Sfoglia il codice sorgente

remove hard dependency on MLX fixes #8

Alex Cheema 1 anno fa
parent
commit
dbbc7be57f
4 ha cambiato i file con 12 aggiunte e 12 eliminazioni
  1. 2 2
      exo/api/chatgpt_api.py
  2. 0 2
      exo/inference/mlx/sharded_utils.py
  3. 8 5
      main.py
  4. 2 3
      requirements.txt

+ 2 - 2
exo/api/chatgpt_api.py

@@ -1,13 +1,11 @@
 import uuid
 import time
 import asyncio
-from http.server import BaseHTTPRequestHandler, HTTPServer
 from typing import List
 from aiohttp import web
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.orchestration import Node
-from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 
 shard_mappings = {
     "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
@@ -41,6 +39,8 @@ class ChatGPTAPI:
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
         request_id = str(uuid.uuid4())
 
+        # TODO equivalent for non-mlx since the user can't install this on non-macs even though the tokenizer itself
+        from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
         tokenizer = load_tokenizer(get_model_path(shard.model_id))
         prompt = tokenizer.apply_chat_template(
             chat_request.messages, tokenize=False, add_generation_prompt=True

+ 0 - 2
exo/inference/mlx/sharded_utils.py

@@ -11,8 +11,6 @@ import mlx.core as mx
 import mlx.nn as nn
 from huggingface_hub import snapshot_download
 from huggingface_hub.utils._errors import RepositoryNotFoundError
-from mlx.utils import tree_flatten
-from transformers import PreTrainedTokenizer
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers

+ 8 - 5
main.py

@@ -1,18 +1,15 @@
 import argparse
 import asyncio
 import signal
-import mlx.core as mx
-import mlx.nn as nn
 import uuid
+import platform
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 
-
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
@@ -24,8 +21,14 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 args = parser.parse_args()
 
+print(f"Starting {platform.system()=}")
+if platform.system() == "Darwin":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+    inference_engine = MLXDynamicShardInferenceEngine()
+else:
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    inference_engine = TinygradDynamicShardInferenceEngine()
 
-inference_engine = MLXDynamicShardInferenceEngine()
 def on_token(tokens: List[int]):
     if inference_engine.tokenizer:
         print(inference_engine.tokenizer.decode(tokens))

+ 2 - 3
requirements.txt

@@ -2,8 +2,8 @@ aiohttp==3.9.5
 grpcio==1.64.1
 grpcio-tools==1.64.1
 huggingface-hub==0.23.4
-mlx==0.15.1
-mlx-lm==0.14.3
+mlx==0.15.1; sys.platform == "darwin"
+mlx-lm==0.14.3; sys.platform == "darwin"
 numpy==2.0.0
 protobuf==5.27.1
 requests==2.32.3
@@ -14,4 +14,3 @@ tokenizers==0.19.1
 tqdm==4.66.4
 transformers==4.41.2
 uuid==1.30
-