瀏覽代碼

linux device capabilities

Alex Cheema 9 月之前
父節點
當前提交
ce46f00059
共有 3 個文件被更改,包括 28 次插入4 次删除
  1. 23 2
      exo/topology/device_capabilities.py
  2. 3 2
      main.py
  3. 2 0
      requirements.txt

+ 23 - 2
exo/topology/device_capabilities.py

@@ -12,8 +12,8 @@ def device_capabilities() -> DeviceCapabilities:
     system = platform.system()
     if system == 'Darwin':
         return mac_device_capabilities()
-    # elif system == 'Linux':
-    #     return linux_device_capabilities()
+    elif system == 'Linux':
+        return linux_device_capabilities()
     # elif system == 'Windows':
     #     return windows_device_capabilities()
     else:
@@ -37,3 +37,24 @@ def mac_device_capabilities() -> DeviceCapabilities:
 
     # Assuming static values for other attributes for demonstration
     return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory)
+
+def linux_device_capabilities() -> DeviceCapabilities:
+    import psutil
+    from tinygrad import Device
+    
+    print(f"tinygrad {Device.DEFAULT=}")
+    if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT="GPU":
+        import pynvml, pynvml_utils
+        pynvml.nvmlInit()
+        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+        gpu_name = pynvml.nvmlDeviceGetName(handle)
+        gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+
+        print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+        return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total)
+    elif Device.DEFAULT == "AMD":
+        # TODO AMD support
+        return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total)
+    else:
+        return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20)

+ 3 - 2
main.py

@@ -3,6 +3,7 @@ import asyncio
 import signal
 import uuid
 import platform
+import psutil
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
@@ -21,8 +22,8 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 args = parser.parse_args()
 
-print(f"Starting {platform.system()=}")
-if platform.system() == "Darwin":
+print(f"Starting {platform.system()=} {psutil.virtual_memory()=}")
+if psutil.MACOS:
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
     inference_engine = MLXDynamicShardInferenceEngine()
 else:

+ 2 - 0
requirements.txt

@@ -6,6 +6,8 @@ mlx==0.15.1; sys.platform == "darwin"
 mlx-lm==0.14.3; sys.platform == "darwin"
 numpy==2.0.0
 protobuf==5.27.1
+psutil==6.0.0
+pynvml==11.5.3
 requests==2.32.3
 safetensors==0.4.3
 tiktoken==0.7.0