Browse Source

strip out tinygrad temporarily

Alex Cheema 6 months ago
parent
commit
e7201292de

+ 0 - 25
.circleci/config.yml

@@ -223,30 +223,6 @@ jobs:
       - checkout
       - checkout
       - run: system_profiler SPHardwareDataType
       - run: system_profiler SPHardwareDataType
 
 
-  # chatgpt_api_integration_test_tinygrad:
-  #   macos:
-  #     xcode: "16.0.0"
-  #   resource_class: m2pro.large
-  #   steps:
-  #     - checkout
-  #     - run:
-  #         name: Set up Python
-  #         command: |
-  #           brew install python@3.12
-  #           python3.12 -m venv env
-  #           source env/bin/activate
-  #     - run:
-  #         name: Install dependencies
-  #         command: |
-  #           source env/bin/activate
-  #           pip install --upgrade pip
-  #           pip install .
-  #     - run_chatgpt_api_test:
-  #         inference_engine: tinygrad
-  #         model_id: llama-3-8b
-  #         prompt: "Keep responses concise. Who was the king of pop?"
-  #         expected_output: "Michael Jackson"
-
 workflows:
 workflows:
   version: 2
   version: 2
   build_and_test:
   build_and_test:
@@ -256,4 +232,3 @@ workflows:
       - chatgpt_api_integration_test_mlx
       - chatgpt_api_integration_test_mlx
       - chatgpt_api_integration_test_dummy
       - chatgpt_api_integration_test_dummy
       - test_macos_m1
       - test_macos_m1
-      # - chatgpt_api_integration_test_tinygrad

+ 0 - 59
exo/inference/debug_inference_engine.py

@@ -1,59 +0,0 @@
-from exo.inference.inference_engine import InferenceEngine
-from exo.inference.shard import Shard
-from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-import asyncio
-import numpy as np
-
-
-# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-  from exo.inference.tinygrad.inference import Tokenizer
-  from pathlib import Path
-
-  _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
-
-  prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
-    "A",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
-  )
-
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
-    input_data=resp1,
-    inference_state=inference_state_1,
-  )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
-  )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
-    input_data=resp3,
-    inference_state=inference_state_3,
-  )
-
-  print(f"{resp2=}")
-  print(f"full: {_tokenizer.decode(resp_full)}")
-  print(f"next full: {_tokenizer.decode(next_resp_full)}")
-  print(f"resp2: {_tokenizer.decode(resp2)}")
-  print(f"{resp4=}")
-  print(f"resp4: {_tokenizer.decode(resp4)}")
-
-  assert np.array_equal(resp_full, resp2)
-  assert np.array_equal(next_resp_full, resp4)
-
-
-asyncio.run(test_inference_engine(
-  TinygradDynamicShardInferenceEngine(),
-  TinygradDynamicShardInferenceEngine(),
-  "llama3-8b-sfr",
-))

+ 0 - 19
exo/inference/inference_engine.py

@@ -15,22 +15,3 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
     pass
-
-
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
-  if DEBUG >= 2:
-    print(f"get_inference_engine called with: {inference_engine_name}")
-  if inference_engine_name == "mlx":
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-
-    return MLXDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "tinygrad":
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    import tinygrad.helpers
-    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-
-    return TinygradDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "dummy":
-    from exo.inference.dummy_inference_engine import DummyInferenceEngine
-    return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 0 - 64
exo/inference/mlx/test_sharded_llava.py

@@ -1,64 +0,0 @@
-import codecs
-import asyncio
-import requests
-from PIL import Image
-from io import BytesIO
-
-import mlx.core as mx
-from mlx_lm.models.base import KVCache
-
-from exo.inference.mlx.sharded_model import StatefulShardedModel
-from exo.inference.mlx.sharded_utils import load_shard
-from exo.inference.shard import Shard
-
-shard_full = Shard("llava", 0, 31, 32)
-shard1 = Shard("llava", 0, 12, 32)
-shard2 = Shard("llava", 13, 31, 32)
-
-model_path = "llava-hf/llava-1.5-7b-hf"
-
-full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
-model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
-model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
-
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
-
-PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
-IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
-response = requests.get(IMAGE_FILE)
-img = Image.open(BytesIO(response.content))
-prompt = codecs.decode(PROMPT, "unicode_escape")
-inputs = full_processor(prompt, img, return_tensors="np")
-pixel_values = mx.array(inputs["pixel_values"])
-input_ids = mx.array(inputs["input_ids"])
-
-print(prompt)
-y = full.step("full", input_ids, pixel_values, temp=0)
-full_generated_tokens = [y.item()]
-
-for _ in range(13):
-  y = full.step("full", y, temp=0)
-  full_generated_tokens.append(y.item())
-
-full_response = full_processor.tokenizer.decode(full_generated_tokens)
-print("full response:", full_response)
-
-inputs = processor1(prompt, img, return_tensors="np")
-pixel_values = mx.array(inputs["pixel_values"])
-input_ids = mx.array(inputs["input_ids"])
-
-y = m1.step("shard", input_ids, pixel_values, temp=0)
-y = m2.step("shard", y, temp=0)
-full_generated_tokens = [y.item()]
-
-for _ in range(13):
-  y = m1.step("shard", y, temp=0)
-  y = m2.step("shard", y, temp=0)
-  full_generated_tokens.append(y.item())
-
-sharded_response = processor2.tokenizer.decode(full_generated_tokens)
-print("sharded response:", sharded_response)
-
-assert full_response == sharded_response

+ 0 - 9
exo/inference/test_inference_engine.py

@@ -45,12 +45,3 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
 
 
 
 
 asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
 asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
-
-if os.getenv("RUN_TINYGRAD", default="0") == "1":
-  import tinygrad
-  import os
-  from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-  tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-  asyncio.run(
-    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", 32)
-  )

+ 0 - 0
exo/inference/tinygrad/__init__.py


+ 0 - 101
exo/inference/tinygrad/inference.py

@@ -1,101 +0,0 @@
-from pathlib import Path
-import json
-import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from exo.inference.shard import Shard
-from exo.inference.tokenizers import resolve_tokenizer
-from tinygrad.nn.state import load_state_dict
-from tinygrad import Tensor, nn, Context
-from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
-import numpy as np
-from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
-from exo.download.shard_download import ShardDownloader
-from concurrent.futures import ThreadPoolExecutor
-import asyncio
-
-Tensor.no_grad = True
-# default settings
-TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
-TOP_K = 25
-TOP_P = 0.9
-ALPHA_F = 0.1
-ALPHA_P = 0.0
-MODEL_PARAMS = {
-  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
-  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
-}
-
-
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
-  # build model
-  linear = nn.Linear
-  with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
-
-  # load weights
-  if model_path.is_dir():
-    if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
-    elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
-    else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
-  else:
-    weights = load(str(model_path), shard)
-  weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
-  weights = fix_bf16(weights)
-
-  with Context(BEAM=0):
-    # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=False)  # consume=True
-  return model
-
-
-class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, shard_downloader: ShardDownloader):
-    self.shard = None
-    self.shard_downloader = shard_downloader
-    self.executor = ThreadPoolExecutor(max_workers=1)
-
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += len(toks)
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      n_captured_toks = len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += n_captured_toks
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
-
-  async def ensure_shard(self, shard: Shard):
-    if self.shard == shard:
-      return
-
-    model_path = await self.shard_downloader.ensure_shard(shard)
-
-    if self.shard != shard:
-      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
-
-      tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
-      self.tokenizer = await resolve_tokenizer(tokenizer_path)
-      self.shard = shard

+ 0 - 0
exo/inference/tinygrad/models/__init__.py


+ 0 - 257
exo/inference/tinygrad/models/llama.py

@@ -1,257 +0,0 @@
-from typing import Tuple, Union, Optional, Dict, Any
-from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
-from tinygrad.helpers import getenv
-
-
-# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
-  freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
-  # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
-
-
-# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
-def complex_mult(A, c, d):
-  a, b = A[..., 0:1], A[..., 1:2]
-  ro = a*c - b*d
-  co = a*d + b*c
-  return ro.cat(co, dim=-1)
-
-
-def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
-  assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
-  xq = xq.reshape(*xq.shape[0:-1], -1, 2)
-  xk = xk.reshape(*xk.shape[0:-1], -1, 2)
-  assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
-  c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
-  xq_out = complex_mult(xq, c, d)
-  xk_out = complex_mult(xk, c, d)
-  return xq_out.flatten(3), xk_out.flatten(3)
-
-
-def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
-  bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1: return x
-  # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
-  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
-
-
-class Attention:
-  def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
-    self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
-    self.head_dim = dim // n_heads
-    self.n_rep = self.n_heads // self.n_kv_heads
-    self.max_context = max_context
-
-    self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
-    self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
-    self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
-    self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
-
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
-    if getenv("WQKV"):
-      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
-      xqkv = x @ self.wqkv.T
-      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
-    else:
-      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
-
-    xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
-    xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
-    xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
-
-    xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
-    bsz, seqlen, _, _ = xq.shape
-
-    # create kv cache
-    if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
-      if isinstance(x.device, tuple):
-        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
-
-    # update the cache
-    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
-
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
-
-    keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
-    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
-    attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
-    attn = attn.reshape(bsz, seqlen, -1)
-    return self.wo(attn)
-
-
-class FeedForward:
-  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
-    self.w1 = linear(dim, hidden_dim, bias=False)
-    self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
-
-  def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu()*self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
-
-
-class TransformerBlock:
-  def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
-    self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
-    self.feed_forward = feed_forward(dim, hidden_dim, linear)
-    self.attention_norm = nn.RMSNorm(dim, norm_eps)
-    self.ffn_norm = nn.RMSNorm(dim, norm_eps)
-
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
-    return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
-
-
-# standard openai sampling
-def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
-  assert logits.ndim == 1, "only works on 1d tensors"
-  assert 0 <= p <= 1, "p must be between 0 and 1"
-  assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
-
-  # if temperature is very low just use argmax
-  if temp < 1e-6: return logits.argmax().reshape(1)
-
-  # alpha sampling
-  if af or ap:
-    if not hasattr(sample, "alpha_counter"):
-      setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
-    logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
-
-  # replace NaNs with -inf
-  logits = (logits != logits).where(-float("inf"), logits)
-
-  # softmax
-  t = (logits/temp).softmax()
-
-  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
-  # top k
-  if k:
-    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
-    for i in range(k):
-      t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
-      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
-      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
-      t = (counter == t_argmax).where(0, t)
-
-    # approximate top p
-    # because we are already limited to top k elements we can do top p "without sorting"
-    output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
-    output = (output_cumsum >= (1 - p))*output
-    output_indices = (output_cumsum >= (1 - p))*output_indices
-
-    # sample
-    output_idx = output.multinomial()
-    output_token = output_indices[output_idx]
-  else:
-    output_token = t.multinomial()
-
-  # increase alpha counter
-  if af or ap:
-    sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
-
-  return output_token
-
-
-from exo.inference.shard import Shard
-
-
-class Transformer:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_layers: int,
-    norm_eps: float,
-    vocab_size,
-    shard: Shard = None,
-    linear=nn.Linear,
-    n_kv_heads=None,
-    rope_theta=10000,
-    max_context=1024,
-    jit=True,
-    feed_forward=FeedForward
-  ):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
-    self.norm = nn.RMSNorm(dim, norm_eps)
-    self.tok_embeddings = nn.Embedding(vocab_size, dim)
-    self.output = nn.Linear(dim, vocab_size, bias=False)
-    self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
-    self.forward_jit = TinyJit(self.forward) if jit else None
-    self.shard = shard
-
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
-    seqlen = x.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
-
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(x)
-    else:
-      h = x
-
-    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
-      layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask)
-
-    if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float()[:, -1, :]
-      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
-    else:
-      return h
-
-  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
-    # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
-    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
-
-
-# *** helpers ***
-
-
-def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
-  def permute(v: Tensor, n_heads: int):
-    return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
-
-  keymap = {
-    "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
-       for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
-       for x in ["q", "k", "v", "o"]
-       for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
-       for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
-       for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
-       for l in range(len(model.layers))},
-    "model.norm.weight": "norm.weight",
-    "lm_head.weight": "output.weight",
-  }
-  sd = {}
-  for k, v in weights.items():
-    if ".rotary_emb." in k: continue
-    v = v.to(Device.DEFAULT)
-    if "model.layers" in k:
-      if "q_proj" in k:
-        v = permute(v, n_heads)
-      elif "k_proj" in k:
-        v = permute(v, n_kv_heads)
-    sd[keymap[k]] = v
-  return sd
-
-
-def fix_bf16(weights: Dict[Any, Tensor]):
-  if getenv("SUPPORT_BF16", 1):
-    # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
-  # TODO: check if device supports bf16
-  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 0 - 47
exo/inference/tinygrad/tinygrad_helpers.py

@@ -1,47 +0,0 @@
-from tinygrad.nn.state import safe_load, torch_load
-from tinygrad import Tensor
-from pathlib import Path
-import json
-from typing import List
-from exo.inference.shard import Shard
-from exo.helpers import DEBUG
-from exo.download.hf.hf_helpers import get_allow_patterns
-from fnmatch import fnmatch
-
-
-# **** helper functions ****
-def concat_weights(models, device=None):
-  def convert(name) -> Tensor:
-    disk_tensors: List[Tensor] = [model[name] for model in models]
-    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
-      return disk_tensors[0].to(device=device)
-    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
-    lazy_tensors = [data.to(device=device) for data in disk_tensors]
-    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
-
-  return {name: convert(name) for name in {name: None for model in models for name in model}}
-
-
-def load(fn: str, shard: Shard):
-  if fn.endswith('.index.json'):
-    with open(fn) as fp:
-      weight_map = json.load(fp)['weight_map']
-    parts = {}
-    filtered_weight_map = {}
-    allow_patterns = get_allow_patterns(weight_map, shard)
-    for k, n in weight_map.items():
-      if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
-        continue
-      if k.startswith("model.layers."):
-        layer_num = int(k.split('.')[2])
-        if layer_num < shard.start_layer or layer_num > shard.end_layer:
-          continue
-
-      parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
-      filtered_weight_map[k] = n
-    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
-    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
-  elif fn.endswith(".safetensors"):
-    return safe_load(fn)
-  else:
-    return torch_load(fn)

+ 1 - 1
exo/inference/tokenizers.py

@@ -53,7 +53,7 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
     if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
     if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
     return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
     return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
   except Exception as e:
   except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Error: {e}")
     if DEBUG >= 4: print(traceback.format_exc())
     if DEBUG >= 4: print(traceback.format_exc())
 
 
   raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
   raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")

+ 4 - 8
exo/main.py

@@ -19,7 +19,7 @@ from exo.download.shard_download import ShardDownloader, RepoProgressEvent, Noop
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.inference_engine import InferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
@@ -45,14 +45,12 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
 args = parser.parse_args()
-print(f"Selected inference engine: {args.inference_engine}")
 
 
 print_yellow_exo()
 print_yellow_exo()
 
 
@@ -60,12 +58,10 @@ system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
 shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
 shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
-                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
-inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-print(f"Inference engine name after selection: {inference_engine_name}")
+                                                      max_parallel_downloads=args.max_parallel_downloads)
 
 
-inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
-print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+inference_engine = MLXDynamicShardInferenceEngine(shard_downloader)
 
 
 if args.node_port is None:
 if args.node_port is None:
   args.node_port = find_available_port(args.node_host)
   args.node_port = find_available_port(args.node_host)

+ 0 - 7
exo/models.py

@@ -6,24 +6,19 @@ model_base_shards = {
   "llama-3.2-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
   "llama-3.2-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
   "llama-3.1-8b": {
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3.1-70b": {
   "llama-3.1-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   "llama-3.1-70b-bf16": {
   "llama-3.1-70b-bf16": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
   },
   },
   "llama-3-70b": {
   "llama-3-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
   },
   },
   ### mistral
   ### mistral
   "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
   "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
@@ -44,6 +39,4 @@ model_base_shards = {
   ### nemotron
   ### nemotron
   "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
   "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
   "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
   "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
-  # dummy
-  "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
 }
 }

+ 1 - 33
exo/orchestration/node.py

@@ -14,7 +14,7 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
-from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.inference_engine import InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
 
 
 
@@ -63,10 +63,6 @@ class Node:
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
       status_data = json.loads(opaque_status)
       status_data = json.loads(opaque_status)
-      if status_data.get("type", "") == "supported_inference_engines":
-        node_id = status_data.get("node_id")
-        engines = status_data.get("engines", [])
-        self.topology_inference_engines_pool.append(engines)
       if status_data.get("type", "") == "node_status":
       if status_data.get("type", "") == "node_status":
         if status_data.get("status", "").startswith("start_"):
         if status_data.get("status", "").startswith("start_"):
           self.current_topology.active_node_id = status_data.get("node_id")
           self.current_topology.active_node_id = status_data.get("node_id")
@@ -84,22 +80,6 @@ class Node:
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
       if DEBUG >= 1: traceback.print_exc()
 
 
-  def get_supported_inference_engines(self):
-    supported_engine_names = []
-    if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-      supported_engine_names.append('mlx')
-      supported_engine_names.append('tinygrad')
-    else:
-      supported_engine_names.append('tinygrad')
-    return supported_engine_names
-
-  async def broadcast_supported_engines(self, supported_engines_names: List[str]):
-    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
-    await self.broadcast_opaque_status("", status_message)
-
-  def get_topology_inference_engines(self) -> List[List[str]]:
-    return self.topology_inference_engines_pool
-
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
@@ -352,17 +332,6 @@ class Node:
     self.peers = next_peers
     self.peers = next_peers
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
 
 
-  async def select_best_inference_engine(self):
-    supported_engines = self.get_supported_inference_engines()
-    await self.broadcast_supported_engines(supported_engines)
-    if len(self.get_topology_inference_engines()):
-      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
-        if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
-        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
-      else:
-        if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
-
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
       await asyncio.sleep(interval)
       await asyncio.sleep(interval)
@@ -371,7 +340,6 @@ class Node:
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if did_peers_change:
         if did_peers_change:
           await self.collect_topology()
           await self.collect_topology()
-          await self.select_best_inference_engine()
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology: {e}")
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
         traceback.print_exc()

+ 0 - 41
exo/topology/device_capabilities.py

@@ -142,8 +142,6 @@ CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items
 def device_capabilities() -> DeviceCapabilities:
 def device_capabilities() -> DeviceCapabilities:
   if psutil.MACOS:
   if psutil.MACOS:
     return mac_device_capabilities()
     return mac_device_capabilities()
-  elif psutil.LINUX:
-    return linux_device_capabilities()
   else:
   else:
     return DeviceCapabilities(
     return DeviceCapabilities(
       model="Unknown Device",
       model="Unknown Device",
@@ -171,42 +169,3 @@ def mac_device_capabilities() -> DeviceCapabilities:
 
 
   # Assuming static values for other attributes for demonstration
   # Assuming static values for other attributes for demonstration
   return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
   return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
-
-
-def linux_device_capabilities() -> DeviceCapabilities:
-  import psutil
-  from tinygrad import Device
-
-  if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
-  if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
-    import pynvml
-
-    pynvml.nvmlInit()
-    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
-    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
-    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
-    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
-
-    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
-
-    return DeviceCapabilities(
-      model=f"Linux Box ({gpu_name})",
-      chip=gpu_name,
-      memory=gpu_memory_info.total // 2**20,
-      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
-    )
-  elif Device.DEFAULT == "AMD":
-    # TODO AMD support
-    return DeviceCapabilities(
-      model="Linux Box (AMD)",
-      chip="Unknown AMD",
-      memory=psutil.virtual_memory().total // 2**20,
-      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
-    )
-  else:
-    return DeviceCapabilities(
-      model=f"Linux Box (Device: {Device.DEFAULT})",
-      chip=f"Unknown Chip (Device: {Device.DEFAULT})",
-      memory=psutil.virtual_memory().total // 2**20,
-      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
-    )

+ 0 - 1
setup.py

@@ -26,7 +26,6 @@ install_requires = [
   "tqdm==4.66.4",
   "tqdm==4.66.4",
   "transformers==4.43.3",
   "transformers==4.43.3",
   "uuid==1.30",
   "uuid==1.30",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
 ]
 ]
 
 
 extras_require = {
 extras_require = {