test_sharded_model.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from inference.shard import Shard
  2. from inference.mlx.sharded_model import StatefulShardedModel
  3. import mlx.core as mx
  4. import mlx.nn as nn
  5. from typing import Optional
  6. import numpy as np
  7. class DummyModel(nn.Module):
  8. def __init__(self, shard: Optional[Shard] = None):
  9. self.shard = shard
  10. self.layers = [
  11. nn.Linear(8, 128),
  12. nn.Linear(128, 128),
  13. nn.Linear(128, 128),
  14. nn.Linear(128, 128),
  15. nn.Linear(128, 8),
  16. ]
  17. self.n_kv_heads = 4
  18. self.head_dim = 4
  19. def __call__(self, x, cache=None):
  20. if self.shard:
  21. for layer in self.layers[self.shard.start_layer:self.shard.end_layer+1]:
  22. x = layer(x)
  23. if self.shard.is_last_layer():
  24. x = x.reshape((1, 2, 4))
  25. else:
  26. for layer in self.layers:
  27. x = layer(x)
  28. x = x.reshape((1, 2, 4))
  29. return x
  30. model = DummyModel()
  31. model.save_weights("./test_weights.npz")
  32. n_layers = 5
  33. shard1 = Shard("test", 0, n_layers // 2, n_layers)
  34. sharded_model1 = DummyModel(shard1)
  35. shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
  36. sharded_model2 = DummyModel(shard2)
  37. model.load_weights("./test_weights.npz")
  38. sharded_model1.load_weights("./test_weights.npz")
  39. sharded_model2.load_weights("./test_weights.npz")
  40. fullresp = model(mx.array([1,2,3,4,5,6,7,8]))
  41. resp1 = sharded_model1(mx.array([1,2,3,4,5,6,7,8]))
  42. resp2 = sharded_model2(resp1)
  43. assert np.all(np.array(fullresp) == np.array(resp2))