Browse Source

fix stable diffusion case for tui, make mlx run on its own thread again and non-blocking

Alex Cheema 5 months ago
parent
commit
8ab9977f01
2 changed files with 51 additions and 23 deletions
  1. 48 20
      exo/inference/mlx/sharded_inference_engine.py
  2. 3 3
      exo/main.py

+ 48 - 20
exo/inference/mlx/sharded_inference_engine.py

@@ -23,10 +23,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.sampler = make_sampler(*self.sampler_params)
     self.sampler = make_sampler(*self.sampler_params)
     self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
     self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
     self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
     self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
+    self.session = {}
 
 
   async def _eval_mlx(self, *args):
   async def _eval_mlx(self, *args):
-    loop = asyncio.get_running_loop()
-    await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
+    await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
 
 
   async def poll_state(self, request_id: str, max_caches=2):
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:
     if request_id in self.caches:
@@ -51,31 +51,48 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
+    return np.asarray(
+      await asyncio.get_running_loop().run_in_executor(
+        self._tokenizer_thread,
+        self.tokenizer.encode,
+        prompt
+      )
+    )
 
 
   async def decode(self, shard: Shard, tokens) -> str:
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
+    return await asyncio.get_running_loop().run_in_executor(
+      self._tokenizer_thread,
+      self.tokenizer.decode,
+      tokens
+    )
 
 
   async def save_checkpoint(self, shard: Shard, path: str):
   async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    self.model.save_weights(path)
+    await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
 
 
   async def load_checkpoint(self, shard: Shard, path: str):
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    self.model.load_weights(path)
-    
+    await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
+
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
     x = mx.array(input_data)
+
     if self.model.model_type != 'StableDiffusionPipeline':
     if self.model.model_type != 'StableDiffusionPipeline':
-      output_data = self.model(x, **state, **(inference_state or {}))
+      output_data = await asyncio.get_running_loop().run_in_executor(
+        self._mlx_thread,
+        lambda: self.model(x, **state, **(inference_state or {}))
+      )
+      inference_state = None
     else:
     else:
-      output_data, inference_state = self.model(x, **state, **(inference_state or {}))
+      result = await asyncio.get_running_loop().run_in_executor(
+        self._mlx_thread,
+        lambda: self.model(x, **state, **(inference_state or {}))
+      )
+      output_data, inference_state = result
+
     output_data = np.array(output_data, copy=False)
     output_data = np.array(output_data, copy=False)
     return output_data, inference_state
     return output_data, inference_state
 
 
@@ -85,18 +102,29 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     x = mx.array(inputs)
     x = mx.array(inputs)
     y = mx.array(targets)
     y = mx.array(targets)
     l = mx.array(lengths)
     l = mx.array(lengths)
-    score = self.session['loss'](self.model, x, y, l)
+
+    score = await asyncio.get_running_loop().run_in_executor(
+      self._mlx_thread,
+      lambda: self.session['loss'](self.model, x, y, l)
+    )
     return score
     return score
 
 
   async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
   async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
+
     if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
     if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
       await self.save_session('train_layers', trainable_layers)
       await self.save_session('train_layers', trainable_layers)
-      self.model.freeze()
-      self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
+      def freeze_unfreeze():
+        self.model.freeze()
+        self.model.apply_to_modules(
+          lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
+        )
+      await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
+
     if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
     if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
       await self.save_session('lossname', loss)
       await self.save_session('lossname', loss)
       await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
       await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
+
     if 'opt' not in self.session:
     if 'opt' not in self.session:
       await self.save_session('opt', opt(lr))
       await self.save_session('opt', opt(lr))
     return True
     return True
@@ -113,11 +141,13 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     x = mx.array(inputs)
     x = mx.array(inputs)
     y = mx.array(targets)
     y = mx.array(targets)
     l = mx.array(lengths)
     l = mx.array(lengths)
-
-    score, gradients, eval_args = train_step(x, y, l)
+    score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
+      self._mlx_thread,
+      lambda: train_step(x, y, l)
+    )
     await self._eval_mlx(*eval_args)
     await self._eval_mlx(*eval_args)
 
 
-    layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
+    layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
     first_layer = np.array(layers[0]['input_layernorm'], copy=False)
     first_layer = np.array(layers[0]['input_layernorm'], copy=False)
     await self._eval_mlx(first_layer)
     await self._eval_mlx(first_layer)
     return score, first_layer
     return score, first_layer
@@ -125,9 +155,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
-
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
-
     if self.shard != shard:
     if self.shard != shard:
       model_shard, self.tokenizer = await load_shard(model_path, shard)
       model_shard, self.tokenizer = await load_shard(model_path, shard)
       self.shard = shard
       self.shard = shard

+ 3 - 3
exo/main.py

@@ -189,11 +189,11 @@ api = ChatGPTAPI(
 buffered_token_output = {}
 buffered_token_output = {}
 def update_topology_viz(req_id, tokens, __):
 def update_topology_viz(req_id, tokens, __):
   if not topology_viz: return
   if not topology_viz: return
+  if inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
+
   if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
   if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
   else: buffered_token_output[req_id] = tokens
   else: buffered_token_output[req_id] = tokens
-
-  if inference_engine.shard.model_id != 'stable-diffusion-2-1-base':
-    topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
+  topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
 node.on_token.register("update_topology_viz").on_next(update_topology_viz)
 node.on_token.register("update_topology_viz").on_next(update_topology_viz)
 
 
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):