|
@@ -61,12 +61,13 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
|
|
|
|
|
return model
|
|
|
|
|
|
+_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread
|
|
|
class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
|
self.shard = None
|
|
|
self.shard_downloader = shard_downloader
|
|
|
- self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
self.states = OrderedDict()
|
|
|
+ self.executor = _executor
|
|
|
|
|
|
def poll_state(self, x, request_id: str, max_states=2):
|
|
|
if request_id not in self.states:
|