浏览代码

Do we need casting here?

Nel Nibcord 8 月之前
父节点
当前提交
59af2dd592
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      exo/inference/mlx/sharded_inference_engine.py

+ 2 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -81,7 +81,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id)
-    x = mx.array(input_data).astype(mx.int64) if self.shard.is_first_layer() else mx.array(input_data)
+    x = mx.array(input_data)
     output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
     return output_data
 
@@ -90,7 +90,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.save_session('loss', loss_fns[loss])
     loop = asyncio.get_running_loop()
     #print(f"evaluate in <- {inputs}")
-    x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
+    x = mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
     score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)