Browse Source

Updated unit tests

Nel Nibcord 7 months ago
parent
commit
b787c676de
1 changed files with 2 additions and 2 deletions
  1. 2 2
      exo/inference/test_inference_engine.py

+ 2 - 2
exo/inference/test_inference_engine.py

@@ -42,7 +42,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
@@ -50,5 +50,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
-    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", 32)
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
   )