|
@@ -23,21 +23,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
|
|
|
pixel_values = mx.array(inputs["pixel_values"])
|
|
|
input_ids = mx.array(inputs["input_ids"])
|
|
|
- output_data = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
|
|
|
+ output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
|
|
|
else:
|
|
|
input_ids = await loop.run_in_executor(self.executor, lambda: mx.array(self.tokenizer.encode(prompt)))
|
|
|
- output_data = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
|
|
|
+ output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
|
|
|
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
|
|
await self.ensure_shard(shard)
|
|
|
- input_tensor = mx.array(input_data)
|
|
|
- output_data = np.array(await asyncio.get_running_loop().run_in_executor(
|
|
|
- self.executor,
|
|
|
- self.stateful_sharded_model.step,
|
|
|
- request_id,
|
|
|
- input_tensor
|
|
|
- ))
|
|
|
+ output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
|
|
|
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
@@ -48,21 +42,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
if self.shard != shard:
|
|
|
loop = asyncio.get_running_loop()
|
|
|
-
|
|
|
- # Run load_shard in a separate thread
|
|
|
- def load_shard_wrapper():
|
|
|
- return asyncio.run(load_shard(model_path, shard))
|
|
|
-
|
|
|
- model_shard, self.tokenizer = await loop.run_in_executor(
|
|
|
- self.executor,
|
|
|
- load_shard_wrapper
|
|
|
- )
|
|
|
-
|
|
|
- # Create StatefulShardedModel in the executor
|
|
|
- self.stateful_sharded_model = await loop.run_in_executor(
|
|
|
- self.executor,
|
|
|
- StatefulShardedModel,
|
|
|
- shard,
|
|
|
- model_shard
|
|
|
- )
|
|
|
+ def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
|
|
|
+ model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
|
|
|
+ self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
|
|
|
self.shard = shard
|