Browse Source

Applied new interface to tinygrad and dummy inference engines

Nel Nibcord 6 tháng trước cách đây
mục cha
commit
527c7a6e49

+ 19 - 37
exo/inference/dummy_inference_engine.py

@@ -1,60 +1,42 @@
 from typing import Optional, Tuple, TYPE_CHECKING
 import numpy as np
+import random
+import string
 import asyncio
 import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
-
+def random_string(length: int):
+  return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
+  
 
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
     self.vocab_size = 1000
+    self.hidden_size = 256
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-    try:
-      await self.ensure_shard(shard)
-
-      # Generate random tokens
-      output_length = np.random.randint(1, 10)
-      output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-
-      # Simulate latency
-      await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size)
 
-      # Randomly decide if finished
-      is_finished = np.random.random() < 0.2
-      if is_finished:
-        output = np.array([[self.eos_token_id]])
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
 
-      new_state = json.dumps({"dummy_state": "some_value"})
-
-      return output, new_state, is_finished
-    except Exception as e:
-      print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
-      return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None):
+    output_data = await self.infer_tensor(request_id, shard, await self.encode(shard, prompt), inference_state)
+    return output_data 
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
-    state = json.loads(inference_state or "{}")
-    start_pos = state.get("start_pos", 0)
-
-    output_length = np.random.randint(1, 10)
-    output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-
-    await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-    is_finished = np.random.random() < 0.2
-    if is_finished:
-      output = np.array([[self.eos_token_id]])
-
-    start_pos += input_data.shape[1] + output_length
-    new_state = json.dumps({"start_pos": start_pos})
-
-    return output, new_state, is_finished
+    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
+    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
+    return output
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 2 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -54,7 +54,8 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     return tokens
     
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
-    output_data = await self.infer_tensor(request_id, shard, await self.encode(shard, prompt), inference_state)
+    tokens = await self.encode(shard, prompt)
+    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
     return output_data 
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):

+ 22 - 26
exo/inference/tinygrad/inference.py

@@ -1,7 +1,7 @@
 from pathlib import Path
 import json
 import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
@@ -65,37 +65,33 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
+  async def sample(self, x: np.ndarray):
+    logits = x[:, -1, :]
+    def sample_wrapper():
+      return sample_logits(Tensor(x).flatten(), TEMPERATURE, 0, 0.8, 0.0, 0.0).realize()
+    out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
+    return out.numpy()
 
-    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
+  async def encode(self, shard: Shard, prompt: str):
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return tokens
+  
+  async def decode(self, shard: Shard, tokens):
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
 
-    if h.shape == (1,):
-      start_pos += len(toks)
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      n_captured_toks = len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
+    return output_data 
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += n_captured_toks
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos)
+    return output_data.numpy()
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 10 - 11
exo/inference/tinygrad/models/llama.py

@@ -120,7 +120,7 @@ class TransformerBlock:
 
 
 # standard openai sampling
-def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
@@ -202,31 +202,30 @@ class Transformer:
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
-    seqlen = x.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
-
+  def forward(self, x: Tensor, start_pos: Union[Variable, int]):
     if self.shard.is_first_layer():
       h = self.tok_embeddings(x)
     else:
       h = x
+    seqlen = h.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
     for i in range(self.shard.start_layer, self.shard.end_layer + 1):
       layer = self.layers[i]
       h = layer(h, start_pos, freqs_cis, mask)
 
     if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float()[:, -1, :]
-      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float().realize()
+      return logits
     else:
       return h
 
-  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
+  def __call__(self, tokens: Tensor, start_pos: Variable):
     # TODO: better way to handle the first call v.s. the rest?
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
-    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos))
+    return self.forward(tokens, start_pos)
 
 
 # *** helpers ***

+ 4 - 0
exo/orchestration/standard_node.py

@@ -39,6 +39,7 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_inputs: Dict[str, Tuple[List[np.ndarray], bool]] = {}
     self.buffered_logits: Dict[str, Tuple[List[np.ndarray], bool]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
@@ -121,6 +122,8 @@ class StandardNode(Node):
     for i in np.reshape(result, (-1, 1, result.shape[-1])):
       self.buffered_logits[request_id][0].append(i)
 
+    inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0])})
+
     if shard.is_last_layer():
       result = await self.inference_engine.sample(result)
     
@@ -131,6 +134,7 @@ class StandardNode(Node):
 
     if result.size == 1:  # we got a new token out
       self.buffered_token_output[request_id][0].append(result.item())
+      inference_state = json.dumps({"start_pos": json.loads(inference_state or "{}").get("start_pos", 0) + 1})
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
     
     if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")