Browse Source

Formatting

Sandesh Bharadwaj 10 months ago
parent
commit
b9eccedc3d

+ 114 - 153
exo/api/chatgpt_api.py

@@ -34,6 +34,7 @@ import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
 
+
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
     self.role = role
@@ -47,7 +48,6 @@ class Message:
     return data
 
 
-
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
     self.model = model
@@ -138,11 +138,7 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 
 def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
   messages = remap_messages(_messages)
-  chat_template_args = {
-    "conversation": [m.to_dict() for m in messages],
-    "tokenize": False,
-    "add_generation_prompt": True
-  }
+  chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
   if tools: chat_template_args["tools"] = tools
 
   prompt = tokenizer.apply_chat_template(**chat_template_args)
@@ -171,8 +167,17 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
+
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
+  def __init__(
+    self,
+    node: Node,
+    inference_engine_classname: str,
+    response_timeout: int = 90,
+    on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
+    default_model: Optional[str] = None,
+    system_prompt: Optional[str] = None
+  ):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -208,7 +213,6 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
-      
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
       self.app.router.add_get("/", self.handle_root)
@@ -219,7 +223,7 @@ class ChatGPTAPI:
     self.app.middlewares.append(self.log_request)
 
   async def handle_quit(self, request):
-    if DEBUG>=1: print("Received quit signal")
+    if DEBUG >= 1: print("Received quit signal")
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     await response.prepare(request)
     await response.write_eof()
@@ -249,61 +253,48 @@ class ChatGPTAPI:
 
   async def handle_model_support(self, request):
     try:
-        response = web.StreamResponse(
-            status=200,
-            reason='OK',
-            headers={
-                'Content-Type': 'text/event-stream',
-                'Cache-Control': 'no-cache',
-                'Connection': 'keep-alive',
-            }
-        )
-        await response.prepare(request)
+      response = web.StreamResponse(status=200, reason='OK', headers={
+        'Content-Type': 'text/event-stream',
+        'Cache-Control': 'no-cache',
+        'Connection': 'keep-alive',
+      })
+      await response.prepare(request)
 
-        async def process_model(model_name, pretty):
-            if model_name in model_cards:
-                model_info = model_cards[model_name]
-
-                if self.inference_engine_classname in model_info.get("repo", {}):
-                    shard = build_base_shard(model_name, self.inference_engine_classname)
-                    if shard:
-                        downloader = HFShardDownloader(quick_check=True)
-                        downloader.current_shard = shard
-                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-                        status = await downloader.get_shard_download_status()
-
-                        download_percentage = status.get("overall") if status else None
-                        total_size = status.get("total_size") if status else None
-                        total_downloaded = status.get("total_downloaded") if status else False
-
-                        model_data = {
-                            model_name: {
-                                "name": pretty,
-                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                                "download_percentage": download_percentage,
-                                "total_size": total_size,
-                                "total_downloaded": total_downloaded
-                            }
-                        }
-
-                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-
-        # Process all models in parallel
-        await asyncio.gather(*[
-            process_model(model_name, pretty)
-            for model_name, pretty in pretty_name.items()
-        ])
-
-        await response.write(b"data: [DONE]\n\n")
-        return response
+      async def process_model(model_name, pretty):
+        if model_name in model_cards:
+          model_info = model_cards[model_name]
+
+          if self.inference_engine_classname in model_info.get("repo", {}):
+            shard = build_base_shard(model_name, self.inference_engine_classname)
+            if shard:
+              downloader = HFShardDownloader(quick_check=True)
+              downloader.current_shard = shard
+              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+              status = await downloader.get_shard_download_status()
+
+              download_percentage = status.get("overall") if status else None
+              total_size = status.get("total_size") if status else None
+              total_downloaded = status.get("total_downloaded") if status else False
+
+              model_data = {
+                model_name: {
+                  "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
+                  "total_downloaded": total_downloaded
+                }
+              }
+
+              await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+
+      # Process all models in parallel
+      await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
+
+      await response.write(b"data: [DONE]\n\n")
+      return response
 
     except Exception as e:
-        print(f"Error in handle_model_support: {str(e)}")
-        traceback.print_exc()
-        return web.json_response(
-            {"detail": f"Server error: {str(e)}"},
-            status=500
-        )
+      print(f"Error in handle_model_support: {str(e)}")
+      traceback.print_exc()
+      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
 
   async def handle_get_models(self, request):
     models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
@@ -472,7 +463,6 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
-  
   async def handle_post_image_generations(self, request):
     data = await request.json()
 
@@ -485,7 +475,7 @@ class ChatGPTAPI:
     shard = build_base_shard(model, self.inference_engine_classname)
     if DEBUG >= 2: print(f"shard: {shard}")
     if not shard:
-        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
+      return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
 
     request_id = str(uuid.uuid4())
     callback_id = f"chatgpt-api-wait-response-{request_id}"
@@ -497,77 +487,73 @@ class ChatGPTAPI:
         img = None
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
 
-
-      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
+      response = web.StreamResponse(status=200, reason='OK', headers={
+        'Content-Type': 'application/octet-stream',
+        "Cache-Control": "no-cache",
+      })
       await response.prepare(request)
 
       def get_progress_bar(current_step, total_steps, bar_length=50):
         # Calculate the percentage of completion
-        percent = float(current_step) / total_steps
+        percent = float(current_step)/total_steps
         # Calculate the number of hashes to display
-        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
-        spaces = ' ' * (bar_length - len(arrow))
-        
+        arrow = '-'*int(round(percent*bar_length) - 1) + '>'
+        spaces = ' '*(bar_length - len(arrow))
+
         # Create the progress bar string
         progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
         return progress_bar
 
       async def stream_image(_request_id: str, result, is_finished: bool):
-          if isinstance(result, list):
-              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
-
-          elif isinstance(result, np.ndarray):
-            im = Image.fromarray(np.array(result))
-            images_folder = get_exo_images_dir()
-            # Save the image to a file
-            image_filename = f"{_request_id}.png"
-            image_path = images_folder / image_filename
-            im.save(image_path)
-            image_url = request.app.router['static_images'].url_for(filename=image_filename)
-            base_url = f"{request.scheme}://{request.host}"
-            # Construct the full URL correctly
-            full_image_url = base_url + str(image_url)
-            
-            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
-            if is_finished:
-              await response.write_eof()
-              
+        if isinstance(result, list):
+          await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
+
+        elif isinstance(result, np.ndarray):
+          im = Image.fromarray(np.array(result))
+          images_folder = get_exo_images_dir()
+          # Save the image to a file
+          image_filename = f"{_request_id}.png"
+          image_path = images_folder/image_filename
+          im.save(image_path)
+          image_url = request.app.router['static_images'].url_for(filename=image_filename)
+          base_url = f"{request.scheme}://{request.host}"
+          # Construct the full URL correctly
+          full_image_url = base_url + str(image_url)
+
+          await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+          if is_finished:
+            await response.write_eof()
 
       stream_task = None
+
       def on_result(_request_id: str, result, is_finished: bool):
-          nonlocal stream_task
-          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
-          return _request_id == request_id and is_finished
+        nonlocal stream_task
+        stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
+        return _request_id == request_id and is_finished
 
       await callback.wait(on_result, timeout=self.response_timeout*10)
-      
+
       if stream_task:
-          # Wait for the stream task to complete before returning
-          await stream_task
+        # Wait for the stream task to complete before returning
+        await stream_task
 
       return response
 
     except Exception as e:
-        if DEBUG >= 2: traceback.print_exc()
-        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-  
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+
   async def handle_delete_model(self, request):
     try:
       model_name = request.match_info.get('model_name')
       if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
 
       if not model_name or model_name not in model_cards:
-        return web.json_response(
-          {"detail": f"Invalid model name: {model_name}"},
-          status=400
-          )
+        return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
 
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard:
-        return web.json_response(
-          {"detail": "Could not build shard for model"},
-          status=400
-        )
+        return web.json_response({"detail": "Could not build shard for model"}, status=400)
 
       repo_id = get_repo(shard.model_id, self.inference_engine_classname)
       if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
@@ -582,38 +568,28 @@ class ChatGPTAPI:
         if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
         try:
           shutil.rmtree(cache_dir)
-          return web.json_response({
-            "status": "success",
-            "message": f"Model {model_name} deleted successfully",
-            "path": str(cache_dir)
-          })
+          return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
         except Exception as e:
-          return web.json_response({
-            "detail": f"Failed to delete model files: {str(e)}"
-          }, status=500)
+          return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
       else:
-        return web.json_response({
-          "detail": f"Model files not found at {cache_dir}"
-        }, status=404)
+        return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
 
     except Exception as e:
-        print(f"Error in handle_delete_model: {str(e)}")
-        traceback.print_exc()
-        return web.json_response({
-            "detail": f"Server error: {str(e)}"
-        }, status=500)
+      print(f"Error in handle_delete_model: {str(e)}")
+      traceback.print_exc()
+      return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
 
   async def handle_get_initial_models(self, request):
     model_data = {}
     for model_name, pretty in pretty_name.items():
-        model_data[model_name] = {
-            "name": pretty,
-            "downloaded": None,  # Initially unknown
-            "download_percentage": None,  # Change from 0 to null
-            "total_size": None,
-            "total_downloaded": None,
-            "loading": True  # Add loading state
-        }
+      model_data[model_name] = {
+        "name": pretty,
+        "downloaded": None,  # Initially unknown
+        "download_percentage": None,  # Change from 0 to null
+        "total_size": None,
+        "total_downloaded": None,
+        "loading": True  # Add loading state
+      }
     return web.json_response(model_data)
 
   async def handle_create_animation(self, request):
@@ -639,17 +615,9 @@ class ChatGPTAPI:
       if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
 
       # Create the animation
-      create_animation_mp4(
-        replacement_image_path,
-        output_path,
-        device_name,
-        prompt_text
-      )
+      create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
 
-      return web.json_response({
-        "status": "success",
-        "output_path": output_path
-      })
+      return web.json_response({"status": "success", "output_path": output_path})
 
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
@@ -665,10 +633,7 @@ class ChatGPTAPI:
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
       asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
-      return web.json_response({
-        "status": "success",
-        "message": f"Download started for model: {model_name}"
-      })
+      return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"error": str(e)}, status=500)
@@ -682,10 +647,7 @@ class ChatGPTAPI:
         return web.json_response({})
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
-      return web.json_response(
-        {"detail": f"Error getting topology: {str(e)}"},
-        status=500
-      )
+      return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
 
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
@@ -696,15 +658,14 @@ class ChatGPTAPI:
   def base64_decode(self, base64_string):
     #decode and reshape image
     if base64_string.startswith('data:image'):
-        base64_string = base64_string.split(',')[1]
+      base64_string = base64_string.split(',')[1]
     image_data = base64.b64decode(base64_string)
     img = Image.open(BytesIO(image_data))
-    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    W, H = (dim - dim%64 for dim in (img.width, img.height))
     if W != img.width or H != img.height:
-        if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
-        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+      if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+      img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
     img = mx.array(np.array(img))
-    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
     img = img[None]
     return img
-  

+ 12 - 13
exo/helpers.py

@@ -245,15 +245,13 @@ def get_all_ip_addresses_and_interfaces():
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return [("localhost", "lo")]
 
+
 async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
   try:
     # Use the shared subprocess_pool
-    output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
-      ['system_profiler', 'SPNetworkDataType', '-json'],
-      capture_output=True,
-      text=True,
-      close_fds=True
-    ).stdout)
+    output = await asyncio.get_running_loop().run_in_executor(
+      subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
+    )
 
     data = json.loads(output)
 
@@ -279,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
 
   return None
 
+
 async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
   # On macOS, try to get interface type using networksetup
   if psutil.MACOS:
@@ -286,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
     if macos_type is not None: return macos_type
 
   # Local container/virtual interfaces
-  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
-    'bridge' in ifname):
+  if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
     return (7, "Container Virtual")
 
   # Loopback interface
@@ -313,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
   # Other physical interfaces
   return (2, "Other")
 
+
 async def shutdown(signal, loop, server):
   """Gracefully shutdown the server and close the asyncio loop."""
   print(f"Received exit signal {signal.name}...")
@@ -332,16 +331,16 @@ def is_frozen():
 
 
 def get_exo_home() -> Path:
-  if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
-  else: docs_folder = Path.home() / "Documents"
+  if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
+  else: docs_folder = Path.home()/"Documents"
   if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
-  exo_folder = docs_folder / "Exo"
+  exo_folder = docs_folder/"Exo"
   if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
   return exo_folder
 
+
 def get_exo_images_dir() -> Path:
   exo_home = get_exo_home()
-  images_dir = exo_home / "Images"
+  images_dir = exo_home/"Images"
   if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
   return images_dir
-  

+ 6 - 4
exo/inference/inference_engine.py

@@ -14,7 +14,7 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     pass
-  
+
   @abstractmethod
   async def sample(self, x: np.ndarray) -> np.ndarray:
     pass
@@ -33,13 +33,13 @@ class InferenceEngine(ABC):
 
   async def save_checkpoint(self, shard: Shard, path: str):
     pass
-  
+
   async def save_session(self, key, value):
     self.session[key] = value
-  
+
   async def clear_session(self):
     self.session.empty()
-  
+
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     tokens = await self.encode(shard, prompt)
     if shard.model_id != 'stable-diffusion-2-1-base':
@@ -50,12 +50,14 @@ class InferenceEngine(ABC):
 
     return output_data, inference_state
 
+
 inference_engine_classes = {
   "mlx": "MLXDynamicShardInferenceEngine",
   "tinygrad": "TinygradDynamicShardInferenceEngine",
   "dummy": "DummyInferenceEngine",
 }
 
+
 def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
   if DEBUG >= 2:
     print(f"get_inference_engine called with: {inference_engine_name}")

+ 21 - 33
exo/networking/grpc/grpc_peer_handle.py

@@ -19,6 +19,7 @@ if platform.system().lower() == "darwin" and platform.machine().lower() == "arm6
 else:
   import numpy as mx
 
+
 class GRPCPeerHandle(PeerHandle):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
     self._id = _id
@@ -42,11 +43,9 @@ class GRPCPeerHandle(PeerHandle):
 
   async def connect(self):
     if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(self.address, options=[
-        ("grpc.max_metadata_size", 32*1024*1024),
-        ('grpc.max_receive_message_length', 32*1024*1024),
-        ('grpc.max_send_message_length', 32*1024*1024)
-      ])
+      self.channel = grpc.aio.insecure_channel(
+        self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
+      )
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
 
@@ -114,7 +113,7 @@ class GRPCPeerHandle(PeerHandle):
       return None
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-  
+
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.ExampleRequest(
       shard=node_service_pb2.Shard(
@@ -136,7 +135,7 @@ class GRPCPeerHandle(PeerHandle):
       return loss, grads
     else:
       return loss
-  
+
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
@@ -171,10 +170,7 @@ class GRPCPeerHandle(PeerHandle):
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
       device_capabilities = DeviceCapabilities(
-        model=capabilities.model,
-        chip=capabilities.chip,
-        memory=capabilities.memory,
-        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
       )
       topology.update_node(node_id, device_capabilities)
     for node_id, peer_connections in response.peer_graph.items():
@@ -198,28 +194,20 @@ class GRPCPeerHandle(PeerHandle):
     proto_inference_state = node_service_pb2.InferenceState()
     other_data = {}
     for k, v in inference_state.items():
-        if isinstance(v, mx.array):
-            np_array = np.array(v)
-            tensor_data = node_service_pb2.Tensor(
-                tensor_data=np_array.tobytes(),
-                shape=list(np_array.shape),
-                dtype=str(np_array.dtype)
-            )
-            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
-        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
-            tensor_list = node_service_pb2.TensorList()
-            for tensor in v:
-                np_array = np.array(tensor)
-                tensor_data = node_service_pb2.Tensor(
-                    tensor_data=np_array.tobytes(),
-                    shape=list(np_array.shape),
-                    dtype=str(np_array.dtype)
-                )
-                tensor_list.tensors.append(tensor_data)
-            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
-        else:
-            # For non-tensor data, we'll still use JSON
-            other_data[k] = v
+      if isinstance(v, mx.array):
+        np_array = np.array(v)
+        tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
+        proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
+      elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+        tensor_list = node_service_pb2.TensorList()
+        for tensor in v:
+          np_array = np.array(tensor)
+          tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
+          tensor_list.tensors.append(tensor_data)
+        proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
+      else:
+        # For non-tensor data, we'll still use JSON
+        other_data[k] = v
     if other_data:
       proto_inference_state.other_data_json = json.dumps(other_data)
     return proto_inference_state

+ 14 - 22
exo/networking/grpc/grpc_server.py

@@ -80,7 +80,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
-  
+
   async def SendExample(self, request, context):
     shard = Shard(
       model_id=request.shard.model_id,
@@ -102,7 +102,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     else:
       loss = await self.node.process_example(shard, example, target, length, train, request_id)
       return node_service_pb2.Loss(loss=loss, grads=None)
-    
+
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth
     visited = set(request.visited)
@@ -118,12 +118,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       for node_id, cap in topology.nodes.items()
     }
     peer_graph = {
-      node_id: node_service_pb2.PeerConnections(
-        connections=[
-          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
-          for conn in connections
-        ]
-      )
+      node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
       for node_id, connections in topology.peer_graph.items()
     }
     if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -137,7 +132,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     result = list(result)
     if len(img.tensor_data) > 0:
-      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
+      result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
 
@@ -151,21 +146,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
   async def HealthCheck(self, request, context):
     return node_service_pb2.HealthCheckResponse(is_healthy=True)
 
-  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
+  def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
     inference_state = {}
-    
+
     for k, tensor_data in inference_state_proto.tensor_data.items():
-        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
-        inference_state[k] = mx.array(np_array)
-    
+      np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
+      inference_state[k] = mx.array(np_array)
+
     for k, tensor_list in inference_state_proto.tensor_list_data.items():
-        inference_state[k] = [
-            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
-            for tensor in tensor_list.tensors
-        ]
-    
+      inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
+
     if inference_state_proto.other_data_json:
-        other_data = json.loads(inference_state_proto.other_data_json)
-        inference_state.update(other_data)
-    
+      other_data = json.loads(inference_state_proto.other_data_json)
+      inference_state.update(other_data)
+
     return inference_state

+ 26 - 25
exo/topology/device_capabilities.py

@@ -195,7 +195,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
     gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
 
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
-    
+
     pynvml.nvmlShutdown()
 
     return DeviceCapabilities(
@@ -207,22 +207,22 @@ def linux_device_capabilities() -> DeviceCapabilities:
   elif Device.DEFAULT == "AMD":
     # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
     from pyrsmi import rocml
-    
+
     rocml.smi_initialize()
     gpu_name = rocml.smi_get_device_name(0).upper()
     gpu_memory_info = rocml.smi_get_device_memory_total(0)
-    
+
     if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
-    
+
     rocml.smi_shutdown()
-      
+
     return DeviceCapabilities(
       model="Linux Box ({gpu_name})",
       chip={gpu_name},
       memory=gpu_memory_info.total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
-    
+
   else:
     return DeviceCapabilities(
       model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -234,30 +234,31 @@ def linux_device_capabilities() -> DeviceCapabilities:
 
 def windows_device_capabilities() -> DeviceCapabilities:
   import psutil
+
   def get_gpu_info():
-    import win32com.client # install pywin32
+    import win32com.client  # install pywin32
 
     wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
     gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
 
     gpu_info = []
     for gpu in gpus:
-        info = {
-            "Name": gpu.Name,
-            "AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
-            "DriverVersion": gpu.DriverVersion,
-            "VideoProcessor": gpu.VideoProcessor
-        }
-        gpu_info.append(info)
-    
+      info = {
+        "Name": gpu.Name,
+        "AdapterRAM": gpu.AdapterRAM,  # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
+        "DriverVersion": gpu.DriverVersion,
+        "VideoProcessor": gpu.VideoProcessor
+      }
+      gpu_info.append(info)
+
     return gpu_info
-    
+
   gpus_info = get_gpu_info()
   gpu_names = [gpu['Name'] for gpu in gpus_info]
-  
-  contains_nvidia = any('nvidia' in gpu_name.lower()for gpu_name in gpu_names)
+
+  contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
   contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
-  
+
   if contains_nvidia:
     import pynvml
 
@@ -266,7 +267,7 @@ def windows_device_capabilities() -> DeviceCapabilities:
     gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
     gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
     gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
-    
+
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
 
     return DeviceCapabilities(
@@ -278,15 +279,15 @@ def windows_device_capabilities() -> DeviceCapabilities:
   elif contains_amd:
     # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
     from pyrsmi import rocml
-    
+
     rocml.smi_initialize()
     gpu_name = rocml.smi_get_device_name(0).upper()
     gpu_memory_info = rocml.smi_get_device_memory_total(0)
-    
+
     if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
-    
+
     rocml.smi_shutdown()
-      
+
     return DeviceCapabilities(
       model="Windows Box ({gpu_name})",
       chip={gpu_name},
@@ -299,4 +300,4 @@ def windows_device_capabilities() -> DeviceCapabilities:
       chip=f"Unknown Chip (Device(s): {gpu_names})",
       memory=psutil.virtual_memory().total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
-    )
+    )