|
@@ -50,16 +50,29 @@ shard_mappings = {
|
|
|
|
|
|
|
|
|
class Message:
|
|
|
- def __init__(self, role: str, content: Union[str, list]):
|
|
|
- self.role = role
|
|
|
- self.content = content
|
|
|
+ def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
|
+ self.role = role
|
|
|
+ self.content = content
|
|
|
+
|
|
|
+ def to_dict(self):
|
|
|
+ return {
|
|
|
+ "role": self.role,
|
|
|
+ "content": self.content
|
|
|
+ }
|
|
|
|
|
|
|
|
|
class ChatCompletionRequest:
|
|
|
- def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
|
- self.model = model
|
|
|
- self.messages = messages
|
|
|
- self.temperature = temperature
|
|
|
+ def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
|
+ self.model = model
|
|
|
+ self.messages = messages
|
|
|
+ self.temperature = temperature
|
|
|
+
|
|
|
+ def to_dict(self):
|
|
|
+ return {
|
|
|
+ "model": self.model,
|
|
|
+ "messages": [message.to_dict() for message in self.messages],
|
|
|
+ "temperature": self.temperature
|
|
|
+ }
|
|
|
|
|
|
|
|
|
def resolve_tinygrad_tokenizer(model_id: str):
|
|
@@ -75,8 +88,12 @@ async def resolve_tokenizer(model_id: str):
|
|
|
try:
|
|
|
if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}")
|
|
|
processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
|
|
|
- processor.eos_token_id = processor.tokenizer.eos_token_id
|
|
|
- processor.encode = processor.tokenizer.encode
|
|
|
+ if not hasattr(processor, 'eos_token_id'):
|
|
|
+ processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
|
|
+ if not hasattr(processor, 'encode'):
|
|
|
+ processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
|
|
|
+ if not hasattr(processor, 'decode'):
|
|
|
+ processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
|
|
return processor
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
|
|
@@ -157,6 +174,10 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
|
|
remapped_messages = []
|
|
|
last_image = None
|
|
|
for message in messages:
|
|
|
+ if not isinstance(message.content, list):
|
|
|
+ remapped_messages.append(message)
|
|
|
+ continue
|
|
|
+
|
|
|
remapped_content = []
|
|
|
for content in message.content:
|
|
|
if isinstance(content, dict):
|
|
@@ -168,16 +189,17 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
|
|
else:
|
|
|
remapped_content.append(content)
|
|
|
else:
|
|
|
- remapped_content.append({"type": "text", "text": content})
|
|
|
+ remapped_content.append(content)
|
|
|
remapped_messages.append(Message(role=message.role, content=remapped_content))
|
|
|
|
|
|
if last_image:
|
|
|
# Replace the last image placeholder with the actual image content
|
|
|
for message in reversed(remapped_messages):
|
|
|
for i, content in enumerate(message.content):
|
|
|
- if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
|
|
- message.content[i] = last_image
|
|
|
- return remapped_messages
|
|
|
+ if isinstance(content, dict):
|
|
|
+ if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
|
|
+ message.content[i] = last_image
|
|
|
+ return remapped_messages
|
|
|
|
|
|
return remapped_messages
|
|
|
|
|
@@ -192,7 +214,7 @@ def build_prompt(tokenizer, _messages: List[Message]):
|
|
|
for content in message.content:
|
|
|
# note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
|
|
|
# follows the convention in https://platform.openai.com/docs/guides/vision
|
|
|
- if content.get("type", None) == "image":
|
|
|
+ if isinstance(content, dict) and content.get("type", None) == "image":
|
|
|
image_str = content.get("image", None)
|
|
|
break
|
|
|
|