|
@@ -8,7 +8,6 @@ from typing import Optional
|
|
|
|
|
|
class MLXFixedShardInferenceEngine(InferenceEngine):
|
|
class MLXFixedShardInferenceEngine(InferenceEngine):
|
|
def __init__(self, model_path: str, shard: Shard):
|
|
def __init__(self, model_path: str, shard: Shard):
|
|
- print("initializing fixed shard inference", shard)
|
|
|
|
self.shard = shard
|
|
self.shard = shard
|
|
model_shard, self.tokenizer = load_shard(model_path, shard)
|
|
model_shard, self.tokenizer = load_shard(model_path, shard)
|
|
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
|
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
|
@@ -18,7 +17,6 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
|
|
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
|
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
|
|
|
|
|
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
|
|
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
|
|
- print(f"output_data size: {output_data.size}, output_data: {output_data}")
|
|
|
|
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
|
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
|
|
|
|
|
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
|
|
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
|
|
@@ -32,7 +30,6 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
|
|
if shard != self.shard:
|
|
if shard != self.shard:
|
|
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
|
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
|
|
|
|
|
- print(f"Resetting shard: {shard}")
|
|
|
|
self.stateful_sharded_model.reset()
|
|
self.stateful_sharded_model.reset()
|
|
|
|
|
|
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
@@ -51,8 +48,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
async def reset_shard(self, shard: Shard):
|
|
async def reset_shard(self, shard: Shard):
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
-
|
|
|
|
- print(f"Resetting shard: {shard}")
|
|
|
|
self.stateful_sharded_model.reset()
|
|
self.stateful_sharded_model.reset()
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
async def ensure_shard(self, shard: Shard):
|