Browse Source

Removed statefulModel stuff from mlx impl too

Nel Nibcord 4 months ago
parent
commit
a4313da8d1

+ 0 - 32
exo/inference/mlx/stateful_model.py

@@ -1,32 +0,0 @@
-from typing import Dict, Tuple, Optional
-from collections import OrderedDict
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.cache import make_prompt_cache
-import numpy as np
-
-from ..shard import Shard
-class StatefulModel(nn.Module):
-  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
-    super().__init__()
-    self.model = model
-    self.max_kv_size = max_kv_size
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
-  
-  def __call__(self, x, request_id: Optional[str] = None, use_cache: bool = True):
-    #print(f"StatefulModel in <- {x}")
-    if use_cache and request_id is not None:
-      if request_id not in self.caches:
-        self.init_cache(request_id)
-      else:
-        self.caches.move_to_end(request_id)
-
-      cache = mx.array(self.caches[request_id])
-      y = self.model(x, cache=cache)
-    else:
-      y = self.model(x)
-    #print(f"StatefulModel out -> {y}")
-    return y
-    

+ 0 - 40
exo/inference/mlx/test_sharded_llama.py

@@ -1,40 +0,0 @@
-import mlx.core as mx
-from exo.inference.mlx.stateful_model import StatefulModel
-from exo.inference.mlx.sharded_utils import load_shard
-from exo.inference.shard import Shard
-
-# 79, 80 for Llama-3-70B
-shard_full = Shard("llama", 0, 31, 32)
-shard1 = Shard("llama", 0, 12, 32)
-shard2 = Shard("llama", 13, 31, 32)
-
-full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
-model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
-model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
-
-full = StatefulModel(shard_full, full_model_shard)
-m1 = StatefulModel(shard1, model_shard1)
-m2 = StatefulModel(shard2, model_shard2)
-
-prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
-prompt_tokens = mx.array(full_tokenizer.encode(prompt))
-max_tokens = 50
-
-resp = prompt_tokens
-full_generated_tokens = []
-for _ in range(max_tokens):
-  resp = full.step(resp)
-  full_generated_tokens.append(resp.item())
-
-print("full response: ", full_tokenizer.decode(full_generated_tokens))
-
-sharded_generated_tokens = []
-sharded_resp = prompt_tokens
-for _ in range(max_tokens):
-  resp1 = m1.step(sharded_resp)
-  sharded_resp = m2.step(resp1)
-  sharded_generated_tokens.append(sharded_resp.item())
-
-print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
-
-assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)

+ 0 - 64
exo/inference/mlx/test_sharded_llava.py

@@ -1,64 +0,0 @@
-import codecs
-import asyncio
-import requests
-from PIL import Image
-from io import BytesIO
-
-import mlx.core as mx
-from mlx_lm.models.cache import KVCache
-
-from exo.inference.mlx.stateful_model import StatefulModel
-from exo.inference.mlx.sharded_utils import load_shard
-from exo.inference.shard import Shard
-
-shard_full = Shard("llava", 0, 31, 32)
-shard1 = Shard("llava", 0, 12, 32)
-shard2 = Shard("llava", 13, 31, 32)
-
-model_path = "llava-hf/llava-1.5-7b-hf"
-
-full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
-model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
-model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
-
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
-
-PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
-IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
-response = requests.get(IMAGE_FILE)
-img = Image.open(BytesIO(response.content))
-prompt = codecs.decode(PROMPT, "unicode_escape")
-inputs = full_processor(prompt, img, return_tensors="np")
-pixel_values = mx.array(inputs["pixel_values"])
-input_ids = mx.array(inputs["input_ids"])
-
-print(prompt)
-y = full.step("full", input_ids, pixel_values, temp=0)
-full_generated_tokens = [y.item()]
-
-for _ in range(13):
-  y = full.step("full", y, temp=0)
-  full_generated_tokens.append(y.item())
-
-full_response = full_processor.tokenizer.decode(full_generated_tokens)
-print("full response:", full_response)
-
-inputs = processor1(prompt, img, return_tensors="np")
-pixel_values = mx.array(inputs["pixel_values"])
-input_ids = mx.array(inputs["input_ids"])
-
-y = m1.step("shard", input_ids, pixel_values, temp=0)
-y = m2.step("shard", y, temp=0)
-full_generated_tokens = [y.item()]
-
-for _ in range(13):
-  y = m1.step("shard", y, temp=0)
-  y = m2.step("shard", y, temp=0)
-  full_generated_tokens.append(y.item())
-
-sharded_response = processor2.tokenizer.decode(full_generated_tokens)
-print("sharded response:", sharded_response)
-
-assert full_response == sharded_response