|
@@ -1,5 +1,6 @@
|
|
|
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
+from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
|
from exo.inference.shard import Shard
|
|
|
import asyncio
|
|
@@ -43,8 +44,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
|
|
|
|
|
asyncio.run(
|
|
|
test_inference_engine(
|
|
|
- MLXDynamicShardInferenceEngine(),
|
|
|
- MLXDynamicShardInferenceEngine(),
|
|
|
+ MLXDynamicShardInferenceEngine(HFShardDownloader()),
|
|
|
+ MLXDynamicShardInferenceEngine(HFShardDownloader()),
|
|
|
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
|
|
)
|
|
|
)
|