Browse Source

run unit test on llama 3.2 1b for faster test

Alex Cheema 9 months ago
parent
commit
ae74d2da16
1 changed files with 10 additions and 8 deletions
  1. 10 8
      exo/inference/test_inference_engine.py

+ 10 - 8
exo/inference/test_inference_engine.py

@@ -9,33 +9,33 @@ import numpy as np
 
 
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp_full,
     input_data=resp_full,
     inference_state=inference_state_full,
     inference_state=inference_state_full,
   )
   )
 
 
   pp = 15
   pp = 15
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
     input_data=resp1,
     inference_state=inference_state_1,
     inference_state=inference_state_1,
   )
   )
   resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
   resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=resp2,
     input_data=resp2,
     inference_state=inference_state_2,
     inference_state=inference_state_2,
   )
   )
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
     input_data=resp3,
     inference_state=inference_state_3,
     inference_state=inference_state_3,
   )
   )
@@ -47,7 +47,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 asyncio.run(test_inference_engine(
 asyncio.run(test_inference_engine(
   MLXDynamicShardInferenceEngine(HFShardDownloader()),
   MLXDynamicShardInferenceEngine(HFShardDownloader()),
   MLXDynamicShardInferenceEngine(HFShardDownloader()),
   MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+  "mlx-community/Llama-3.2-1B-Instruct-4bit",
+  16
 ))
 ))
 
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
@@ -60,5 +61,6 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+      32
     )
     )
   )
   )