Răsfoiți Sursa

sort topology by memory descending (works well for now to workaround #12

Alex Cheema 9 luni în urmă
părinte
comite
d4f55002ea

+ 2 - 2
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -7,12 +7,12 @@ from .partitioning_strategy import Partition
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
     def partition(self, topology: Topology) -> List[Partition]:
         nodes = list(topology.all_nodes())
-        nodes.sort(key=lambda x: x[0])
+        nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
         total_memory = sum(node[1].memory for node in nodes)
         partitions = []
         start = 0
         for node in nodes:
-            end = start + (node[1].memory / total_memory)
+            end = round(start + (node[1].memory / total_memory), 5)
             partitions.append(Partition(node[0], start, end))
             start = end
         return partitions

+ 6 - 6
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -9,9 +9,9 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
         # triangle
         # node1 -> node2 -> node3 -> node1
         topology = Topology()
-        topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=100))
-        topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=300))
-        topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=600))
+        topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000))
+        topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000))
+        topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000))
         topology.add_edge('node1', 'node2')
         topology.add_edge('node2', 'node3')
         topology.add_edge('node3', 'node1')
@@ -22,9 +22,9 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
 
         self.assertEqual(len(partitions), 3)
         self.assertEqual(partitions, [
-            Partition('node1', 0.0, 0.1),
-            Partition('node2', 0.1, 0.4),
-            Partition('node3', 0.4, 1.0)
+            Partition('node3', 0.0, 0.6),
+            Partition('node1', 0.6, 0.9),
+            Partition('node2', 0.9, 1.0),
         ])
 
 if __name__ == '__main__':