Browse Source

Minor fix for Shard typing

Sandesh Bharadwaj 5 months ago
parent
commit
349b5344eb
1 changed files with 2 additions and 1 deletions
  1. 2 1
      exo/inference/inference_engine.py

+ 2 - 1
exo/inference/inference_engine.py

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG  # Make sure to import DEBUG
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
+from exo.download.shard_download import ShardDownloader
 
 
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
@@ -55,7 +56,7 @@ inference_engine_classes = {
   "dummy": "DummyInferenceEngine",
   "dummy": "DummyInferenceEngine",
 }
 }
 
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
   if DEBUG >= 2:
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")
     print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":
   if inference_engine_name == "mlx":