Bläddra i källkod

add pydantic dependency

Ian Paul 6 månader sedan
förälder
incheckning
6b48a936b9

+ 10 - 25
exo/networking/manual/network_topology_config.py

@@ -1,20 +1,17 @@
-from typing import Dict
-from dataclasses import dataclass
-
 import json
+from typing import Dict
+from pydantic import BaseModel, ValidationError
 
 from exo.topology.device_capabilities import DeviceCapabilities
 
 
-@dataclass
-class PeerConfig:
+class PeerConfig(BaseModel):
   address: str
   port: int
   device_capabilities: DeviceCapabilities
 
 
-@dataclass
-class NetworkTopology:
+class NetworkTopology(BaseModel):
   """Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
 
   peers: Dict[str, PeerConfig]
@@ -26,23 +23,11 @@ class NetworkTopology:
   def from_path(cls, path: str) -> "NetworkTopology":
     try:
       with open(path, "r") as f:
-        config = json.load(f)
-    except FileNotFoundError:
-      raise FileNotFoundError(f"Config file not found at {path}")
-    except json.JSONDecodeError as e:
-      raise json.JSONDecodeError(f"Error decoding JSON data from {path}: {e}", e.doc, e.pos)
+        config_data = f.read()
+    except FileNotFoundError as e:
+      raise FileNotFoundError(f"Config file not found at {path}") from e
 
     try:
-      peers = {}
-      for node_id, peer_data in config["peers"].items():
-        device_capabilities = DeviceCapabilities(**peer_data["device_capabilities"])
-        peer_config = PeerConfig(address=peer_data["address"], port=peer_data["port"], device_capabilities=device_capabilities)
-        peers[node_id] = peer_config
-
-      networking_config = cls(peers=peers)
-    except KeyError as e:
-      raise KeyError(f"Missing required key in config file: {e}")
-    except TypeError as e:
-      raise TypeError(f"Error parsing networking config from {path}: {e}")
-
-    return networking_config
+      return cls.model_validate_json(config_data)
+    except ValidationError as e:
+      raise ValueError(f"Error validating network topology config from {path}: {e}") from e

+ 9 - 5
exo/networking/manual/test_network_topology_config.py

@@ -1,5 +1,7 @@
 import unittest
 import json
+
+from pydantic import ValidationError
 from exo.networking.manual.network_topology_config import NetworkTopology
 
 root_path = "./exo/networking/manual/test_data/"
@@ -9,17 +11,19 @@ class TestNetworkTopologyConfig(unittest.TestCase):
   def test_from_path_invalid_path(self):
     with self.assertRaises(FileNotFoundError) as e:
       NetworkTopology.from_path("invalid_path")
-    self.assertEqual(e.exception.args[0], "Config file not found at invalid_path")
+    self.assertEqual(str(e.exception), "Config file not found at invalid_path")
 
   def test_from_path_invalid_json(self):
-    with self.assertRaises(json.JSONDecodeError) as e:
+    with self.assertRaises(ValueError) as e:
       NetworkTopology.from_path(root_path + "invalid_json.json")
-    self.assertEqual(e.exception.args[0], "Error decoding JSON data from ./exo/networking/manual/test_data/invalid_json.json: Expecting value: line 1 column 1 (char 0): line 1 column 1 (char 0)")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("1 validation error for NetworkTopology\n  Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
 
   def test_from_path_invalid_config(self):
-    with self.assertRaises(KeyError) as e:
+    with self.assertRaises(ValueError) as e:
       NetworkTopology.from_path(root_path + "invalid_config.json")
-    self.assertEqual(e.exception.args[0], "Missing required key in config file: 'port'")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("port\n  Field required", str(e.exception))
 
   def test_from_path_valid(self):
     config = NetworkTopology.from_path(root_path + "test_config.json")

+ 6 - 7
exo/topology/device_capabilities.py

@@ -1,13 +1,13 @@
+from typing import Any
+from pydantic import BaseModel
 from exo import DEBUG
-from dataclasses import dataclass, asdict
 import subprocess
 import psutil
 
 TFLOPS = 1.00
 
 
-@dataclass
-class DeviceFlops:
+class DeviceFlops(BaseModel):
   # units of TFLOPS
   fp32: float
   fp16: float
@@ -17,11 +17,10 @@ class DeviceFlops:
     return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
 
   def to_dict(self):
-    return asdict(self)
+    return self.model_dump()
 
 
-@dataclass
-class DeviceCapabilities:
+class DeviceCapabilities(BaseModel):
   model: str
   chip: str
   memory: int
@@ -30,7 +29,7 @@ class DeviceCapabilities:
   def __str__(self):
     return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
 
-  def __post_init__(self):
+  def model_post_init(self, __context: Any) -> None:
     if isinstance(self.flops, dict):
       self.flops = DeviceFlops(**self.flops)
 

+ 1 - 0
setup.py

@@ -17,6 +17,7 @@ install_requires = [
   "prometheus-client==0.20.0",
   "protobuf==5.27.1",
   "psutil==6.0.0",
+  "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
   "safetensors==0.4.3",