partitioning_strategy.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from abc import ABC, abstractmethod
  2. from typing import List
  3. from dataclasses import dataclass
  4. from .topology import Topology
  5. from exo.inference.shard import Shard
  6. # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
  7. @dataclass
  8. class Partition:
  9. node_id: str
  10. start: float
  11. end: float
  12. class PartitioningStrategy(ABC):
  13. @abstractmethod
  14. def partition(self, topology: Topology) -> List[Partition]:
  15. pass
  16. def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
  17. shards = []
  18. for i, partition in enumerate(partitions):
  19. start_layer = int(partition.start*num_layers)
  20. end_layer = int(partition.end*num_layers) - 1
  21. # Ensure the last partition covers up to num_layers - 1
  22. if i == len(partitions) - 1:
  23. end_layer = num_layers - 1
  24. # Ensure no empty shards
  25. if start_layer <= end_layer:
  26. shards.append(Shard(model_id, start_layer, end_layer, num_layers))
  27. # Ensure full coverage
  28. if shards and shards[-1].end_layer < num_layers - 1:
  29. shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
  30. return shards