@@ -28,6 +28,7 @@ def sample_logits(
token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits*(1/temp))
+
return token
class MLXDynamicShardInferenceEngine(InferenceEngine):