瀏覽代碼

Add AMD GPU querying + Windows device capabilities

Sandesh Bharadwaj 3 月之前
父節點
當前提交
df3624d27a
共有 1 個文件被更改,包括 89 次插入4 次删除
  1. 89 4
      exo/topology/device_capabilities.py

+ 89 - 4
exo/topology/device_capabilities.py

@@ -149,6 +149,8 @@ def device_capabilities() -> DeviceCapabilities:
     return mac_device_capabilities()
   elif psutil.LINUX:
     return linux_device_capabilities()
+  elif psutil.WINDOWS:
+    return windows_device_capabilities()
   else:
     return DeviceCapabilities(
       model="Unknown Device",
@@ -193,6 +195,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
     gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
 
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+    
+    pynvml.nvmlShutdown()
 
     return DeviceCapabilities(
       model=f"Linux Box ({gpu_name})",
@@ -201,13 +205,24 @@ def linux_device_capabilities() -> DeviceCapabilities:
       flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
     )
   elif Device.DEFAULT == "AMD":
-    # TODO AMD support
+    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
+    from pyrsmi import rocml
+    
+    rocml.smi_initialize()
+    gpu_name = rocml.smi_get_device_name(0).upper()
+    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+    
+    if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
+    
+    rocml.smi_shutdown()
+      
     return DeviceCapabilities(
-      model="Linux Box (AMD)",
-      chip="Unknown AMD",
-      memory=psutil.virtual_memory().total // 2**20,
+      model="Linux Box ({gpu_name})",
+      chip={gpu_name},
+      memory=gpu_memory_info.total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
+    
   else:
     return DeviceCapabilities(
       model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -215,3 +230,73 @@ def linux_device_capabilities() -> DeviceCapabilities:
       memory=psutil.virtual_memory().total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
+
+
+def windows_device_capabilities() -> DeviceCapabilities:
+  import psutil
+  def get_gpu_info():
+    import win32com.client # install pywin32
+
+    wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
+    gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
+
+    gpu_info = []
+    for gpu in gpus:
+        info = {
+            "Name": gpu.Name,
+            "AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
+            "DriverVersion": gpu.DriverVersion,
+            "VideoProcessor": gpu.VideoProcessor
+        }
+        gpu_info.append(info)
+    
+    return gpu_info
+    
+  gpus_info = get_gpu_info()
+  gpu_names = [gpu['Name'] for gpu in gpus_info]
+  
+  contains_nvidia = any('nvidia' in gpu_name.lower()for gpu_name in gpu_names)
+  contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
+  
+  if contains_nvidia:
+    import pynvml
+
+    pynvml.nvmlInit()
+    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
+    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
+    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+    
+    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+    return DeviceCapabilities(
+      model=f"Windows Box ({gpu_name})",
+      chip=gpu_name,
+      memory=gpu_memory_info.total // 2**20,
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+  elif contains_amd:
+    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
+    from pyrsmi import rocml
+    
+    rocml.smi_initialize()
+    gpu_name = rocml.smi_get_device_name(0).upper()
+    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+    
+    if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
+    
+    rocml.smi_shutdown()
+      
+    return DeviceCapabilities(
+      model="Windows Box ({gpu_name})",
+      chip={gpu_name},
+      memory=gpu_memory_info.total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+  else:
+    return DeviceCapabilities(
+      model=f"Windows Box (Device: Unknown)",
+      chip=f"Unknown Chip (Device(s): {gpu_names})",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )