Sfoglia il codice sorgente

Minor fix for Shard typing

Sandesh Bharadwaj 3 mesi fa
parent
commit
349b5344eb
1 ha cambiato i file con 2 aggiunte e 1 eliminazioni
  1. 2 1
      exo/inference/inference_engine.py

+ 2 - 1
exo/inference/inference_engine.py

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG  # Make sure to import DEBUG
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
+from exo.download.shard_download import ShardDownloader
 
 
 class InferenceEngine(ABC):
@@ -55,7 +56,7 @@ inference_engine_classes = {
   "dummy": "DummyInferenceEngine",
 }
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":