浏览代码

fix dummy setup

Alex Cheema 6 月之前
父节点
当前提交
9b8d58c464

+ 12 - 6
.circleci/config.yml

@@ -10,6 +10,10 @@ commands:
         type: string
       model_id:
         type: string
+      expected_output:
+        type: string
+      prompt:
+        type: string
     steps:
       - run:
           name: Run chatgpt api integration test (<<parameters.inference_engine>>, <<parameters.model_id>>)
@@ -56,7 +60,7 @@ commands:
               -H "Content-Type: application/json" \
               -d '{
                 "model": "<<parameters.model_id>>",
-                "messages": [{"role": "user", "content": "Keep responses concise. Who was the king of pop?"}],
+                "messages": [{"role": "user", "content": "<<parameters.prompt>>"}],
                 "temperature": 0.7
               }')
             echo "Response 1: $response_1"
@@ -69,7 +73,7 @@ commands:
               -H "Content-Type: application/json" \
               -d '{
                 "model": "<<parameters.model_id>>",
-                "messages": [{"role": "user", "content": "Keep responses concise. Who was the king of pop?"}],
+                "messages": [{"role": "user", "content": "<<parameters.prompt>>"}],
                 "temperature": 0.7
               }')
             echo "Response 2: $response_2"
@@ -81,8 +85,8 @@ commands:
             kill $PID1 $PID2
 
             echo ""
-            if ! echo "$response_1" | grep -q "$expected_content" || ! echo "$response_2" | grep -q "$expected_content"; then
-              echo "Test failed: Response does not contain '$expected_content'"
+            if ! echo "$response_1" | grep -q "<<parameters.expected_output>>" || ! echo "$response_2" | grep -q "<<parameters.expected_output>>"; then
+              echo "Test failed: Response does not contain '<<parameters.expected_output>>'"
               echo "Response 1: $response_1"
               echo ""
               echo "Response 2: $response_2"
@@ -92,7 +96,7 @@ commands:
               cat output2.log
               exit 1
             else
-              echo "Test passed: Response from both nodes contains '$expected_content'"
+              echo "Test passed: Response from both nodes contains '<<parameters.expected_output>>'"
             fi
 
 jobs:
@@ -184,6 +188,7 @@ jobs:
       - run_chatgpt_api_test:
           inference_engine: mlx
           model_id: llama-3.2-1b
+          expected_output: "Michael Jackson"
 
   chatgpt_api_integration_test_dummy:
     macos:
@@ -206,6 +211,7 @@ jobs:
       - run_chatgpt_api_test:
           inference_engine: dummy
           model_id: dummy-model
+          expected_output: "dummy"
 
   test_macos_m1:
     macos:
@@ -246,4 +252,4 @@ workflows:
       - chatgpt_api_integration_test_mlx
       - chatgpt_api_integration_test_dummy
       - test_macos_m1
-      # - chatgpt_api_integration_test_tinygrad
+      # - chatgpt_api_integration_test_tinygrad

+ 8 - 0
exo/download/shard_download.py

@@ -24,3 +24,11 @@ class ShardDownloader(ABC):
   @abstractmethod
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
+
+class NoopShardDownloader(ShardDownloader):
+  async def ensure_shard(self, shard: Shard) -> Path:
+    return Path("/tmp/noop_shard")
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return AsyncCallbackSystem()

+ 53 - 54
exo/inference/dummy_inference_engine.py

@@ -6,60 +6,59 @@ from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 
 class DummyInferenceEngine(InferenceEngine):
-    def __init__(self, shard_downloader):
-        self.shard = None
-        self.shard_downloader = shard_downloader
-        self.vocab_size = 1000
-        self.eos_token_id = 0
-        self.latency_mean = 0.1
-        self.latency_stddev = 0.02
+  def __init__(self):
+    self.shard = None
+    self.vocab_size = 1000
+    self.eos_token_id = 0
+    self.latency_mean = 0.1
+    self.latency_stddev = 0.02
 
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        try:
-            await self.ensure_shard(shard)
-            
-            # Generate random tokens
-            output_length = np.random.randint(1, 10)
-            output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-            
-            # Simulate latency
-            await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-            
-            # Randomly decide if finished
-            is_finished = np.random.random() < 0.2
-            if is_finished:
-                output = np.array([[self.eos_token_id]])
-            
-            new_state = json.dumps({"dummy_state": "some_value"})
-            
-            return output, new_state, is_finished
-        except Exception as e:
-            print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
-            return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+    try:
+      await self.ensure_shard(shard)
 
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        await self.ensure_shard(shard)
-        state = json.loads(inference_state or "{}")
-        start_pos = state.get("start_pos", 0)
-        
-        output_length = np.random.randint(1, 10)
-        output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-        
-        await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-        
-        is_finished = np.random.random() < 0.2
-        if is_finished:
-            output = np.array([[self.eos_token_id]])
-        
-        start_pos += input_data.shape[1] + output_length
-        new_state = json.dumps({"start_pos": start_pos})
-        
-        return output, new_state, is_finished
+      # Generate random tokens
+      output_length = np.random.randint(1, 10)
+      output = np.random.randint(1, self.vocab_size, size=(1, output_length))
 
-    async def ensure_shard(self, shard: Shard):
-        if self.shard == shard:
-            return
-        # Simulate shard loading without making any API calls
-        await asyncio.sleep(0.1)  # Simulate a short delay
-        self.shard = shard
-        print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")
+      # Simulate latency
+      await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+
+      # Randomly decide if finished
+      is_finished = np.random.random() < 0.2
+      if is_finished:
+        output = np.array([[self.eos_token_id]])
+
+      new_state = json.dumps({"dummy_state": "some_value"})
+
+      return output, new_state, is_finished
+    except Exception as e:
+      print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
+      return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
+
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+    await self.ensure_shard(shard)
+    state = json.loads(inference_state or "{}")
+    start_pos = state.get("start_pos", 0)
+
+    output_length = np.random.randint(1, 10)
+    output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+
+    await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
+
+    is_finished = np.random.random() < 0.2
+    if is_finished:
+      output = np.array([[self.eos_token_id]])
+
+    start_pos += input_data.shape[1] + output_length
+    new_state = json.dumps({"start_pos": start_pos})
+
+    return output, new_state, is_finished
+
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
+    # Simulate shard loading without making any API calls
+    await asyncio.sleep(0.1)  # Simulate a short delay
+    self.shard = shard
+    print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

+ 1 - 1
exo/inference/inference_engine.py

@@ -32,5 +32,5 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
     return TinygradDynamicShardInferenceEngine(shard_downloader)
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
-    return DummyInferenceEngine(shard_downloader)
+    return DummyInferenceEngine()
   raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 10 - 0
exo/inference/tokenizers.py

@@ -7,7 +7,17 @@ from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 
+class DummyTokenizer:
+  def __init__(self):
+    self.eos_token_id = 0
+  def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
+    return [1,2,3]
+  def decode(self, tokens):
+    return "dummy"
+
 async def resolve_tokenizer(model_id: str):
+  if model_id == "dummy":
+    return DummyTokenizer()
   local_path = await get_local_snapshot_dir(model_id)
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   try:

+ 2 - 2
exo/main.py

@@ -15,7 +15,7 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader, RepoProgressEvent
+from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
@@ -61,7 +61,7 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 

+ 4 - 0
exo/models.py

@@ -66,4 +66,8 @@ model_base_shards = {
   "nemotron-70b-bf16": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),
   },
+  # dummy
+  "dummy": {
+    "DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),
+  },
 }