Browse Source

Okay we should probably await the update

Nel Nibcord 8 months ago
parent
commit
3e869051f6
1 changed files with 1 additions and 1 deletions
  1. 1 1
      exo/inference/mlx/sharded_inference_engine.py

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -93,8 +93,8 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     l = mx.array(lengths)
     loop = asyncio.get_running_loop()
     score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
-    loop.run_in_executor(self.executor, self.update_model, grad, score)
     layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
+    await loop.run_in_executor(self.executor, self.update_model, grad, score)
 
     return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)