shard.py 300 B

1234567891011121314
  1. from dataclasses import dataclass
  2. @dataclass
  3. class Shard:
  4. model_id: str
  5. start_layer: int
  6. end_layer: int
  7. n_layers: int
  8. def is_first_layer(self) -> bool:
  9. return self.start_layer == 0
  10. def is_last_layer(self) -> bool:
  11. return self.end_layer == self.n_layers - 1