浏览代码

Merge remote-tracking branch 'origin/main' into runners2

Alex Cheema 3 月之前
父节点
当前提交
461e4f37cb

+ 148 - 152
exo/api/chatgpt_api.py

@@ -21,7 +21,13 @@ from PIL import Image
 import numpy as np
 import base64
 from io import BytesIO
-import mlx.core as mx
+import platform
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
+
 import tempfile
 from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
@@ -29,6 +35,7 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
 from collections import defaultdict
 
+
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
     self.role = role
@@ -42,7 +49,6 @@ class Message:
     return data
 
 
-
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
     self.model = model
@@ -133,16 +139,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 
 def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
   messages = remap_messages(_messages)
-  chat_template_args = {
-    "conversation": [m.to_dict() for m in messages],
-    "tokenize": False,
-    "add_generation_prompt": True
-  }
-  if tools: chat_template_args["tools"] = tools
-
-  prompt = tokenizer.apply_chat_template(**chat_template_args)
-  print(f"!!! Prompt: {prompt}")
-  return prompt
+  chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
+  if tools: 
+    chat_template_args["tools"] = tools
+
+  try:
+    prompt = tokenizer.apply_chat_template(**chat_template_args)
+    if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
+    return prompt
+  except UnicodeEncodeError:
+    # Handle Unicode encoding by ensuring everything is UTF-8
+    chat_template_args["conversation"] = [
+      {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v 
+       for k, v in m.to_dict().items()}
+      for m in messages
+    ]
+    prompt = tokenizer.apply_chat_template(**chat_template_args)
+    if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
+    return prompt
 
 
 def parse_message(data: dict):
@@ -166,8 +180,17 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
+
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
+  def __init__(
+    self,
+    node: Node,
+    inference_engine_classname: str,
+    response_timeout: int = 90,
+    on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
+    default_model: Optional[str] = None,
+    system_prompt: Optional[str] = None
+  ):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -209,18 +232,22 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
-      
+    # Add static routes
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
       self.app.router.add_get("/", self.handle_root)
       self.app.router.add_static("/", self.static_dir, name="static")
-      self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
+      
+    # Always add images route, regardless of compilation status
+    self.images_dir = get_exo_images_dir()
+    self.images_dir.mkdir(parents=True, exist_ok=True)
+    self.app.router.add_static('/images/', self.images_dir, name='static_images')
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
 
   async def handle_quit(self, request):
-    if DEBUG>=1: print("Received quit signal")
+    if DEBUG >= 1: print("Received quit signal")
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     await response.prepare(request)
     await response.write_eof()
@@ -250,61 +277,48 @@ class ChatGPTAPI:
 
   async def handle_model_support(self, request):
     try:
-        response = web.StreamResponse(
-            status=200,
-            reason='OK',
-            headers={
-                'Content-Type': 'text/event-stream',
-                'Cache-Control': 'no-cache',
-                'Connection': 'keep-alive',
-            }
-        )
-        await response.prepare(request)
+      response = web.StreamResponse(status=200, reason='OK', headers={
+        'Content-Type': 'text/event-stream',
+        'Cache-Control': 'no-cache',
+        'Connection': 'keep-alive',
+      })
+      await response.prepare(request)
+
+      async def process_model(model_name, pretty):
+        if model_name in model_cards:
+          model_info = model_cards[model_name]
+
+          if self.inference_engine_classname in model_info.get("repo", {}):
+            shard = build_base_shard(model_name, self.inference_engine_classname)
+            if shard:
+              downloader = HFShardDownloader(quick_check=True)
+              downloader.current_shard = shard
+              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+              status = await downloader.get_shard_download_status()
+
+              download_percentage = status.get("overall") if status else None
+              total_size = status.get("total_size") if status else None
+              total_downloaded = status.get("total_downloaded") if status else False
 
-        async def process_model(model_name, pretty):
-            if model_name in model_cards:
-                model_info = model_cards[model_name]
-
-                if self.inference_engine_classname in model_info.get("repo", {}):
-                    shard = build_base_shard(model_name, self.inference_engine_classname)
-                    if shard:
-                        downloader = HFShardDownloader(quick_check=True)
-                        downloader.current_shard = shard
-                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-                        status = await downloader.get_shard_download_status()
-
-                        download_percentage = status.get("overall") if status else None
-                        total_size = status.get("total_size") if status else None
-                        total_downloaded = status.get("total_downloaded") if status else False
-
-                        model_data = {
-                            model_name: {
-                                "name": pretty,
-                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                                "download_percentage": download_percentage,
-                                "total_size": total_size,
-                                "total_downloaded": total_downloaded
-                            }
-                        }
-
-                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-
-        # Process all models in parallel
-        await asyncio.gather(*[
-            process_model(model_name, pretty)
-            for model_name, pretty in pretty_name.items()
-        ])
-
-        await response.write(b"data: [DONE]\n\n")
-        return response
+              model_data = {
+                model_name: {
+                  "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
+                  "total_downloaded": total_downloaded
+                }
+              }
+
+              await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+
+      # Process all models in parallel
+      await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
+
+      await response.write(b"data: [DONE]\n\n")
+      return response
 
     except Exception as e:
-        print(f"Error in handle_model_support: {str(e)}")
-        traceback.print_exc()
-        return web.json_response(
-            {"detail": f"Server error: {str(e)}"},
-            status=500
-        )
+      print(f"Error in handle_model_support: {str(e)}")
+      traceback.print_exc()
+      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
 
   async def handle_get_models(self, request):
     models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
@@ -466,7 +480,6 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
-  
   async def handle_post_image_generations(self, request):
     data = await request.json()
 
@@ -479,7 +492,7 @@ class ChatGPTAPI:
     shard = build_base_shard(model, self.inference_engine_classname)
     if DEBUG >= 2: print(f"shard: {shard}")
     if not shard:
-        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
+      return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
 
     request_id = str(uuid.uuid4())
     callback_id = f"chatgpt-api-wait-response-{request_id}"
@@ -491,77 +504,85 @@ class ChatGPTAPI:
         img = None
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
 
-
-      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
+      response = web.StreamResponse(status=200, reason='OK', headers={
+        'Content-Type': 'application/octet-stream',
+        "Cache-Control": "no-cache",
+      })
       await response.prepare(request)
 
       def get_progress_bar(current_step, total_steps, bar_length=50):
         # Calculate the percentage of completion
-        percent = float(current_step) / total_steps
+        percent = float(current_step)/total_steps
         # Calculate the number of hashes to display
-        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
-        spaces = ' ' * (bar_length - len(arrow))
-        
+        arrow = '-'*int(round(percent*bar_length) - 1) + '>'
+        spaces = ' '*(bar_length - len(arrow))
+
         # Create the progress bar string
         progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
         return progress_bar
 
       async def stream_image(_request_id: str, result, is_finished: bool):
-          if isinstance(result, list):
-              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
+        if isinstance(result, list):
+          await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
 
-          elif isinstance(result, np.ndarray):
+        elif isinstance(result, np.ndarray):
+          try:
             im = Image.fromarray(np.array(result))
-            images_folder = get_exo_images_dir()
             # Save the image to a file
             image_filename = f"{_request_id}.png"
-            image_path = images_folder / image_filename
+            image_path = self.images_dir/image_filename
             im.save(image_path)
-            image_url = request.app.router['static_images'].url_for(filename=image_filename)
-            base_url = f"{request.scheme}://{request.host}"
-            # Construct the full URL correctly
-            full_image_url = base_url + str(image_url)
             
-            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            # Get URL for the saved image
+            try:
+              image_url = request.app.router['static_images'].url_for(filename=image_filename)
+              base_url = f"{request.scheme}://{request.host}"
+              full_image_url = base_url + str(image_url)
+              
+              await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            except KeyError as e:
+              if DEBUG >= 2: print(f"Error getting image URL: {e}")
+              # Fallback to direct file path if URL generation fails
+              await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            
             if is_finished:
               await response.write_eof()
-              
+            
+          except Exception as e:
+            if DEBUG >= 2: print(f"Error processing image: {e}")
+            if DEBUG >= 2: traceback.print_exc()
+            await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
 
       stream_task = None
+
       def on_result(_request_id: str, result, is_finished: bool):
-          nonlocal stream_task
-          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
-          return _request_id == request_id and is_finished
+        nonlocal stream_task
+        stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
+        return _request_id == request_id and is_finished
 
       await callback.wait(on_result, timeout=self.response_timeout*10)
-      
+
       if stream_task:
-          # Wait for the stream task to complete before returning
-          await stream_task
+        # Wait for the stream task to complete before returning
+        await stream_task
 
       return response
 
     except Exception as e:
-        if DEBUG >= 2: traceback.print_exc()
-        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-  
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+
   async def handle_delete_model(self, request):
     try:
       model_name = request.match_info.get('model_name')
       if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
 
       if not model_name or model_name not in model_cards:
-        return web.json_response(
-          {"detail": f"Invalid model name: {model_name}"},
-          status=400
-          )
+        return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
 
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard:
-        return web.json_response(
-          {"detail": "Could not build shard for model"},
-          status=400
-        )
+        return web.json_response({"detail": "Could not build shard for model"}, status=400)
 
       repo_id = get_repo(shard.model_id, self.inference_engine_classname)
       if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
@@ -576,38 +597,28 @@ class ChatGPTAPI:
         if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
         try:
           shutil.rmtree(cache_dir)
-          return web.json_response({
-            "status": "success",
-            "message": f"Model {model_name} deleted successfully",
-            "path": str(cache_dir)
-          })
+          return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
         except Exception as e:
-          return web.json_response({
-            "detail": f"Failed to delete model files: {str(e)}"
-          }, status=500)
+          return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
       else:
-        return web.json_response({
-          "detail": f"Model files not found at {cache_dir}"
-        }, status=404)
+        return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
 
     except Exception as e:
-        print(f"Error in handle_delete_model: {str(e)}")
-        traceback.print_exc()
-        return web.json_response({
-            "detail": f"Server error: {str(e)}"
-        }, status=500)
+      print(f"Error in handle_delete_model: {str(e)}")
+      traceback.print_exc()
+      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
 
   async def handle_get_initial_models(self, request):
     model_data = {}
     for model_name, pretty in pretty_name.items():
-        model_data[model_name] = {
-            "name": pretty,
-            "downloaded": None,  # Initially unknown
-            "download_percentage": None,  # Change from 0 to null
-            "total_size": None,
-            "total_downloaded": None,
-            "loading": True  # Add loading state
-        }
+      model_data[model_name] = {
+        "name": pretty,
+        "downloaded": None,  # Initially unknown
+        "download_percentage": None,  # Change from 0 to null
+        "total_size": None,
+        "total_downloaded": None,
+        "loading": True  # Add loading state
+      }
     return web.json_response(model_data)
 
   async def handle_create_animation(self, request):
@@ -633,17 +644,9 @@ class ChatGPTAPI:
       if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
 
       # Create the animation
-      create_animation_mp4(
-        replacement_image_path,
-        output_path,
-        device_name,
-        prompt_text
-      )
+      create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
 
-      return web.json_response({
-        "status": "success",
-        "output_path": output_path
-      })
+      return web.json_response({"status": "success", "output_path": output_path})
 
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
@@ -659,10 +662,7 @@ class ChatGPTAPI:
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
-      return web.json_response({
-        "status": "success",
-        "message": f"Download started for model: {model_name}"
-      })
+      return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"error": str(e)}, status=500)
@@ -676,10 +676,7 @@ class ChatGPTAPI:
         return web.json_response({})
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
-      return web.json_response(
-        {"detail": f"Error getting topology: {str(e)}"},
-        status=500
-      )
+      return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
 
   async def handle_token(self, request_id: str, token: int, is_finished: bool):
     await self.token_queues[request_id].put((token, is_finished))
@@ -693,15 +690,14 @@ class ChatGPTAPI:
   def base64_decode(self, base64_string):
     #decode and reshape image
     if base64_string.startswith('data:image'):
-        base64_string = base64_string.split(',')[1]
+      base64_string = base64_string.split(',')[1]
     image_data = base64.b64decode(base64_string)
     img = Image.open(BytesIO(image_data))
-    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    W, H = (dim - dim%64 for dim in (img.width, img.height))
     if W != img.width or H != img.height:
-        if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
-        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+      if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+      img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
     img = mx.array(np.array(img))
-    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
     img = img[None]
     return img
-  

+ 12 - 5
exo/apputil/anim.py

@@ -2,6 +2,7 @@ from PIL import Image, ImageDraw, ImageFont, ImageFilter
 import os
 import numpy as np
 import cv2
+import sys
 
 def draw_rounded_rectangle(draw, coords, radius, fill):
   left, top, right, bottom = coords
@@ -80,14 +81,20 @@ def create_animation_mp4(
     font = ImageFont.load_default()
     promptfont = ImageFont.load_default()
 
+  # Get the base directory for images when running as a bundled app
+  if hasattr(sys, '_MEIPASS'):
+    base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
+  else:
+    base_dir = os.path.join(os.path.dirname(__file__), "baseimages")
+
   # Process first frame
-  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
+  base_img = Image.open(os.path.join(base_dir, "image1.png"))
   draw = ImageDraw.Draw(base_img)
   draw_centered_text_rounded(draw, device_name, font, device_coords)
   frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
 
   # Process second frame with typing animation
-  base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
+  base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
   for i in range(len(prompt_text) + 1):
     current_frame = base_img2.copy()
     draw = ImageDraw.Draw(current_frame)
@@ -101,7 +108,7 @@ def create_animation_mp4(
 
   # Create blur sequence
   replacement_img = Image.open(replacement_image_path)
-  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
+  base_img = Image.open(os.path.join(base_dir, "image3.png"))
   blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
 
   for i, blur_amount in enumerate(blur_steps):
@@ -123,7 +130,7 @@ def create_animation_mp4(
     frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
 
   # Create and add final frame (image4)
-  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
+  final_base = Image.open(os.path.join(base_dir, "image4.png"))
   draw = ImageDraw.Draw(final_base)
 
   draw_centered_text_rounded(draw, device_name, font, device_coords)
@@ -158,4 +165,4 @@ def create_animation_mp4(
       out.write(frame_array)
     
     out.release()
-    print(f"Video saved successfully to {output_path}")
+    print(f"Video saved successfully to {output_path}")

+ 25 - 25
exo/helpers.py

@@ -7,7 +7,8 @@ import random
 import platform
 import psutil
 import uuid
-import netifaces
+from scapy.all import get_if_addr, get_if_list
+import re
 import subprocess
 from pathlib import Path
 import tempfile
@@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
 def get_all_ip_addresses_and_interfaces():
   try:
     ip_addresses = []
-    for interface in netifaces.interfaces():
-      ifaddresses = netifaces.ifaddresses(interface)
-      if netifaces.AF_INET in ifaddresses:
-        for link in ifaddresses[netifaces.AF_INET]:
-          ip = link['addr']
-          ip_addresses.append((ip, interface))
+    for interface in get_if_list():
+      ip = get_if_addr(interface)
+      # Include all addresses, including loopback
+      # Filter out link-local addresses
+      if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
+        # Remove "\\Device\\NPF_" prefix from interface name
+        simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
+        ip_addresses.append((ip, simplified_interface))
     return list(set(ip_addresses))
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return [("localhost", "lo")]
 
+
 async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
   try:
     # Use the shared subprocess_pool
-    output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
-      ['system_profiler', 'SPNetworkDataType', '-json'],
-      capture_output=True,
-      text=True,
-      close_fds=True
-    ).stdout)
+    output = await asyncio.get_running_loop().run_in_executor(
+      subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
+    )
 
     data = json.loads(output)
 
@@ -276,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
 
   return None
 
+
 async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
   # On macOS, try to get interface type using networksetup
   if psutil.MACOS:
@@ -283,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
     if macos_type is not None: return macos_type
 
   # Local container/virtual interfaces
-  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
-    'bridge' in ifname):
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
     return (7, "Container Virtual")
 
   # Loopback interface
@@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
   # Other physical interfaces
   return (2, "Other")
 
+
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""
   print(f"Received exit signal {signal.name}...")
@@ -353,18 +355,16 @@ async def get_mac_system_info() -> Tuple[str, str, int]:
         return "Unknown Model", "Unknown Chip", 0
 
 def get_exo_home() -> Path:
-  if os.name == "nt":  # Check if the OS is Windows
-    docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
-  else:
-    docs_folder = Path.home() / "Documents"
-  exo_folder = docs_folder / "Exo"
-  if not exo_folder.exists():
-    exo_folder.mkdir()
+  if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
+  else: docs_folder = Path.home()/"Documents"
+  if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
+  exo_folder = docs_folder/"Exo"
+  if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
   return exo_folder
 
+
 def get_exo_images_dir() -> Path:
   exo_home = get_exo_home()
-  images_dir = exo_home / "Images"
-  if not images_dir.exists():
-    images_dir.mkdir()
+  images_dir = exo_home/"Images"
+  if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
   return images_dir

+ 5 - 5
exo/inference/debug_inference_engine.py

@@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
 
-  next_resp_full = await inference_engine_1.infer_tensor(
+  next_resp_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=token_full,
   )
 
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
+  resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
   )
   token2 = await inference_engine_2.sample(resp2)
-  resp3 = await inference_engine_1.infer_tensor(
+  resp3, _ = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
     input_data=token2,
   )
-  resp4 = await inference_engine_2.infer_tensor(
+  resp4, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,

+ 2 - 2
exo/inference/dummy_inference_engine.py

@@ -25,9 +25,9 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
-    return input_data + 1 if self.shard.is_last_layer() else input_data
+    return input_data + 1 if self.shard.is_last_layer() else input_data, None
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard: return

+ 10 - 7
exo/inference/inference_engine.py

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG  # Make sure to import DEBUG
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
+from exo.download.shard_download import ShardDownloader
 
 
 class InferenceEngine(ABC):
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     pass
-  
+
   @abstractmethod
   async def sample(self, x: np.ndarray) -> np.ndarray:
     pass
@@ -23,7 +24,7 @@ class InferenceEngine(ABC):
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     pass
 
   @abstractmethod
@@ -32,14 +33,14 @@ class InferenceEngine(ABC):
 
   async def save_checkpoint(self, shard: Shard, path: str):
     pass
-  
+
   async def save_session(self, key, value):
     self.session[key] = value
-  
+
   async def clear_session(self):
     self.session.empty()
-  
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
+
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     tokens = await self.encode(shard, prompt)
     if shard.model_id != 'stable-diffusion-2-1-base':
       x = tokens.reshape(1, -1)
@@ -49,13 +50,15 @@ class InferenceEngine(ABC):
 
     return output_data, inference_state
 
+
 inference_engine_classes = {
   "mlx": "MLXDynamicShardInferenceEngine",
   "tinygrad": "TinygradDynamicShardInferenceEngine",
   "dummy": "DummyInferenceEngine",
 }
 
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":

+ 3 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -67,15 +67,15 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     self.model.load_weights(path)
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
     if self.model.model_type != 'StableDiffusionPipeline':
-      output_data = self.model(x, **state, **inference_state)
+      output_data = self.model(x, **state, **(inference_state or {}))
     else:
-      output_data, inference_state = self.model(x, **state, **inference_state)
+      output_data, inference_state = self.model(x, **state, **(inference_state or {}))
     output_data = np.array(output_data, copy=False)
     return output_data, inference_state
 

+ 5 - 11
exo/inference/test_dummy_inference_engine.py

@@ -1,22 +1,16 @@
 import pytest
-import json
 import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 
 
-class MockShardDownloader:
-  async def ensure_shard(self, shard):
-    pass
-
-
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
-  engine = DummyInferenceEngine(MockShardDownloader())
+  engine = DummyInferenceEngine()
   test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
   test_prompt = "This is a test prompt"
 
-  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+  result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
 
   print(f"Inference result shape: {result.shape}")
 
@@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
   # Initialize the DummyInferenceEngine
-  engine = DummyInferenceEngine(MockShardDownloader())
+  engine = DummyInferenceEngine()
 
   # Create a test shard
   shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
 
   # Test infer_prompt
-  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+  output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"
 
   # Test infer_tensor
   input_tensor = np.array([[1, 2, 3]])
-  output = await engine.infer_tensor("test_id", shard, input_tensor)
+  output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
 
   assert isinstance(output, np.ndarray), "Output should be a numpy array"
   assert output.ndim == 2, "Output should be 2-dimensional"

+ 6 - 6
exo/inference/test_inference_engine.py

@@ -11,30 +11,30 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
   token_full = await inference_engine_1.sample(resp_full)
   token_full = token_full.reshape(1, -1)
-  next_resp_full = await inference_engine_1.infer_tensor(
+  next_resp_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=token_full,
   )
 
   pp = n_layers // 2
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
+  resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
   )
   tokens2 = await inference_engine_1.sample(resp2)
   tokens2 = tokens2.reshape(1, -1)
-  resp3 = await inference_engine_1.infer_tensor(
+  resp3, _ = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     input_data=tokens2,
   )
-  resp4 = await inference_engine_2.infer_tensor(
+  resp4, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     safe_save(state_dict, path) 
   
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
     def wrap_infer():
       x = Tensor(input_data)

+ 26 - 31
exo/networking/grpc/grpc_peer_handle.py

@@ -12,7 +12,13 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 import json
-import mlx.core as mx
+import platform
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
+
 
 class GRPCPeerHandle(PeerHandle):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -101,7 +107,7 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
+      inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
     await self.stub.SendPrompt(request)
 
@@ -115,7 +121,7 @@ class GRPCPeerHandle(PeerHandle):
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
+      inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
     response =await self.stub.SendTensor(request)
 
@@ -123,7 +129,7 @@ class GRPCPeerHandle(PeerHandle):
       return None
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-  
+
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.ExampleRequest(
       shard=node_service_pb2.Shard(
@@ -145,7 +151,7 @@ class GRPCPeerHandle(PeerHandle):
       return loss, grads
     else:
       return loss
-  
+
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
@@ -170,10 +176,7 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model,
-        chip=capabilities.chip,
-        memory=capabilities.memory,
-        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       topology.update_node(node_id, device_capabilities)
     for node_id, peer_connections in response.peer_graph.items():
@@ -197,28 +200,20 @@ class GRPCPeerHandle(PeerHandle):
     proto_inference_state = node_service_pb2.InferenceState()
     other_data = {}
     for k, v in inference_state.items():
-        if isinstance(v, mx.array):
-            np_array = np.array(v)
-            tensor_data = node_service_pb2.Tensor(
-                tensor_data=np_array.tobytes(),
-                shape=list(np_array.shape),
-                dtype=str(np_array.dtype)
-            )
-            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
-        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
-            tensor_list = node_service_pb2.TensorList()
-            for tensor in v:
-                np_array = np.array(tensor)
-                tensor_data = node_service_pb2.Tensor(
-                    tensor_data=np_array.tobytes(),
-                    shape=list(np_array.shape),
-                    dtype=str(np_array.dtype)
-                )
-                tensor_list.tensors.append(tensor_data)
-            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
-        else:
-            # For non-tensor data, we'll still use JSON
-            other_data[k] = v
+      if isinstance(v, mx.array):
+        np_array = np.array(v)
+        tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
+        proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
+      elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+        tensor_list = node_service_pb2.TensorList()
+        for tensor in v:
+          np_array = np.array(tensor)
+          tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
+          tensor_list.tensors.append(tensor_data)
+        proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
+      else:
+        # For non-tensor data, we'll still use JSON
+        other_data[k] = v
     if other_data:
       proto_inference_state.other_data_json = json.dumps(other_data)
     return proto_inference_state

+ 23 - 25
exo/networking/grpc/grpc_server.py

@@ -3,13 +3,19 @@ from concurrent import futures
 import numpy as np
 from asyncio import CancelledError
 
+import platform
+
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 import json
-import mlx.core as mx
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -60,7 +66,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     prompt = request.prompt
     request_id = request.request_id
-    inference_state = self.deserialize_inference_state(request.inference_state)
+    inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
     result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
     if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
@@ -76,13 +82,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
 
-    inference_state = self.deserialize_inference_state(request.inference_state)
+    inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
-  
+
   async def SendExample(self, request, context):
     shard = Shard(
       model_id=request.shard.model_id,
@@ -104,7 +110,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     else:
       loss = await self.node.process_example(shard, example, target, length, train, request_id)
       return node_service_pb2.Loss(loss=loss, grads=None)
-    
+
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth
     visited = set(request.visited)
@@ -120,12 +126,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       for node_id, cap in topology.nodes.items()
     }
     peer_graph = {
-      node_id: node_service_pb2.PeerConnections(
-        connections=[
-          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
-          for conn in connections
-        ]
-      )
+      node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
       for node_id, connections in topology.peer_graph.items()
     }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -139,7 +140,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     result = list(result)
     if len(img.tensor_data) > 0:
-      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
+      result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
 
@@ -153,21 +154,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
   async def HealthCheck(self, request, context):
     return node_service_pb2.HealthCheckResponse(is_healthy=True)
 
-  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
+  def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
     inference_state = {}
-    
+
     for k, tensor_data in inference_state_proto.tensor_data.items():
-        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
-        inference_state[k] = mx.array(np_array)
-    
+      np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
+      inference_state[k] = mx.array(np_array)
+
     for k, tensor_list in inference_state_proto.tensor_list_data.items():
-        inference_state[k] = [
-            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
-            for tensor in tensor_list.tensors
-        ]
-    
+      inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
+
     if inference_state_proto.other_data_json:
-        other_data = json.loads(inference_state_proto.other_data_json)
-        inference_state.update(other_data)
-    
+      other_data = json.loads(inference_state_proto.other_data_json)
+      inference_state.update(other_data)
+
     return inference_state

+ 2 - 2
exo/orchestration/node.py

@@ -326,7 +326,7 @@ class Node:
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
         else:
           self.outstanding_requests[request_id] = "preprocessing"
-          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
           loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
           self.outstanding_requests[request_id] = "training"
@@ -342,7 +342,7 @@ class Node:
           loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
         else:
           self.outstanding_requests[request_id] = "preprocessing"
-          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
         self.outstanding_requests.pop(request_id)

+ 90 - 4
exo/topology/device_capabilities.py

@@ -151,6 +151,8 @@ async def device_capabilities() -> DeviceCapabilities:
     return await mac_device_capabilities()
   elif psutil.LINUX:
     return await linux_device_capabilities()
+  elif psutil.WINDOWS:
+    return await windows_device_capabilities()
   else:
     return DeviceCapabilities(
       model="Unknown Device",
@@ -187,6 +189,8 @@ async def linux_device_capabilities() -> DeviceCapabilities:
 
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
 
+    pynvml.nvmlShutdown()
+
     return DeviceCapabilities(
       model=f"Linux Box ({gpu_name})",
       chip=gpu_name,
@@ -194,13 +198,24 @@ async def linux_device_capabilities() -> DeviceCapabilities:
       flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
     )
   elif Device.DEFAULT == "AMD":
-    # TODO AMD support
+    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
+    from pyrsmi import rocml
+
+    rocml.smi_initialize()
+    gpu_name = rocml.smi_get_device_name(0).upper()
+    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+
+    if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
+
+    rocml.smi_shutdown()
+
     return DeviceCapabilities(
-      model="Linux Box (AMD)",
-      chip="Unknown AMD",
-      memory=psutil.virtual_memory().total // 2**20,
+      model="Linux Box ({gpu_name})",
+      chip={gpu_name},
+      memory=gpu_memory_info.total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
+
   else:
     return DeviceCapabilities(
       model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -208,3 +223,74 @@ async def linux_device_capabilities() -> DeviceCapabilities:
       memory=psutil.virtual_memory().total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
+
+
+def windows_device_capabilities() -> DeviceCapabilities:
+  import psutil
+
+  def get_gpu_info():
+    import win32com.client  # install pywin32
+
+    wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
+    gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
+
+    gpu_info = []
+    for gpu in gpus:
+      info = {
+        "Name": gpu.Name,
+        "AdapterRAM": gpu.AdapterRAM,  # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
+        "DriverVersion": gpu.DriverVersion,
+        "VideoProcessor": gpu.VideoProcessor
+      }
+      gpu_info.append(info)
+
+    return gpu_info
+
+  gpus_info = get_gpu_info()
+  gpu_names = [gpu['Name'] for gpu in gpus_info]
+
+  contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
+  contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
+
+  if contains_nvidia:
+    import pynvml
+
+    pynvml.nvmlInit()
+    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
+    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
+    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+
+    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+    return DeviceCapabilities(
+      model=f"Windows Box ({gpu_name})",
+      chip=gpu_name,
+      memory=gpu_memory_info.total // 2**20,
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+  elif contains_amd:
+    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
+    from pyrsmi import rocml
+
+    rocml.smi_initialize()
+    gpu_name = rocml.smi_get_device_name(0).upper()
+    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+
+    if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
+
+    rocml.smi_shutdown()
+
+    return DeviceCapabilities(
+      model="Windows Box ({gpu_name})",
+      chip={gpu_name},
+      memory=gpu_memory_info.total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+  else:
+    return DeviceCapabilities(
+      model=f"Windows Box (Device: Unknown)",
+      chip=f"Unknown Chip (Device(s): {gpu_names})",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )

+ 6 - 2
scripts/build_exo.py

@@ -6,6 +6,9 @@ import pkgutil
 
 def run():
     site_packages = site.getsitepackages()[0]
+    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+    baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages")
+    
     command = [
         f"{sys.executable}", "-m", "nuitka", "exo/main.py",
         "--company-name=exolabs",
@@ -15,7 +18,8 @@ def run():
         "--standalone",
         "--output-filename=exo",
         "--python-flag=no_site",
-        "--onefile"
+        "--onefile",
+        f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages"
     ]
 
     if sys.platform == "darwin": 
@@ -23,7 +27,7 @@ def run():
             "--macos-app-name=exo",
             "--macos-app-mode=gui",
             "--macos-app-version=0.0.1",
-            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-signed-app-name=net.exolabs.exo",
             "--include-distribution-meta=mlx",
             "--include-module=mlx._reprlib_fix",
             "--include-module=mlx._os_warning",

+ 37 - 4
setup.py

@@ -1,5 +1,6 @@
 import sys
 import platform
+import subprocess
 
 from setuptools import find_packages, setup
 
@@ -11,7 +12,6 @@ install_requires = [
   "grpcio==1.68.0",
   "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
-  "netifaces==0.11.0",
   "numpy==2.0.0",
   "nuitka==2.5.1",
   "nvidia-ml-py==12.560.30",
@@ -23,6 +23,7 @@ install_requires = [
   "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
+  "scapy==2.6.1",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
   "transformers==4.46.3",
@@ -32,19 +33,51 @@ install_requires = [
 ]
 
 extras_require = {
-  "formatting": [
-    "yapf==0.40.2",
-  ],
+  "formatting": ["yapf==0.40.2",],
   "apple_silicon": [
     "mlx==0.21.1",
     "mlx-lm==0.20.4",
   ],
+  "windows": ["pywin32==308",],
+  "nvidia-gpu": ["nvidia-ml-py==12.560.30",],
+  "amd-gpu": ["pyrsmi==0.2.0"],
 }
 
 # Check if running on macOS with Apple Silicon
 if sys.platform.startswith("darwin") and platform.machine() == "arm64":
   install_requires.extend(extras_require["apple_silicon"])
 
+# Check if running Windows
+if sys.platform.startswith("win32"):
+  install_requires.extend(extras_require["windows"])
+
+
+def _add_gpu_requires():
+  global install_requires
+  # Add Nvidia-GPU
+  try:
+    out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
+    if out.returncode == 0:
+      install_requires.extend(extras_require["nvidia-gpu"])
+  except subprocess.CalledProcessError:
+    pass
+
+  # Add AMD-GPU
+  # This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
+  try:
+    out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
+    if out.returncode == 0:
+      install_requires.extend(extras_require["amd-gpu"])
+  except:
+    out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
+    if out.returncode == 0:
+      install_requires.extend(extras_require["amd-gpu"])
+  finally:
+    pass
+
+
+_add_gpu_requires()
+
 setup(
   name="exo",
   version="0.0.1",

+ 1 - 1
test/test_tokenizers.py

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 models = []
 for model_id in model_cards: