Browse Source

fix inference engine test

Alex Cheema 10 months ago
parent
commit
09a9abc065
1 changed files with 3 additions and 2 deletions
  1. 3 2
      exo/inference/test_inference_engine.py

+ 3 - 2
exo/inference/test_inference_engine.py

@@ -1,5 +1,6 @@
 from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 import asyncio
@@ -43,8 +44,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 
 asyncio.run(
   test_inference_engine(
-    MLXDynamicShardInferenceEngine(),
-    MLXDynamicShardInferenceEngine(),
+    MLXDynamicShardInferenceEngine(HFShardDownloader()),
+    MLXDynamicShardInferenceEngine(HFShardDownloader()),
     "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
   )
 )