@@ -8,7 +8,7 @@ from mlx_lm.sample_utils import top_p_sampling
from ..shard import Shard
-
+# TODO: support a speculative model so we can parallelise compute across devices
class StatefulShardedModel:
def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
self.shard = shard