Browse Source

Merge remote-tracking branch 'origin/main' into runners2

Alex Cheema 3 months ago
parent
commit
461e4f37cb

+ 148 - 152
exo/api/chatgpt_api.py

@@ -21,7 +21,13 @@ from PIL import Image
 import numpy as np
 import numpy as np
 import base64
 import base64
 from io import BytesIO
 from io import BytesIO
-import mlx.core as mx
+import platform
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
+
 import tempfile
 import tempfile
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 import shutil
@@ -29,6 +35,7 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
 from exo.apputil import create_animation_mp4
 from collections import defaultdict
 from collections import defaultdict
 
 
+
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
     self.role = role
     self.role = role
@@ -42,7 +49,6 @@ class Message:
     return data
     return data
 
 
 
 
-
 class ChatCompletionRequest:
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
   def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
     self.model = model
     self.model = model
@@ -133,16 +139,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 
 
 def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
 def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
   messages = remap_messages(_messages)
   messages = remap_messages(_messages)
-  chat_template_args = {
-    "conversation": [m.to_dict() for m in messages],
-    "tokenize": False,
-    "add_generation_prompt": True
-  }
-  if tools: chat_template_args["tools"] = tools
-
-  prompt = tokenizer.apply_chat_template(**chat_template_args)
-  print(f"!!! Prompt: {prompt}")
-  return prompt
+  chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
+  if tools: 
+    chat_template_args["tools"] = tools
+
+  try:
+    prompt = tokenizer.apply_chat_template(**chat_template_args)
+    if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
+    return prompt
+  except UnicodeEncodeError:
+    # Handle Unicode encoding by ensuring everything is UTF-8
+    chat_template_args["conversation"] = [
+      {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v 
+       for k, v in m.to_dict().items()}
+      for m in messages
+    ]
+    prompt = tokenizer.apply_chat_template(**chat_template_args)
+    if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
+    return prompt
 
 
 
 
 def parse_message(data: dict):
 def parse_message(data: dict):
@@ -166,8 +180,17 @@ class PromptSession:
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
+
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
+  def __init__(
+    self,
+    node: Node,
+    inference_engine_classname: str,
+    response_timeout: int = 90,
+    on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
+    default_model: Optional[str] = None,
+    system_prompt: Optional[str] = None
+  ):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
     self.response_timeout = response_timeout
@@ -209,18 +232,22 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
 
-      
+    # Add static routes
     if "__compiled__" not in globals():
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
       self.static_dir = Path(__file__).parent.parent/"tinychat"
       self.app.router.add_get("/", self.handle_root)
       self.app.router.add_get("/", self.handle_root)
       self.app.router.add_static("/", self.static_dir, name="static")
       self.app.router.add_static("/", self.static_dir, name="static")
-      self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
+      
+    # Always add images route, regardless of compilation status
+    self.images_dir = get_exo_images_dir()
+    self.images_dir.mkdir(parents=True, exist_ok=True)
+    self.app.router.add_static('/images/', self.images_dir, name='static_images')
 
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
 
 
   async def handle_quit(self, request):
   async def handle_quit(self, request):
-    if DEBUG>=1: print("Received quit signal")
+    if DEBUG >= 1: print("Received quit signal")
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     await response.prepare(request)
     await response.prepare(request)
     await response.write_eof()
     await response.write_eof()
@@ -250,61 +277,48 @@ class ChatGPTAPI:
 
 
   async def handle_model_support(self, request):
   async def handle_model_support(self, request):
     try:
     try:
-        response = web.StreamResponse(
-            status=200,
-            reason='OK',
-            headers={
-                'Content-Type': 'text/event-stream',
-                'Cache-Control': 'no-cache',
-                'Connection': 'keep-alive',
-            }
-        )
-        await response.prepare(request)
+      response = web.StreamResponse(status=200, reason='OK', headers={
+        'Content-Type': 'text/event-stream',
+        'Cache-Control': 'no-cache',
+        'Connection': 'keep-alive',
+      })
+      await response.prepare(request)
+
+      async def process_model(model_name, pretty):
+        if model_name in model_cards:
+          model_info = model_cards[model_name]
+
+          if self.inference_engine_classname in model_info.get("repo", {}):
+            shard = build_base_shard(model_name, self.inference_engine_classname)
+            if shard:
+              downloader = HFShardDownloader(quick_check=True)
+              downloader.current_shard = shard
+              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+              status = await downloader.get_shard_download_status()
+
+              download_percentage = status.get("overall") if status else None
+              total_size = status.get("total_size") if status else None
+              total_downloaded = status.get("total_downloaded") if status else False
 
 
-        async def process_model(model_name, pretty):
-            if model_name in model_cards:
-                model_info = model_cards[model_name]
-
-                if self.inference_engine_classname in model_info.get("repo", {}):
-                    shard = build_base_shard(model_name, self.inference_engine_classname)
-                    if shard:
-                        downloader = HFShardDownloader(quick_check=True)
-                        downloader.current_shard = shard
-                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-                        status = await downloader.get_shard_download_status()
-
-                        download_percentage = status.get("overall") if status else None
-                        total_size = status.get("total_size") if status else None
-                        total_downloaded = status.get("total_downloaded") if status else False
-
-                        model_data = {
-                            model_name: {
-                                "name": pretty,
-                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                                "download_percentage": download_percentage,
-                                "total_size": total_size,
-                                "total_downloaded": total_downloaded
-                            }
-                        }
-
-                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-
-        # Process all models in parallel
-        await asyncio.gather(*[
-            process_model(model_name, pretty)
-            for model_name, pretty in pretty_name.items()
-        ])
-
-        await response.write(b"data: [DONE]\n\n")
-        return response
+              model_data = {
+                model_name: {
+                  "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
+                  "total_downloaded": total_downloaded
+                }
+              }
+
+              await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+
+      # Process all models in parallel
+      await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
+
+      await response.write(b"data: [DONE]\n\n")
+      return response
 
 
     except Exception as e:
     except Exception as e:
-        print(f"Error in handle_model_support: {str(e)}")
-        traceback.print_exc()
-        return web.json_response(
-            {"detail": f"Server error: {str(e)}"},
-            status=500
-        )
+      print(f"Error in handle_model_support: {str(e)}")
+      traceback.print_exc()
+      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
 
 
   async def handle_get_models(self, request):
   async def handle_get_models(self, request):
     models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
     models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
@@ -466,7 +480,6 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
 
-  
   async def handle_post_image_generations(self, request):
   async def handle_post_image_generations(self, request):
     data = await request.json()
     data = await request.json()
 
 
@@ -479,7 +492,7 @@ class ChatGPTAPI:
     shard = build_base_shard(model, self.inference_engine_classname)
     shard = build_base_shard(model, self.inference_engine_classname)
     if DEBUG >= 2: print(f"shard: {shard}")
     if DEBUG >= 2: print(f"shard: {shard}")
     if not shard:
     if not shard:
-        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
+      return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
 
 
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback_id = f"chatgpt-api-wait-response-{request_id}"
@@ -491,77 +504,85 @@ class ChatGPTAPI:
         img = None
         img = None
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
 
 
-
-      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
+      response = web.StreamResponse(status=200, reason='OK', headers={
+        'Content-Type': 'application/octet-stream',
+        "Cache-Control": "no-cache",
+      })
       await response.prepare(request)
       await response.prepare(request)
 
 
       def get_progress_bar(current_step, total_steps, bar_length=50):
       def get_progress_bar(current_step, total_steps, bar_length=50):
         # Calculate the percentage of completion
         # Calculate the percentage of completion
-        percent = float(current_step) / total_steps
+        percent = float(current_step)/total_steps
         # Calculate the number of hashes to display
         # Calculate the number of hashes to display
-        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
-        spaces = ' ' * (bar_length - len(arrow))
-        
+        arrow = '-'*int(round(percent*bar_length) - 1) + '>'
+        spaces = ' '*(bar_length - len(arrow))
+
         # Create the progress bar string
         # Create the progress bar string
         progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
         progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
         return progress_bar
         return progress_bar
 
 
       async def stream_image(_request_id: str, result, is_finished: bool):
       async def stream_image(_request_id: str, result, is_finished: bool):
-          if isinstance(result, list):
-              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
+        if isinstance(result, list):
+          await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
 
 
-          elif isinstance(result, np.ndarray):
+        elif isinstance(result, np.ndarray):
+          try:
             im = Image.fromarray(np.array(result))
             im = Image.fromarray(np.array(result))
-            images_folder = get_exo_images_dir()
             # Save the image to a file
             # Save the image to a file
             image_filename = f"{_request_id}.png"
             image_filename = f"{_request_id}.png"
-            image_path = images_folder / image_filename
+            image_path = self.images_dir/image_filename
             im.save(image_path)
             im.save(image_path)
-            image_url = request.app.router['static_images'].url_for(filename=image_filename)
-            base_url = f"{request.scheme}://{request.host}"
-            # Construct the full URL correctly
-            full_image_url = base_url + str(image_url)
             
             
-            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            # Get URL for the saved image
+            try:
+              image_url = request.app.router['static_images'].url_for(filename=image_filename)
+              base_url = f"{request.scheme}://{request.host}"
+              full_image_url = base_url + str(image_url)
+              
+              await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            except KeyError as e:
+              if DEBUG >= 2: print(f"Error getting image URL: {e}")
+              # Fallback to direct file path if URL generation fails
+              await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            
             if is_finished:
             if is_finished:
               await response.write_eof()
               await response.write_eof()
-              
+            
+          except Exception as e:
+            if DEBUG >= 2: print(f"Error processing image: {e}")
+            if DEBUG >= 2: traceback.print_exc()
+            await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
 
 
       stream_task = None
       stream_task = None
+
       def on_result(_request_id: str, result, is_finished: bool):
       def on_result(_request_id: str, result, is_finished: bool):
-          nonlocal stream_task
-          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
-          return _request_id == request_id and is_finished
+        nonlocal stream_task
+        stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
+        return _request_id == request_id and is_finished
 
 
       await callback.wait(on_result, timeout=self.response_timeout*10)
       await callback.wait(on_result, timeout=self.response_timeout*10)
-      
+
       if stream_task:
       if stream_task:
-          # Wait for the stream task to complete before returning
-          await stream_task
+        # Wait for the stream task to complete before returning
+        await stream_task
 
 
       return response
       return response
 
 
     except Exception as e:
     except Exception as e:
-        if DEBUG >= 2: traceback.print_exc()
-        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-  
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+
   async def handle_delete_model(self, request):
   async def handle_delete_model(self, request):
     try:
     try:
       model_name = request.match_info.get('model_name')
       model_name = request.match_info.get('model_name')
       if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
       if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
 
 
       if not model_name or model_name not in model_cards:
       if not model_name or model_name not in model_cards:
-        return web.json_response(
-          {"detail": f"Invalid model name: {model_name}"},
-          status=400
-          )
+        return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
 
 
       shard = build_base_shard(model_name, self.inference_engine_classname)
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard:
       if not shard:
-        return web.json_response(
-          {"detail": "Could not build shard for model"},
-          status=400
-        )
+        return web.json_response({"detail": "Could not build shard for model"}, status=400)
 
 
       repo_id = get_repo(shard.model_id, self.inference_engine_classname)
       repo_id = get_repo(shard.model_id, self.inference_engine_classname)
       if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
       if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
@@ -576,38 +597,28 @@ class ChatGPTAPI:
         if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
         if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
         try:
         try:
           shutil.rmtree(cache_dir)
           shutil.rmtree(cache_dir)
-          return web.json_response({
-            "status": "success",
-            "message": f"Model {model_name} deleted successfully",
-            "path": str(cache_dir)
-          })
+          return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
         except Exception as e:
         except Exception as e:
-          return web.json_response({
-            "detail": f"Failed to delete model files: {str(e)}"
-          }, status=500)
+          return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
       else:
       else:
-        return web.json_response({
-          "detail": f"Model files not found at {cache_dir}"
-        }, status=404)
+        return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
 
 
     except Exception as e:
     except Exception as e:
-        print(f"Error in handle_delete_model: {str(e)}")
-        traceback.print_exc()
-        return web.json_response({
-            "detail": f"Server error: {str(e)}"
-        }, status=500)
+      print(f"Error in handle_delete_model: {str(e)}")
+      traceback.print_exc()
+      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
 
 
   async def handle_get_initial_models(self, request):
   async def handle_get_initial_models(self, request):
     model_data = {}
     model_data = {}
     for model_name, pretty in pretty_name.items():
     for model_name, pretty in pretty_name.items():
-        model_data[model_name] = {
-            "name": pretty,
-            "downloaded": None,  # Initially unknown
-            "download_percentage": None,  # Change from 0 to null
-            "total_size": None,
-            "total_downloaded": None,
-            "loading": True  # Add loading state
-        }
+      model_data[model_name] = {
+        "name": pretty,
+        "downloaded": None,  # Initially unknown
+        "download_percentage": None,  # Change from 0 to null
+        "total_size": None,
+        "total_downloaded": None,
+        "loading": True  # Add loading state
+      }
     return web.json_response(model_data)
     return web.json_response(model_data)
 
 
   async def handle_create_animation(self, request):
   async def handle_create_animation(self, request):
@@ -633,17 +644,9 @@ class ChatGPTAPI:
       if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
       if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
 
 
       # Create the animation
       # Create the animation
-      create_animation_mp4(
-        replacement_image_path,
-        output_path,
-        device_name,
-        prompt_text
-      )
+      create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
 
 
-      return web.json_response({
-        "status": "success",
-        "output_path": output_path
-      })
+      return web.json_response({"status": "success", "output_path": output_path})
 
 
     except Exception as e:
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       if DEBUG >= 2: traceback.print_exc()
@@ -659,10 +662,7 @@ class ChatGPTAPI:
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
 
-      return web.json_response({
-        "status": "success",
-        "message": f"Download started for model: {model_name}"
-      })
+      return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
     except Exception as e:
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"error": str(e)}, status=500)
       return web.json_response({"error": str(e)}, status=500)
@@ -676,10 +676,7 @@ class ChatGPTAPI:
         return web.json_response({})
         return web.json_response({})
     except Exception as e:
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       if DEBUG >= 2: traceback.print_exc()
-      return web.json_response(
-        {"detail": f"Error getting topology: {str(e)}"},
-        status=500
-      )
+      return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
 
 
   async def handle_token(self, request_id: str, token: int, is_finished: bool):
   async def handle_token(self, request_id: str, token: int, is_finished: bool):
     await self.token_queues[request_id].put((token, is_finished))
     await self.token_queues[request_id].put((token, is_finished))
@@ -693,15 +690,14 @@ class ChatGPTAPI:
   def base64_decode(self, base64_string):
   def base64_decode(self, base64_string):
     #decode and reshape image
     #decode and reshape image
     if base64_string.startswith('data:image'):
     if base64_string.startswith('data:image'):
-        base64_string = base64_string.split(',')[1]
+      base64_string = base64_string.split(',')[1]
     image_data = base64.b64decode(base64_string)
     image_data = base64.b64decode(base64_string)
     img = Image.open(BytesIO(image_data))
     img = Image.open(BytesIO(image_data))
-    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    W, H = (dim - dim%64 for dim in (img.width, img.height))
     if W != img.width or H != img.height:
     if W != img.width or H != img.height:
-        if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
-        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+      if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+      img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
     img = mx.array(np.array(img))
     img = mx.array(np.array(img))
-    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
     img = img[None]
     img = img[None]
     return img
     return img
-  

+ 12 - 5
exo/apputil/anim.py

@@ -2,6 +2,7 @@ from PIL import Image, ImageDraw, ImageFont, ImageFilter
 import os
 import os
 import numpy as np
 import numpy as np
 import cv2
 import cv2
+import sys
 
 
 def draw_rounded_rectangle(draw, coords, radius, fill):
 def draw_rounded_rectangle(draw, coords, radius, fill):
   left, top, right, bottom = coords
   left, top, right, bottom = coords
@@ -80,14 +81,20 @@ def create_animation_mp4(
     font = ImageFont.load_default()
     font = ImageFont.load_default()
     promptfont = ImageFont.load_default()
     promptfont = ImageFont.load_default()
 
 
+  # Get the base directory for images when running as a bundled app
+  if hasattr(sys, '_MEIPASS'):
+    base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
+  else:
+    base_dir = os.path.join(os.path.dirname(__file__), "baseimages")
+
   # Process first frame
   # Process first frame
-  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
+  base_img = Image.open(os.path.join(base_dir, "image1.png"))
   draw = ImageDraw.Draw(base_img)
   draw = ImageDraw.Draw(base_img)
   draw_centered_text_rounded(draw, device_name, font, device_coords)
   draw_centered_text_rounded(draw, device_name, font, device_coords)
   frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
   frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
 
 
   # Process second frame with typing animation
   # Process second frame with typing animation
-  base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
+  base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
   for i in range(len(prompt_text) + 1):
   for i in range(len(prompt_text) + 1):
     current_frame = base_img2.copy()
     current_frame = base_img2.copy()
     draw = ImageDraw.Draw(current_frame)
     draw = ImageDraw.Draw(current_frame)
@@ -101,7 +108,7 @@ def create_animation_mp4(
 
 
   # Create blur sequence
   # Create blur sequence
   replacement_img = Image.open(replacement_image_path)
   replacement_img = Image.open(replacement_image_path)
-  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
+  base_img = Image.open(os.path.join(base_dir, "image3.png"))
   blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
   blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
 
 
   for i, blur_amount in enumerate(blur_steps):
   for i, blur_amount in enumerate(blur_steps):
@@ -123,7 +130,7 @@ def create_animation_mp4(
     frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
     frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
 
 
   # Create and add final frame (image4)
   # Create and add final frame (image4)
-  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
+  final_base = Image.open(os.path.join(base_dir, "image4.png"))
   draw = ImageDraw.Draw(final_base)
   draw = ImageDraw.Draw(final_base)
 
 
   draw_centered_text_rounded(draw, device_name, font, device_coords)
   draw_centered_text_rounded(draw, device_name, font, device_coords)
@@ -158,4 +165,4 @@ def create_animation_mp4(
       out.write(frame_array)
       out.write(frame_array)
     
     
     out.release()
     out.release()
-    print(f"Video saved successfully to {output_path}")
+    print(f"Video saved successfully to {output_path}")

+ 25 - 25
exo/helpers.py

@@ -7,7 +7,8 @@ import random
 import platform
 import platform
 import psutil
 import psutil
 import uuid
 import uuid
-import netifaces
+from scapy.all import get_if_addr, get_if_list
+import re
 import subprocess
 import subprocess
 from pathlib import Path
 from pathlib import Path
 import tempfile
 import tempfile
@@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
 def get_all_ip_addresses_and_interfaces():
 def get_all_ip_addresses_and_interfaces():
   try:
   try:
     ip_addresses = []
     ip_addresses = []
-    for interface in netifaces.interfaces():
-      ifaddresses = netifaces.ifaddresses(interface)
-      if netifaces.AF_INET in ifaddresses:
-        for link in ifaddresses[netifaces.AF_INET]:
-          ip = link['addr']
-          ip_addresses.append((ip, interface))
+    for interface in get_if_list():
+      ip = get_if_addr(interface)
+      # Include all addresses, including loopback
+      # Filter out link-local addresses
+      if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
+        # Remove "\\Device\\NPF_" prefix from interface name
+        simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
+        ip_addresses.append((ip, simplified_interface))
     return list(set(ip_addresses))
     return list(set(ip_addresses))
   except:
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return [("localhost", "lo")]
     return [("localhost", "lo")]
 
 
+
 async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
 async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
   try:
   try:
     # Use the shared subprocess_pool
     # Use the shared subprocess_pool
-    output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
-      ['system_profiler', 'SPNetworkDataType', '-json'],
-      capture_output=True,
-      text=True,
-      close_fds=True
-    ).stdout)
+    output = await asyncio.get_running_loop().run_in_executor(
+      subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
+    )
 
 
     data = json.loads(output)
     data = json.loads(output)
 
 
@@ -276,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
 
 
   return None
   return None
 
 
+
 async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
 async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
   # On macOS, try to get interface type using networksetup
   # On macOS, try to get interface type using networksetup
   if psutil.MACOS:
   if psutil.MACOS:
@@ -283,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
     if macos_type is not None: return macos_type
     if macos_type is not None: return macos_type
 
 
   # Local container/virtual interfaces
   # Local container/virtual interfaces
-  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
-    'bridge' in ifname):
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
     return (7, "Container Virtual")
     return (7, "Container Virtual")
 
 
   # Loopback interface
   # Loopback interface
@@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
   # Other physical interfaces
   # Other physical interfaces
   return (2, "Other")
   return (2, "Other")
 
 
+
 async def shutdown(signal, loop, server):
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""
   """Gracefully shutdown the server and close the asyncio loop."""
   print(f"Received exit signal {signal.name}...")
   print(f"Received exit signal {signal.name}...")
@@ -353,18 +355,16 @@ async def get_mac_system_info() -> Tuple[str, str, int]:
         return "Unknown Model", "Unknown Chip", 0
         return "Unknown Model", "Unknown Chip", 0
 
 
 def get_exo_home() -> Path:
 def get_exo_home() -> Path:
-  if os.name == "nt":  # Check if the OS is Windows
-    docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
-  else:
-    docs_folder = Path.home() / "Documents"
-  exo_folder = docs_folder / "Exo"
-  if not exo_folder.exists():
-    exo_folder.mkdir()
+  if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
+  else: docs_folder = Path.home()/"Documents"
+  if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
+  exo_folder = docs_folder/"Exo"
+  if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
   return exo_folder
   return exo_folder
 
 
+
 def get_exo_images_dir() -> Path:
 def get_exo_images_dir() -> Path:
   exo_home = get_exo_home()
   exo_home = get_exo_home()
-  images_dir = exo_home / "Images"
-  if not images_dir.exists():
-    images_dir.mkdir()
+  images_dir = exo_home/"Images"
+  if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
   return images_dir
   return images_dir

+ 5 - 5
exo/inference/debug_inference_engine.py

@@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
   token_full = await inference_engine_1.sample(resp_full)
 
 
-  next_resp_full = await inference_engine_1.infer_tensor(
+  next_resp_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=token_full,
     input_data=token_full,
   )
   )
 
 
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
+  resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
     input_data=resp1,
   )
   )
   token2 = await inference_engine_2.sample(resp2)
   token2 = await inference_engine_2.sample(resp2)
-  resp3 = await inference_engine_1.infer_tensor(
+  resp3, _ = await inference_engine_1.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
     input_data=token2,
     input_data=token2,
   )
   )
-  resp4 = await inference_engine_2.infer_tensor(
+  resp4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,

+ 2 - 2
exo/inference/dummy_inference_engine.py

@@ -25,9 +25,9 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
     return self.tokenizer.decode(tokens)
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    return input_data + 1 if self.shard.is_last_layer() else input_data
+    return input_data + 1 if self.shard.is_last_layer() else input_data, None
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: return
     if self.shard == shard: return

+ 10 - 7
exo/inference/inference_engine.py

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG  # Make sure to import DEBUG
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
+from exo.download.shard_download import ShardDownloader
 
 
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     pass
     pass
-  
+
   @abstractmethod
   @abstractmethod
   async def sample(self, x: np.ndarray) -> np.ndarray:
   async def sample(self, x: np.ndarray) -> np.ndarray:
     pass
     pass
@@ -23,7 +24,7 @@ class InferenceEngine(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
@@ -32,14 +33,14 @@ class InferenceEngine(ABC):
 
 
   async def save_checkpoint(self, shard: Shard, path: str):
   async def save_checkpoint(self, shard: Shard, path: str):
     pass
     pass
-  
+
   async def save_session(self, key, value):
   async def save_session(self, key, value):
     self.session[key] = value
     self.session[key] = value
-  
+
   async def clear_session(self):
   async def clear_session(self):
     self.session.empty()
     self.session.empty()
-  
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
+
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     tokens = await self.encode(shard, prompt)
     tokens = await self.encode(shard, prompt)
     if shard.model_id != 'stable-diffusion-2-1-base':
     if shard.model_id != 'stable-diffusion-2-1-base':
       x = tokens.reshape(1, -1)
       x = tokens.reshape(1, -1)
@@ -49,13 +50,15 @@ class InferenceEngine(ABC):
 
 
     return output_data, inference_state
     return output_data, inference_state
 
 
+
 inference_engine_classes = {
 inference_engine_classes = {
   "mlx": "MLXDynamicShardInferenceEngine",
   "mlx": "MLXDynamicShardInferenceEngine",
   "tinygrad": "TinygradDynamicShardInferenceEngine",
   "tinygrad": "TinygradDynamicShardInferenceEngine",
   "dummy": "DummyInferenceEngine",
   "dummy": "DummyInferenceEngine",
 }
 }
 
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
   if DEBUG >= 2:
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")
     print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":
   if inference_engine_name == "mlx":

+ 3 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -67,15 +67,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     self.model.load_weights(path)
     self.model.load_weights(path)
     
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     loop = asyncio.get_running_loop()
     loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
     x = mx.array(input_data)
     if self.model.model_type != 'StableDiffusionPipeline':
     if self.model.model_type != 'StableDiffusionPipeline':
-      output_data = self.model(x, **state, **inference_state)
+      output_data = self.model(x, **state, **(inference_state or {}))
     else:
     else:
-      output_data, inference_state = self.model(x, **state, **inference_state)
+      output_data, inference_state = self.model(x, **state, **(inference_state or {}))
     output_data = np.array(output_data, copy=False)
     output_data = np.array(output_data, copy=False)
     return output_data, inference_state
     return output_data, inference_state
 
 

+ 5 - 11
exo/inference/test_dummy_inference_engine.py

@@ -1,22 +1,16 @@
 import pytest
 import pytest
-import json
 import numpy as np
 import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
 
 
-class MockShardDownloader:
-  async def ensure_shard(self, shard):
-    pass
-
-
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
 async def test_dummy_inference_specific():
-  engine = DummyInferenceEngine(MockShardDownloader())
+  engine = DummyInferenceEngine()
   test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
   test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
   test_prompt = "This is a test prompt"
   test_prompt = "This is a test prompt"
 
 
-  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+  result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
 
 
   print(f"Inference result shape: {result.shape}")
   print(f"Inference result shape: {result.shape}")
 
 
@@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
 async def test_dummy_inference_engine():
   # Initialize the DummyInferenceEngine
   # Initialize the DummyInferenceEngine
-  engine = DummyInferenceEngine(MockShardDownloader())
+  engine = DummyInferenceEngine()
 
 
   # Create a test shard
   # Create a test shard
   shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
   shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
 
 
   # Test infer_prompt
   # Test infer_prompt
-  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+  output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
 
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"
   assert output.ndim == 2, "Output should be 2-dimensional"
 
 
   # Test infer_tensor
   # Test infer_tensor
   input_tensor = np.array([[1, 2, 3]])
   input_tensor = np.array([[1, 2, 3]])
-  output = await engine.infer_tensor("test_id", shard, input_tensor)
+  output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
 
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"
   assert output.ndim == 2, "Output should be 2-dimensional"

+ 6 - 6
exo/inference/test_inference_engine.py

@@ -11,30 +11,30 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
   token_full = await inference_engine_1.sample(resp_full)
   token_full = token_full.reshape(1, -1)
   token_full = token_full.reshape(1, -1)
-  next_resp_full = await inference_engine_1.infer_tensor(
+  next_resp_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=token_full,
     input_data=token_full,
   )
   )
 
 
   pp = n_layers // 2
   pp = n_layers // 2
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
+  resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
     input_data=resp1,
   )
   )
   tokens2 = await inference_engine_1.sample(resp2)
   tokens2 = await inference_engine_1.sample(resp2)
   tokens2 = tokens2.reshape(1, -1)
   tokens2 = tokens2.reshape(1, -1)
-  resp3 = await inference_engine_1.infer_tensor(
+  resp3, _ = await inference_engine_1.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=tokens2,
     input_data=tokens2,
   )
   )
-  resp4 = await inference_engine_2.infer_tensor(
+  resp4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
     input_data=resp3,

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     safe_save(state_dict, path) 
     safe_save(state_dict, path) 
   
   
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     def wrap_infer():
     def wrap_infer():
       x = Tensor(input_data)
       x = Tensor(input_data)

+ 26 - 31
exo/networking/grpc/grpc_peer_handle.py

@@ -12,7 +12,13 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 import json
 import json
-import mlx.core as mx
+import platform
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
+
 
 
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -101,7 +107,7 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
         n_layers=shard.n_layers,
       ),
       ),
       request_id=request_id,
       request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
+      inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
     )
     await self.stub.SendPrompt(request)
     await self.stub.SendPrompt(request)
 
 
@@ -115,7 +121,7 @@ class GRPCPeerHandle(PeerHandle):
       ),
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
       request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
+      inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
     )
     response =await self.stub.SendTensor(request)
     response =await self.stub.SendTensor(request)
 
 
@@ -123,7 +129,7 @@ class GRPCPeerHandle(PeerHandle):
       return None
       return None
 
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-  
+
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.ExampleRequest(
     request = node_service_pb2.ExampleRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
@@ -145,7 +151,7 @@ class GRPCPeerHandle(PeerHandle):
       return loss, grads
       return loss, grads
     else:
     else:
       return loss
       return loss
-  
+
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
@@ -170,10 +176,7 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model,
-        chip=capabilities.chip,
-        memory=capabilities.memory,
-        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       )
       topology.update_node(node_id, device_capabilities)
       topology.update_node(node_id, device_capabilities)
     for node_id, peer_connections in response.peer_graph.items():
     for node_id, peer_connections in response.peer_graph.items():
@@ -197,28 +200,20 @@ class GRPCPeerHandle(PeerHandle):
     proto_inference_state = node_service_pb2.InferenceState()
     proto_inference_state = node_service_pb2.InferenceState()
     other_data = {}
     other_data = {}
     for k, v in inference_state.items():
     for k, v in inference_state.items():
-        if isinstance(v, mx.array):
-            np_array = np.array(v)
-            tensor_data = node_service_pb2.Tensor(
-                tensor_data=np_array.tobytes(),
-                shape=list(np_array.shape),
-                dtype=str(np_array.dtype)
-            )
-            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
-        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
-            tensor_list = node_service_pb2.TensorList()
-            for tensor in v:
-                np_array = np.array(tensor)
-                tensor_data = node_service_pb2.Tensor(
-                    tensor_data=np_array.tobytes(),
-                    shape=list(np_array.shape),
-                    dtype=str(np_array.dtype)
-                )
-                tensor_list.tensors.append(tensor_data)
-            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
-        else:
-            # For non-tensor data, we'll still use JSON
-            other_data[k] = v
+      if isinstance(v, mx.array):
+        np_array = np.array(v)
+        tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
+        proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
+      elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+        tensor_list = node_service_pb2.TensorList()
+        for tensor in v:
+          np_array = np.array(tensor)
+          tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
+          tensor_list.tensors.append(tensor_data)
+        proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
+      else:
+        # For non-tensor data, we'll still use JSON
+        other_data[k] = v
     if other_data:
     if other_data:
       proto_inference_state.other_data_json = json.dumps(other_data)
       proto_inference_state.other_data_json = json.dumps(other_data)
     return proto_inference_state
     return proto_inference_state

+ 23 - 25
exo/networking/grpc/grpc_server.py

@@ -3,13 +3,19 @@ from concurrent import futures
 import numpy as np
 import numpy as np
 from asyncio import CancelledError
 from asyncio import CancelledError
 
 
+import platform
+
 from . import node_service_pb2
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 from . import node_service_pb2_grpc
 from exo import DEBUG
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 from exo.orchestration import Node
 import json
 import json
-import mlx.core as mx
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
 
 
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -60,7 +66,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     )
     prompt = request.prompt
     prompt = request.prompt
     request_id = request.request_id
     request_id = request.request_id
-    inference_state = self.deserialize_inference_state(request.inference_state)
+    inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
     result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
     result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
     if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
@@ -76,13 +82,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
     request_id = request.request_id
 
 
-    inference_state = self.deserialize_inference_state(request.inference_state)
+    inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
 
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
-  
+
   async def SendExample(self, request, context):
   async def SendExample(self, request, context):
     shard = Shard(
     shard = Shard(
       model_id=request.shard.model_id,
       model_id=request.shard.model_id,
@@ -104,7 +110,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     else:
     else:
       loss = await self.node.process_example(shard, example, target, length, train, request_id)
       loss = await self.node.process_example(shard, example, target, length, train, request_id)
       return node_service_pb2.Loss(loss=loss, grads=None)
       return node_service_pb2.Loss(loss=loss, grads=None)
-    
+
   async def CollectTopology(self, request, context):
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth
     max_depth = request.max_depth
     visited = set(request.visited)
     visited = set(request.visited)
@@ -120,12 +126,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       for node_id, cap in topology.nodes.items()
       for node_id, cap in topology.nodes.items()
     }
     }
     peer_graph = {
     peer_graph = {
-      node_id: node_service_pb2.PeerConnections(
-        connections=[
-          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
-          for conn in connections
-        ]
-      )
+      node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
       for node_id, connections in topology.peer_graph.items()
       for node_id, connections in topology.peer_graph.items()
     }
     }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -139,7 +140,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     result = list(result)
     result = list(result)
     if len(img.tensor_data) > 0:
     if len(img.tensor_data) > 0:
-      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
+      result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
     self.node.on_token.trigger_all(request_id, result, is_finished)
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()
 
 
@@ -153,21 +154,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
   async def HealthCheck(self, request, context):
   async def HealthCheck(self, request, context):
     return node_service_pb2.HealthCheckResponse(is_healthy=True)
     return node_service_pb2.HealthCheckResponse(is_healthy=True)
 
 
-  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
+  def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
     inference_state = {}
     inference_state = {}
-    
+
     for k, tensor_data in inference_state_proto.tensor_data.items():
     for k, tensor_data in inference_state_proto.tensor_data.items():
-        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
-        inference_state[k] = mx.array(np_array)
-    
+      np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
+      inference_state[k] = mx.array(np_array)
+
     for k, tensor_list in inference_state_proto.tensor_list_data.items():
     for k, tensor_list in inference_state_proto.tensor_list_data.items():
-        inference_state[k] = [
-            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
-            for tensor in tensor_list.tensors
-        ]
-    
+      inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
+
     if inference_state_proto.other_data_json:
     if inference_state_proto.other_data_json:
-        other_data = json.loads(inference_state_proto.other_data_json)
-        inference_state.update(other_data)
-    
+      other_data = json.loads(inference_state_proto.other_data_json)
+      inference_state.update(other_data)
+
     return inference_state
     return inference_state

+ 2 - 2
exo/orchestration/node.py

@@ -326,7 +326,7 @@ class Node:
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
         else:
         else:
           self.outstanding_requests[request_id] = "preprocessing"
           self.outstanding_requests[request_id] = "preprocessing"
-          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
           self.outstanding_requests[request_id] = "waiting"
           loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
           loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
           self.outstanding_requests[request_id] = "training"
           self.outstanding_requests[request_id] = "training"
@@ -342,7 +342,7 @@ class Node:
           loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
           loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
         else:
         else:
           self.outstanding_requests[request_id] = "preprocessing"
           self.outstanding_requests[request_id] = "preprocessing"
-          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
           self.outstanding_requests[request_id] = "waiting"
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
         self.outstanding_requests.pop(request_id)
         self.outstanding_requests.pop(request_id)

+ 90 - 4
exo/topology/device_capabilities.py

@@ -151,6 +151,8 @@ async def device_capabilities() -> DeviceCapabilities:
     return await mac_device_capabilities()
     return await mac_device_capabilities()
   elif psutil.LINUX:
   elif psutil.LINUX:
     return await linux_device_capabilities()
     return await linux_device_capabilities()
+  elif psutil.WINDOWS:
+    return await windows_device_capabilities()
   else:
   else:
     return DeviceCapabilities(
     return DeviceCapabilities(
       model="Unknown Device",
       model="Unknown Device",
@@ -187,6 +189,8 @@ async def linux_device_capabilities() -> DeviceCapabilities:
 
 
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
 
 
+    pynvml.nvmlShutdown()
+
     return DeviceCapabilities(
     return DeviceCapabilities(
       model=f"Linux Box ({gpu_name})",
       model=f"Linux Box ({gpu_name})",
       chip=gpu_name,
       chip=gpu_name,
@@ -194,13 +198,24 @@ async def linux_device_capabilities() -> DeviceCapabilities:
       flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
       flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
     )
     )
   elif Device.DEFAULT == "AMD":
   elif Device.DEFAULT == "AMD":
-    # TODO AMD support
+    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
+    from pyrsmi import rocml
+
+    rocml.smi_initialize()
+    gpu_name = rocml.smi_get_device_name(0).upper()
+    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+
+    if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
+
+    rocml.smi_shutdown()
+
     return DeviceCapabilities(
     return DeviceCapabilities(
-      model="Linux Box (AMD)",
-      chip="Unknown AMD",
-      memory=psutil.virtual_memory().total // 2**20,
+      model="Linux Box ({gpu_name})",
+      chip={gpu_name},
+      memory=gpu_memory_info.total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
     )
+
   else:
   else:
     return DeviceCapabilities(
     return DeviceCapabilities(
       model=f"Linux Box (Device: {Device.DEFAULT})",
       model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -208,3 +223,74 @@ async def linux_device_capabilities() -> DeviceCapabilities:
       memory=psutil.virtual_memory().total // 2**20,
       memory=psutil.virtual_memory().total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
     )
+
+
+def windows_device_capabilities() -> DeviceCapabilities:
+  import psutil
+
+  def get_gpu_info():
+    import win32com.client  # install pywin32
+
+    wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
+    gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
+
+    gpu_info = []
+    for gpu in gpus:
+      info = {
+        "Name": gpu.Name,
+        "AdapterRAM": gpu.AdapterRAM,  # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
+        "DriverVersion": gpu.DriverVersion,
+        "VideoProcessor": gpu.VideoProcessor
+      }
+      gpu_info.append(info)
+
+    return gpu_info
+
+  gpus_info = get_gpu_info()
+  gpu_names = [gpu['Name'] for gpu in gpus_info]
+
+  contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
+  contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
+
+  if contains_nvidia:
+    import pynvml
+
+    pynvml.nvmlInit()
+    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
+    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
+    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+
+    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+    return DeviceCapabilities(
+      model=f"Windows Box ({gpu_name})",
+      chip=gpu_name,
+      memory=gpu_memory_info.total // 2**20,
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+  elif contains_amd:
+    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
+    from pyrsmi import rocml
+
+    rocml.smi_initialize()
+    gpu_name = rocml.smi_get_device_name(0).upper()
+    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+
+    if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
+
+    rocml.smi_shutdown()
+
+    return DeviceCapabilities(
+      model="Windows Box ({gpu_name})",
+      chip={gpu_name},
+      memory=gpu_memory_info.total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+  else:
+    return DeviceCapabilities(
+      model=f"Windows Box (Device: Unknown)",
+      chip=f"Unknown Chip (Device(s): {gpu_names})",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )

+ 6 - 2
scripts/build_exo.py

@@ -6,6 +6,9 @@ import pkgutil
 
 
 def run():
 def run():
     site_packages = site.getsitepackages()[0]
     site_packages = site.getsitepackages()[0]
+    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+    baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages")
+    
     command = [
     command = [
         f"{sys.executable}", "-m", "nuitka", "exo/main.py",
         f"{sys.executable}", "-m", "nuitka", "exo/main.py",
         "--company-name=exolabs",
         "--company-name=exolabs",
@@ -15,7 +18,8 @@ def run():
         "--standalone",
         "--standalone",
         "--output-filename=exo",
         "--output-filename=exo",
         "--python-flag=no_site",
         "--python-flag=no_site",
-        "--onefile"
+        "--onefile",
+        f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages"
     ]
     ]
 
 
     if sys.platform == "darwin": 
     if sys.platform == "darwin": 
@@ -23,7 +27,7 @@ def run():
             "--macos-app-name=exo",
             "--macos-app-name=exo",
             "--macos-app-mode=gui",
             "--macos-app-mode=gui",
             "--macos-app-version=0.0.1",
             "--macos-app-version=0.0.1",
-            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-signed-app-name=net.exolabs.exo",
             "--include-distribution-meta=mlx",
             "--include-distribution-meta=mlx",
             "--include-module=mlx._reprlib_fix",
             "--include-module=mlx._reprlib_fix",
             "--include-module=mlx._os_warning",
             "--include-module=mlx._os_warning",

+ 37 - 4
setup.py

@@ -1,5 +1,6 @@
 import sys
 import sys
 import platform
 import platform
+import subprocess
 
 
 from setuptools import find_packages, setup
 from setuptools import find_packages, setup
 
 
@@ -11,7 +12,6 @@ install_requires = [
   "grpcio==1.68.0",
   "grpcio==1.68.0",
   "grpcio-tools==1.68.0",
   "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "Jinja2==3.1.4",
-  "netifaces==0.11.0",
   "numpy==2.0.0",
   "numpy==2.0.0",
   "nuitka==2.5.1",
   "nuitka==2.5.1",
   "nvidia-ml-py==12.560.30",
   "nvidia-ml-py==12.560.30",
@@ -23,6 +23,7 @@ install_requires = [
   "pydantic==2.9.2",
   "pydantic==2.9.2",
   "requests==2.32.3",
   "requests==2.32.3",
   "rich==13.7.1",
   "rich==13.7.1",
+  "scapy==2.6.1",
   "tenacity==9.0.0",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
   "tqdm==4.66.4",
   "transformers==4.46.3",
   "transformers==4.46.3",
@@ -32,19 +33,51 @@ install_requires = [
 ]
 ]
 
 
 extras_require = {
 extras_require = {
-  "formatting": [
-    "yapf==0.40.2",
-  ],
+  "formatting": ["yapf==0.40.2",],
   "apple_silicon": [
   "apple_silicon": [
     "mlx==0.21.1",
     "mlx==0.21.1",
     "mlx-lm==0.20.4",
     "mlx-lm==0.20.4",
   ],
   ],
+  "windows": ["pywin32==308",],
+  "nvidia-gpu": ["nvidia-ml-py==12.560.30",],
+  "amd-gpu": ["pyrsmi==0.2.0"],
 }
 }
 
 
 # Check if running on macOS with Apple Silicon
 # Check if running on macOS with Apple Silicon
 if sys.platform.startswith("darwin") and platform.machine() == "arm64":
 if sys.platform.startswith("darwin") and platform.machine() == "arm64":
   install_requires.extend(extras_require["apple_silicon"])
   install_requires.extend(extras_require["apple_silicon"])
 
 
+# Check if running Windows
+if sys.platform.startswith("win32"):
+  install_requires.extend(extras_require["windows"])
+
+
+def _add_gpu_requires():
+  global install_requires
+  # Add Nvidia-GPU
+  try:
+    out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
+    if out.returncode == 0:
+      install_requires.extend(extras_require["nvidia-gpu"])
+  except subprocess.CalledProcessError:
+    pass
+
+  # Add AMD-GPU
+  # This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
+  try:
+    out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
+    if out.returncode == 0:
+      install_requires.extend(extras_require["amd-gpu"])
+  except:
+    out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
+    if out.returncode == 0:
+      install_requires.extend(extras_require["amd-gpu"])
+  finally:
+    pass
+
+
+_add_gpu_requires()
+
 setup(
 setup(
   name="exo",
   name="exo",
   version="0.0.1",
   version="0.0.1",

+ 1 - 1
test/test_tokenizers.py

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 models = []
 models = []
 for model_id in model_cards:
 for model_id in model_cards: