@@ -120,7 +120,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
return lval, gradlayers
- x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
+ x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)