瀏覽代碼

add missing top_p_sampling import

Alex Cheema 8 月之前
父節點
當前提交
6659a18e94
共有 1 個文件被更改,包括 1 次插入0 次删除
  1. 1 0
      exo/inference/mlx/sharded_inference_engine.py

+ 1 - 0
exo/inference/mlx/sharded_inference_engine.py

@@ -1,6 +1,7 @@
 import numpy as np
 import numpy as np
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
+from mlx_lm.sample_utils import top_p_sampling
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
 from .stateful_model import StatefulModel
 from .stateful_model import StatefulModel
 from .sharded_utils import load_shard
 from .sharded_utils import load_shard