|
@@ -1,6 +1,7 @@
|
|
import numpy as np
|
|
import numpy as np
|
|
import mlx.core as mx
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.nn as nn
|
|
|
|
+from mlx_lm.sample_utils import top_p_sampling
|
|
from ..inference_engine import InferenceEngine
|
|
from ..inference_engine import InferenceEngine
|
|
from .stateful_model import StatefulModel
|
|
from .stateful_model import StatefulModel
|
|
from .sharded_utils import load_shard
|
|
from .sharded_utils import load_shard
|