|
@@ -57,6 +57,7 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
|
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.Comparator;
|
|
|
import java.util.EnumSet;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.HashSet;
|
|
@@ -65,6 +66,7 @@ import java.util.Locale;
|
|
|
import java.util.Map;
|
|
|
import java.util.Objects;
|
|
|
import java.util.Set;
|
|
|
+import java.util.TreeSet;
|
|
|
import java.util.concurrent.CountDownLatch;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
import java.util.concurrent.atomic.AtomicReference;
|
|
@@ -90,14 +92,24 @@ public class ElasticInferenceService extends SenderService {
|
|
|
private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class);
|
|
|
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
|
|
|
private static final String SERVICE_NAME = "Elastic";
|
|
|
+
|
|
|
+ // rainbow-sprinkles
|
|
|
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
|
|
|
- static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
|
|
|
+ static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
|
|
|
+
|
|
|
+ // elser-v2
|
|
|
+ static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
|
|
|
+ static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
|
|
|
|
|
|
/**
|
|
|
* The task types that the {@link InferenceAction.Request} can accept.
|
|
|
*/
|
|
|
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
|
|
|
|
|
|
+ private static String defaultEndpointId(String modelId) {
|
|
|
+ return Strings.format(".%s-elastic", modelId);
|
|
|
+ }
|
|
|
+
|
|
|
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
|
|
|
private Configuration configuration;
|
|
|
private final AtomicReference<AuthorizedContent> authRef = new AtomicReference<>(AuthorizedContent.empty());
|
|
@@ -142,6 +154,19 @@ public class ElasticInferenceService extends SenderService {
|
|
|
elasticInferenceServiceComponents
|
|
|
),
|
|
|
MinimalServiceSettings.chatCompletion()
|
|
|
+ ),
|
|
|
+ DEFAULT_ELSER_MODEL_ID_V2,
|
|
|
+ new DefaultModelConfig(
|
|
|
+ new ElasticInferenceServiceSparseEmbeddingsModel(
|
|
|
+ DEFAULT_ELSER_ENDPOINT_ID_V2,
|
|
|
+ TaskType.SPARSE_EMBEDDING,
|
|
|
+ NAME,
|
|
|
+ new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
|
|
|
+ EmptyTaskSettings.INSTANCE,
|
|
|
+ EmptySecretSettings.INSTANCE,
|
|
|
+ elasticInferenceServiceComponents
|
|
|
+ ),
|
|
|
+ MinimalServiceSettings.sparseEmbedding()
|
|
|
)
|
|
|
);
|
|
|
}
|
|
@@ -184,13 +209,13 @@ public class ElasticInferenceService extends SenderService {
|
|
|
|
|
|
configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
|
|
|
|
|
|
- defaultConfigIds().forEach(modelRegistry::addDefaultIds);
|
|
|
+ defaultConfigIds().forEach(modelRegistry::putDefaultIdIfAbsent);
|
|
|
handleRevokedDefaultConfigs(authorizedDefaultModelIds);
|
|
|
}
|
|
|
|
|
|
private Set<String> getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) {
|
|
|
var authorizedModels = auth.getAuthorizedModelIds();
|
|
|
- var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet());
|
|
|
+ var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet());
|
|
|
authorizedDefaultModelIds.retainAll(authorizedModels);
|
|
|
|
|
|
return authorizedDefaultModelIds;
|
|
@@ -218,6 +243,7 @@ public class ElasticInferenceService extends SenderService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ authorizedConfigIds.sort(Comparator.comparing(DefaultConfigId::inferenceId));
|
|
|
return authorizedConfigIds;
|
|
|
}
|
|
|
|
|
@@ -230,6 +256,7 @@ public class ElasticInferenceService extends SenderService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model.getInferenceEntityId()));
|
|
|
return authorizedModels;
|
|
|
}
|
|
|
|