|
@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
|
|
|
from typing import Tuple, Optional
|
|
|
from abc import ABC, abstractmethod
|
|
|
from .shard import Shard
|
|
|
+from exo.download.shard_download import ShardDownloader
|
|
|
|
|
|
|
|
|
class InferenceEngine(ABC):
|
|
@@ -55,7 +56,7 @@ inference_engine_classes = {
|
|
|
"dummy": "DummyInferenceEngine",
|
|
|
}
|
|
|
|
|
|
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
|
|
|
+def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
|
|
|
if DEBUG >= 2:
|
|
|
print(f"get_inference_engine called with: {inference_engine_name}")
|
|
|
if inference_engine_name == "mlx":
|