|
@@ -18,7 +18,6 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|
|
import org.elasticsearch.inference.ChunkingOptions;
|
|
|
import org.elasticsearch.inference.ChunkingSettings;
|
|
|
import org.elasticsearch.inference.EmptySettingsConfiguration;
|
|
|
-import org.elasticsearch.inference.InferenceService;
|
|
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
|
|
import org.elasticsearch.inference.InferenceServiceResults;
|
|
|
import org.elasticsearch.inference.InputType;
|
|
@@ -51,6 +50,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.Alib
|
|
|
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
|
|
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
|
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
|
|
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
|
|
|
|
|
|
import java.util.EnumSet;
|
|
|
import java.util.HashMap;
|
|
@@ -60,7 +60,6 @@ import java.util.stream.Stream;
|
|
|
|
|
|
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
|
|
|
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
|
|
|
-import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
|
|
@@ -332,68 +331,39 @@ public class AlibabaCloudSearchService extends SenderService {
|
|
|
*/
|
|
|
@Override
|
|
|
public void checkModelConfig(Model model, ActionListener<Model> listener) {
|
|
|
+ // TODO: Remove this function once all services have been updated to use the new model validators
|
|
|
+ ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
|
|
if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) {
|
|
|
- ServiceUtils.getEmbeddingSize(
|
|
|
- model,
|
|
|
- this,
|
|
|
- listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
|
|
|
+ var serviceSettings = embeddingsModel.getServiceSettings();
|
|
|
+
|
|
|
+ var updatedServiceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings(
|
|
|
+ new AlibabaCloudSearchServiceSettings(
|
|
|
+ serviceSettings.getCommonSettings().modelId(),
|
|
|
+ serviceSettings.getCommonSettings().getHost(),
|
|
|
+ serviceSettings.getCommonSettings().getWorkspaceName(),
|
|
|
+ serviceSettings.getCommonSettings().getHttpSchema(),
|
|
|
+ serviceSettings.getCommonSettings().rateLimitSettings()
|
|
|
+ ),
|
|
|
+ SimilarityMeasure.DOT_PRODUCT,
|
|
|
+ embeddingSize,
|
|
|
+ serviceSettings.getMaxInputTokens()
|
|
|
);
|
|
|
+
|
|
|
+ return new AlibabaCloudSearchEmbeddingsModel(embeddingsModel, updatedServiceSettings);
|
|
|
} else {
|
|
|
- checkAlibabaCloudSearchServiceConfig(model, this, listener);
|
|
|
+ throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private AlibabaCloudSearchEmbeddingsModel updateModelWithEmbeddingDetails(AlibabaCloudSearchEmbeddingsModel model, int embeddingSize) {
|
|
|
- AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings(
|
|
|
- new AlibabaCloudSearchServiceSettings(
|
|
|
- model.getServiceSettings().getCommonSettings().modelId(),
|
|
|
- model.getServiceSettings().getCommonSettings().getHost(),
|
|
|
- model.getServiceSettings().getCommonSettings().getWorkspaceName(),
|
|
|
- model.getServiceSettings().getCommonSettings().getHttpSchema(),
|
|
|
- model.getServiceSettings().getCommonSettings().rateLimitSettings()
|
|
|
- ),
|
|
|
- SimilarityMeasure.DOT_PRODUCT,
|
|
|
- embeddingSize,
|
|
|
- model.getServiceSettings().getMaxInputTokens()
|
|
|
- );
|
|
|
-
|
|
|
- return new AlibabaCloudSearchEmbeddingsModel(model, serviceSettings);
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
public TransportVersion getMinimalSupportedVersion() {
|
|
|
return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED;
|
|
|
}
|
|
|
|
|
|
- /**
|
|
|
- * For other models except of text embedding
|
|
|
- * check the model's service settings and task settings
|
|
|
- *
|
|
|
- * @param model The new model
|
|
|
- * @param service The inferenceService
|
|
|
- * @param listener The listener
|
|
|
- */
|
|
|
- private void checkAlibabaCloudSearchServiceConfig(Model model, InferenceService service, ActionListener<Model> listener) {
|
|
|
- String input = ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT;
|
|
|
- String query = model.getTaskType().equals(TaskType.RERANK) ? ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY : null;
|
|
|
-
|
|
|
- service.infer(
|
|
|
- model,
|
|
|
- query,
|
|
|
- List.of(input),
|
|
|
- false,
|
|
|
- Map.of(),
|
|
|
- InputType.INGEST,
|
|
|
- DEFAULT_TIMEOUT,
|
|
|
- listener.delegateFailureAndWrap((delegate, r) -> {
|
|
|
- listener.onResponse(model);
|
|
|
- })
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
- private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT = "input";
|
|
|
- private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY = "query";
|
|
|
-
|
|
|
public static class Configuration {
|
|
|
public static InferenceServiceConfiguration get() {
|
|
|
return configuration.getOrCompute();
|