1
0
Эх сурвалжийг харах

Changes required to detect AMD GPUs

DeftDawg 1 сар өмнө
parent
commit
f98d9bac53

+ 6 - 9
exo/topology/device_capabilities.py

@@ -198,22 +198,19 @@ async def linux_device_capabilities() -> DeviceCapabilities:
       flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
     )
   elif Device.DEFAULT == "AMD":
-    # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
-    from pyrsmi import rocml
+    import pyamdgpuinfo
 
-    rocml.smi_initialize()
-    gpu_name = rocml.smi_get_device_name(0).upper()
-    gpu_memory_info = rocml.smi_get_device_memory_total(0)
+    gpu_raw_info = pyamdgpuinfo.get_gpu(0)
+    gpu_name = gpu_raw_info.name
+    gpu_memory_info = gpu_raw_info.memory_info["vram_size"]
 
     if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
 
-    rocml.smi_shutdown()
-
     return DeviceCapabilities(
-      model="Linux Box ({gpu_name})",
+      model="Linux Box (" + gpu_name + ")",
       chip=gpu_name,
       memory=gpu_memory_info // 2**20,
-      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
     )
 
   else:

+ 1 - 0
setup.py

@@ -20,6 +20,7 @@ install_requires = [
   "prometheus-client==0.20.0",
   "protobuf==5.28.1",
   "psutil==6.0.0",
+  "pyamdgpuinfo==2.1.6;platform_system=='Linux'",
   "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",