shard.py 470 B

1234567891011121314151617181920212223
  1. from dataclasses import dataclass
  2. @dataclass
  3. class Shard:
  4. model_id: str
  5. start_layer: int
  6. end_layer: int
  7. n_layers: int
  8. def is_first_layer(self) -> bool:
  9. return self.start_layer == 0
  10. def is_last_layer(self) -> bool:
  11. return self.end_layer == self.n_layers - 1
  12. def to_dict(self) -> dict:
  13. return {
  14. "model_id": self.model_id,
  15. "start_layer": self.start_layer,
  16. "end_layer": self.end_layer,
  17. "n_layers": self.n_layers,
  18. }