Browse Source

make StatefulShardedModel callable, add some tests for mlx sharded inference

Alex Cheema 11 months ago
parent
commit
850b72d3ea

+ 1 - 0
.gitignore

@@ -1,2 +1,3 @@
 __pycache__/
 .venv
+test_weights.npz

+ 0 - 0
inference/__init__.py


+ 0 - 0
inference/mlx/__init__.py


+ 0 - 0
inference/mlx/models/__init__.py


+ 1 - 1
inference/mlx/sharded_inference_engine.py

@@ -60,4 +60,4 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
         model_shard, self.tokenizer = load_shard(shard.model_id, shard)
         self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-        self.shard = shard
+        self.shard = shard

+ 10 - 1
inference/mlx/sharded_model.py

@@ -1,4 +1,4 @@
-from typing import Dict, Generator, Optional, Tuple
+from typing import Any, Dict, Generator, Optional, Tuple
 
 import mlx.core as mx
 import mlx.nn as nn
@@ -47,6 +47,15 @@ class StatefulShardedModel:
         else:
             return output
 
+    def __call__(
+            self,
+            x,
+            temp: float = 0.0,
+            top_p: float = 1.0,
+        logit_bias: Optional[Dict[int, float]] = None,
+    ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+        return self.step(x, temp, top_p, logit_bias)
+
     def reset(self):
         kv_heads = (
             [self.model.n_kv_heads] * len(self.model.layers)

+ 40 - 0
inference/mlx/test_sharded_llama.py

@@ -0,0 +1,40 @@
+import mlx.core as mx
+from inference.mlx.sharded_model import StatefulShardedModel
+from inference.mlx.sharded_utils import load_shard
+from inference.shard import Shard
+
+shard_full = Shard("llama", 0, 31, 32)
+shard1 = Shard("llama", 0, 12, 32)
+shard2 = Shard("llama", 13, 31, 32)
+
+full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
+model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
+model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
+
+full = StatefulShardedModel(shard_full, full_model_shard)
+m1 = StatefulShardedModel(shard1, model_shard1)
+m2 = StatefulShardedModel(shard2, model_shard2)
+
+prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
+prompt_tokens = mx.array(tokenizer1.encode(prompt))
+max_tokens = 50
+
+resp = prompt_tokens
+full_generated_tokens = []
+for _ in range(max_tokens):
+    resp = full.step(resp)
+    full_generated_tokens.append(resp.item())
+
+print("full response: ", tokenizer1.decode(full_generated_tokens))
+
+
+sharded_generated_tokens = []
+sharded_resp = prompt_tokens
+for _ in range(max_tokens):
+    resp1 = m1.step(sharded_resp)
+    sharded_resp = m2.step(resp1)
+    sharded_generated_tokens.append(sharded_resp.item())
+
+print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
+
+assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)

+ 51 - 0
inference/mlx/test_sharded_model.py

@@ -0,0 +1,51 @@
+from inference.shard import Shard
+from inference.mlx.sharded_model import StatefulShardedModel
+import mlx.core as mx
+import mlx.nn as nn
+from typing import Optional
+import numpy as np
+
+class DummyModel(nn.Module):
+    def __init__(self, shard: Optional[Shard] = None):
+        self.shard = shard
+        self.layers = [
+            nn.Linear(8, 128),
+            nn.Linear(128, 128),
+            nn.Linear(128, 128),
+            nn.Linear(128, 128),
+            nn.Linear(128, 8),
+        ]
+
+        self.n_kv_heads = 4
+        self.head_dim = 4
+
+    def __call__(self, x, cache=None):
+        if self.shard:
+            for layer in self.layers[self.shard.start_layer:self.shard.end_layer+1]:
+                x = layer(x)
+            if self.shard.is_last_layer():
+                x =  x.reshape((1, 2, 4))
+        else:
+            for layer in self.layers:
+                x = layer(x)
+            x = x.reshape((1, 2, 4))
+
+        return x
+
+model = DummyModel()
+model.save_weights("./test_weights.npz")
+n_layers = 5
+shard1 = Shard("test", 0, n_layers // 2, n_layers)
+sharded_model1 = DummyModel(shard1)
+shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
+sharded_model2 = DummyModel(shard2)
+
+model.load_weights("./test_weights.npz")
+sharded_model1.load_weights("./test_weights.npz")
+sharded_model2.load_weights("./test_weights.npz")
+
+fullresp = model(mx.array([1,2,3,4,5,6,7,8]))
+resp1 = sharded_model1(mx.array([1,2,3,4,5,6,7,8]))
+resp2 = sharded_model2(resp1)
+
+assert np.all(np.array(fullresp) == np.array(resp2))