1
0

partitioning_strategy.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from abc import ABC, abstractmethod
  2. from typing import List, Dict
  3. from dataclasses import dataclass
  4. from .topology import Topology
  5. from exo.inference.shard import Shard
  6. from exo.topology.device_capabilities import device_capabilities
  7. import asyncio
  8. # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
  9. @dataclass
  10. class Partition:
  11. node_id: str
  12. start: float
  13. end: float
  14. class PartitioningStrategy(ABC):
  15. @abstractmethod
  16. def partition(self, topology: Topology) -> List[Partition]:
  17. pass
  18. def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
  19. shards = []
  20. for i, partition in enumerate(partitions):
  21. start_layer = int(partition.start*num_layers)
  22. end_layer = int(partition.end*num_layers) - 1
  23. # Ensure the last partition covers up to num_layers - 1
  24. if i == len(partitions) - 1:
  25. end_layer = num_layers - 1
  26. # Ensure no empty shards
  27. if start_layer <= end_layer:
  28. shards.append(Shard(model_id, start_layer, end_layer, num_layers))
  29. # Ensure full coverage
  30. if shards and shards[-1].end_layer < num_layers - 1:
  31. shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
  32. return shards