test_sharded_llava.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import codecs
  2. import asyncio
  3. import requests
  4. from PIL import Image
  5. from io import BytesIO
  6. import mlx.core as mx
  7. from mlx_lm.models.cache import KVCache
  8. from exo.inference.mlx.stateful_model import StatefulModel
  9. from exo.inference.mlx.sharded_utils import load_shard
  10. from exo.inference.shard import Shard
  11. shard_full = Shard("llava", 0, 31, 32)
  12. shard1 = Shard("llava", 0, 12, 32)
  13. shard2 = Shard("llava", 13, 31, 32)
  14. model_path = "llava-hf/llava-1.5-7b-hf"
  15. full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
  16. model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
  17. model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
  18. full = StatefulShardedModel(shard_full, full_model_shard)
  19. m1 = StatefulShardedModel(shard1, model_shard1)
  20. m2 = StatefulShardedModel(shard2, model_shard2)
  21. PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
  22. IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
  23. response = requests.get(IMAGE_FILE)
  24. img = Image.open(BytesIO(response.content))
  25. prompt = codecs.decode(PROMPT, "unicode_escape")
  26. inputs = full_processor(prompt, img, return_tensors="np")
  27. pixel_values = mx.array(inputs["pixel_values"])
  28. input_ids = mx.array(inputs["input_ids"])
  29. print(prompt)
  30. y = full.step("full", input_ids, pixel_values, temp=0)
  31. full_generated_tokens = [y.item()]
  32. for _ in range(13):
  33. y = full.step("full", y, temp=0)
  34. full_generated_tokens.append(y.item())
  35. full_response = full_processor.tokenizer.decode(full_generated_tokens)
  36. print("full response:", full_response)
  37. inputs = processor1(prompt, img, return_tensors="np")
  38. pixel_values = mx.array(inputs["pixel_values"])
  39. input_ids = mx.array(inputs["input_ids"])
  40. y = m1.step("shard", input_ids, pixel_values, temp=0)
  41. y = m2.step("shard", y, temp=0)
  42. full_generated_tokens = [y.item()]
  43. for _ in range(13):
  44. y = m1.step("shard", y, temp=0)
  45. y = m2.step("shard", y, temp=0)
  46. full_generated_tokens.append(y.item())
  47. sharded_response = processor2.tokenizer.decode(full_generated_tokens)
  48. print("sharded response:", sharded_response)
  49. assert full_response == sharded_response