1
0
Эх сурвалжийг харах

Minor fix for Shard typing

Sandesh Bharadwaj 3 сар өмнө
parent
commit
349b5344eb

+ 2 - 1
exo/inference/inference_engine.py

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG  # Make sure to import DEBUG
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
+from exo.download.shard_download import ShardDownloader
 
 
 class InferenceEngine(ABC):
@@ -55,7 +56,7 @@ inference_engine_classes = {
   "dummy": "DummyInferenceEngine",
 }
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":