|
@@ -81,7 +81,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.ensure_shard(shard)
|
|
|
loop = asyncio.get_running_loop()
|
|
|
state = await self.poll_state(request_id)
|
|
|
- x = mx.array(input_data).astype(mx.int64) if self.shard.is_first_layer() else mx.array(input_data)
|
|
|
+ x = mx.array(input_data)
|
|
|
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
|
|
|
return output_data
|
|
|
|
|
@@ -90,7 +90,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.save_session('loss', loss_fns[loss])
|
|
|
loop = asyncio.get_running_loop()
|
|
|
#print(f"evaluate in <- {inputs}")
|
|
|
- x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
|
|
|
+ x = mx.array(inputs)
|
|
|
y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
|