Browse Source

Merge branch 'exo-explore:main' into adalovace

lipere123 9 months ago
parent
commit
6c302244e0

+ 9 - 9
.circleci/config.yml

@@ -92,7 +92,7 @@ jobs:
   unit_test:
     macos:
       xcode: "16.0.0"
-    resource_class: macos.m1.large.gen1
+    resource_class: m2pro.large
     steps:
       - checkout
       - run:
@@ -119,7 +119,7 @@ jobs:
 
   discovery_integration_test:
     macos:
-      xcode: "15.4.0"
+      xcode: "16.0.0"
     steps:
       - checkout
       - run:
@@ -158,8 +158,8 @@ jobs:
 
   chatgpt_api_integration_test_mlx:
     macos:
-      xcode: "15.4.0"
-    resource_class: macos.m1.large.gen1
+      xcode: "16.0.0"
+    resource_class: m2pro.large
     steps:
       - checkout
       - run:
@@ -176,20 +176,20 @@ jobs:
             pip install .
       - run_chatgpt_api_test:
           inference_engine: mlx
-          model_id: llama-3.1-8b
+          model_id: llama-3.2-1b
 
   test_macos_m1:
     macos:
-      xcode: "15.4.0"
-    resource_class: macos.m1.large.gen1
+      xcode: "16.0.0"
+    resource_class: m2pro.large
     steps:
       - checkout
       - run: system_profiler SPHardwareDataType
 
   # chatgpt_api_integration_test_tinygrad:
   #   macos:
-  #     xcode: "15.4.0"
-  #   resource_class: macos.m1.large.gen1
+  #     xcode: "16.0.0"
+  #   resource_class: m2pro.large
   #   steps:
   #     - checkout
   #     - run:

+ 4 - 4
exo/inference/mlx/models/llama.py

@@ -32,15 +32,15 @@ class LlamaModel(nn.Module):
     self.vocab_size = args.vocab_size
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
-    if self.args.shard.is_first_layer():
+    if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
     self.layers = []
     for i in range(self.num_hidden_layers):
-      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+      if args.shard.start_layer <= i <= args.shard.end_layer:
         self.layers.append(TransformerBlock(args=args))
       else:
         self.layers.append(IdentityBlock())
-    if self.args.shard.is_last_layer():
+    if args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 
   def __call__(
@@ -74,7 +74,7 @@ class Model(nn.Module):
     self.args = args
     self.model_type = args.model_type
     self.model = LlamaModel(args)
-    if self.args.shard.is_last_layer():
+    if args.shard.is_last_layer():
       if not args.tie_word_embeddings:
         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
 

+ 3 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -8,6 +8,7 @@ from typing import Optional
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
+from functools import partial
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -20,7 +21,8 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     loop = asyncio.get_running_loop()
     if image_str:
       image = await get_image_from_str(image_str)
-      inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
+      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
+      inputs = await loop.run_in_executor(self.executor, tokenize)
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
       output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))

+ 11 - 9
exo/inference/test_inference_engine.py

@@ -9,33 +9,33 @@ import numpy as np
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp_full,
     inference_state=inference_state_full,
   )
 
-  pp = 15
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
+  pp = n_layers // 2
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
     inference_state=inference_state_1,
   )
   resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=resp2,
     inference_state=inference_state_2,
   )
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
     inference_state=inference_state_3,
   )
@@ -47,7 +47,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 asyncio.run(test_inference_engine(
   MLXDynamicShardInferenceEngine(HFShardDownloader()),
   MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+  "mlx-community/Llama-3.2-1B-Instruct-4bit",
+  16
 ))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
@@ -60,5 +61,6 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+      32
     )
   )