|
@@ -42,7 +42,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
|
|
assert np.array_equal(next_resp_full, resp4)
|
|
|
|
|
|
|
|
|
-asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
|
|
|
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
|
|
|
|
|
|
if os.getenv("RUN_TINYGRAD", default="0") == "1":
|
|
|
import tinygrad
|
|
@@ -50,5 +50,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
|
|
|
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
|
|
asyncio.run(
|
|
|
- test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", 32)
|
|
|
+ test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
|
|
|
)
|