|
@@ -222,14 +222,13 @@ class ChatGPTAPI:
|
|
|
if model_name in model_cards:
|
|
|
model_info = model_cards[model_name]
|
|
|
|
|
|
- # Get required engines
|
|
|
- required_engines = list(dict.fromkeys([
|
|
|
- inference_engine_classes.get(engine_name, None)
|
|
|
- for engine_list in self.node.topology_inference_engines_pool
|
|
|
- for engine_name in engine_list
|
|
|
- if engine_name is not None
|
|
|
- ] + [self.inference_engine_classname]))
|
|
|
-
|
|
|
+ # Get required engines from the node's topology directly
|
|
|
+ required_engines = list(dict.fromkeys(
|
|
|
+ [engine_name for engine_list in self.node.topology_inference_engines_pool
|
|
|
+ for engine_name in engine_list
|
|
|
+ if engine_name is not None] +
|
|
|
+ [self.inference_engine_classname]
|
|
|
+ ))
|
|
|
# Check if model supports required engines
|
|
|
if all(map(lambda engine: engine in model_info["repo"], required_engines)):
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|