|
@@ -25,7 +25,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
input_ids = mx.array(inputs["input_ids"])
|
|
|
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
|
|
|
else:
|
|
|
- input_ids = await loop.run_in_executor(self.executor, lambda: mx.array(self.tokenizer.encode(prompt)))
|
|
|
+ input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
|
|
|
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
|
|
|
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
|
|
|