|
@@ -93,8 +93,8 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
l = mx.array(lengths)
|
|
|
loop = asyncio.get_running_loop()
|
|
|
score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
|
|
|
- loop.run_in_executor(self.executor, self.update_model, grad, score)
|
|
|
layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
|
|
|
+ await loop.run_in_executor(self.executor, self.update_model, grad, score)
|
|
|
|
|
|
return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
|
|
|
|