- from dataclasses import dataclass
- @dataclass
- class Shard:
- model_id: str
- start_layer: int
- end_layer: int
- n_layers: int
- def is_first_layer(self) -> bool:
- return self.start_layer == 0
- def is_last_layer(self) -> bool:
- return self.end_layer == self.n_layers - 1
|