|
@@ -3,13 +3,13 @@ import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
from ..inference_engine import InferenceEngine
|
|
|
from .stateful_model import StatefulModel
|
|
|
-from .sharded_utils import load_shard, get_image_from_str
|
|
|
+from .sharded_utils import load_shard
|
|
|
from ..shard import Shard
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
import asyncio
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
-from functools import partial
|
|
|
+
|
|
|
def sample_logits(
|
|
|
logits: mx.array,
|
|
|
temp: float = 0.0,
|
|
@@ -28,7 +28,6 @@ def sample_logits(
|
|
|
token = top_p_sampling(logits, top_p, temp)
|
|
|
else:
|
|
|
token = mx.random.categorical(logits*(1/temp))
|
|
|
-
|
|
|
return token
|
|
|
|
|
|
class MLXDynamicShardInferenceEngine(InferenceEngine):
|