|
@@ -41,17 +41,17 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
|
|
assert np.array_equal(next_resp_full, resp4)
|
|
|
|
|
|
|
|
|
-# asyncio.run(
|
|
|
-# test_inference_engine(
|
|
|
-# MLXDynamicShardInferenceEngine(),
|
|
|
-# MLXDynamicShardInferenceEngine(),
|
|
|
-# "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
|
|
-# )
|
|
|
-# )
|
|
|
+asyncio.run(
|
|
|
+ test_inference_engine(
|
|
|
+ MLXDynamicShardInferenceEngine(),
|
|
|
+ MLXDynamicShardInferenceEngine(),
|
|
|
+ "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
|
|
+ )
|
|
|
+)
|
|
|
|
|
|
# TODO: Need more memory or a smaller model
|
|
|
-asyncio.run(test_inference_engine(
|
|
|
- TinygradDynamicShardInferenceEngine(),
|
|
|
- TinygradDynamicShardInferenceEngine(),
|
|
|
- "llama3-8b-sfr",
|
|
|
-))
|
|
|
+# asyncio.run(test_inference_engine(
|
|
|
+# TinygradDynamicShardInferenceEngine(),
|
|
|
+# TinygradDynamicShardInferenceEngine(),
|
|
|
+# "mlx-community/Meta-Llama-3-8B-Instruct",
|
|
|
+# ))
|