ring_memory_weighted_partitioning_strategy.py 1.2 KB

123456789101112131415161718192021222324252627
  1. from .partitioning_strategy import PartitioningStrategy
  2. from inference.shard import Shard
  3. from .topology import Topology
  4. class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
  5. def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
  6. # Get all nodes from the topology and include the current node
  7. nodes = list(topology.all_nodes())
  8. nodes.append((self.id, None, node_stats))
  9. # Sort nodes by their IDs
  10. nodes.sort(key=lambda x: x[0])
  11. # Calculate the total memory of all nodes
  12. total_memory = sum(node[2]['memory'] for node in nodes)
  13. # Calculate the number of layers to assign to each node proportional to its memory
  14. layers_per_node = {node[0]: (node[2]['memory'] / total_memory) * current_shard.n_layers for node in nodes}
  15. # Find the successor node
  16. node_ids = [node[0] for node in nodes]
  17. current_index = node_ids.index(self.id)
  18. successor_index = (current_index + 1) % len(node_ids)
  19. successor_id = node_ids[successor_index]
  20. # Return the Shard calculated for the successor
  21. return Shard(successor_id, layers_per_node[successor_id])