Browse Source

remove hard dependency on MLX fixes #8

Alex Cheema 1 year ago
parent
commit
dbbc7be57f
4 changed files with 12 additions and 12 deletions
  1. 2 2
      exo/api/chatgpt_api.py
  2. 0 2
      exo/inference/mlx/sharded_utils.py
  3. 8 5
      main.py
  4. 2 3
      requirements.txt

+ 2 - 2
exo/api/chatgpt_api.py

@@ -1,13 +1,11 @@
 import uuid
 import uuid
 import time
 import time
 import asyncio
 import asyncio
-from http.server import BaseHTTPRequestHandler, HTTPServer
 from typing import List
 from typing import List
 from aiohttp import web
 from aiohttp import web
 from exo import DEBUG
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 from exo.orchestration import Node
-from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 
 
 shard_mappings = {
 shard_mappings = {
     "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
@@ -41,6 +39,8 @@ class ChatGPTAPI:
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
         request_id = str(uuid.uuid4())
         request_id = str(uuid.uuid4())
 
 
+        # TODO equivalent for non-mlx since the user can't install this on non-macs even though the tokenizer itself
+        from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
         tokenizer = load_tokenizer(get_model_path(shard.model_id))
         tokenizer = load_tokenizer(get_model_path(shard.model_id))
         prompt = tokenizer.apply_chat_template(
         prompt = tokenizer.apply_chat_template(
             chat_request.messages, tokenize=False, add_generation_prompt=True
             chat_request.messages, tokenize=False, add_generation_prompt=True

+ 0 - 2
exo/inference/mlx/sharded_utils.py

@@ -11,8 +11,6 @@ import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
 from huggingface_hub.utils._errors import RepositoryNotFoundError
 from huggingface_hub.utils._errors import RepositoryNotFoundError
-from mlx.utils import tree_flatten
-from transformers import PreTrainedTokenizer
 
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tuner.utils import apply_lora_layers
 from mlx_lm.tuner.utils import apply_lora_layers

+ 8 - 5
main.py

@@ -1,18 +1,15 @@
 import argparse
 import argparse
 import asyncio
 import asyncio
 import signal
 import signal
-import mlx.core as mx
-import mlx.nn as nn
 import uuid
 import uuid
+import platform
 from typing import List
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 
 
-
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
 parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
@@ -24,8 +21,14 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
+print(f"Starting {platform.system()=}")
+if platform.system() == "Darwin":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+    inference_engine = MLXDynamicShardInferenceEngine()
+else:
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    inference_engine = TinygradDynamicShardInferenceEngine()
 
 
-inference_engine = MLXDynamicShardInferenceEngine()
 def on_token(tokens: List[int]):
 def on_token(tokens: List[int]):
     if inference_engine.tokenizer:
     if inference_engine.tokenizer:
         print(inference_engine.tokenizer.decode(tokens))
         print(inference_engine.tokenizer.decode(tokens))

+ 2 - 3
requirements.txt

@@ -2,8 +2,8 @@ aiohttp==3.9.5
 grpcio==1.64.1
 grpcio==1.64.1
 grpcio-tools==1.64.1
 grpcio-tools==1.64.1
 huggingface-hub==0.23.4
 huggingface-hub==0.23.4
-mlx==0.15.1
-mlx-lm==0.14.3
+mlx==0.15.1; sys.platform == "darwin"
+mlx-lm==0.14.3; sys.platform == "darwin"
 numpy==2.0.0
 numpy==2.0.0
 protobuf==5.27.1
 protobuf==5.27.1
 requests==2.32.3
 requests==2.32.3
@@ -14,4 +14,3 @@ tokenizers==0.19.1
 tqdm==4.66.4
 tqdm==4.66.4
 transformers==4.41.2
 transformers==4.41.2
 uuid==1.30
 uuid==1.30
-