|
@@ -4,12 +4,13 @@
|
|
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
* 2.0.
|
|
|
*
|
|
|
- * this file was contributed to by a generative AI
|
|
|
+ * this file has been contributed to by a Generative AI
|
|
|
*/
|
|
|
|
|
|
package org.elasticsearch.xpack.inference.services.elser;
|
|
|
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
+import org.elasticsearch.ResourceNotFoundException;
|
|
|
import org.elasticsearch.TransportVersion;
|
|
|
import org.elasticsearch.TransportVersions;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
@@ -24,10 +25,12 @@ import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.xpack.core.ClientHelper;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
|
|
|
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.List;
|
|
@@ -73,16 +76,8 @@ public class ElserMlNodeService implements InferenceService {
|
|
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
|
|
var serviceSettingsBuilder = ElserMlNodeServiceSettings.fromMap(serviceSettingsMap);
|
|
|
|
|
|
- // choose a default model version based on the cluster architecture
|
|
|
if (serviceSettingsBuilder.getModelVariant() == null) {
|
|
|
- boolean homogenous = modelArchitectures.size() == 1;
|
|
|
- if (homogenous && modelArchitectures.iterator().next().equals("linux-x86_64")) {
|
|
|
- // Use the hardware optimized model
|
|
|
- serviceSettingsBuilder.setModelVariant(ELSER_V2_MODEL_LINUX_X86);
|
|
|
- } else {
|
|
|
- // default to the platform-agnostic model
|
|
|
- serviceSettingsBuilder.setModelVariant(ELSER_V2_MODEL);
|
|
|
- }
|
|
|
+ serviceSettingsBuilder.setModelVariant(selectDefaultModelVersionBasedOnClusterArchitecture(modelArchitectures));
|
|
|
}
|
|
|
|
|
|
Map<String, Object> taskSettingsMap;
|
|
@@ -102,6 +97,18 @@ public class ElserMlNodeService implements InferenceService {
|
|
|
return new ElserMlNodeModel(modelId, taskType, NAME, serviceSettingsBuilder.build(), taskSettings);
|
|
|
}
|
|
|
|
|
|
+ private static String selectDefaultModelVersionBasedOnClusterArchitecture(Set<String> modelArchitectures) {
|
|
|
+ // choose a default model version based on the cluster architecture
|
|
|
+ boolean homogenous = modelArchitectures.size() == 1;
|
|
|
+ if (homogenous && modelArchitectures.iterator().next().equals("linux-x86_64")) {
|
|
|
+ // Use the hardware optimized model
|
|
|
+ return ELSER_V2_MODEL_LINUX_X86;
|
|
|
+ } else {
|
|
|
+ // default to the platform-agnostic model
|
|
|
+ return ELSER_V2_MODEL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public ElserMlNodeModel parsePersistedConfigWithSecrets(
|
|
|
String modelId,
|
|
@@ -157,11 +164,33 @@ public class ElserMlNodeService implements InferenceService {
|
|
|
startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads());
|
|
|
startRequest.setWaitForState(STARTED);
|
|
|
|
|
|
- client.execute(
|
|
|
- StartTrainedModelDeploymentAction.INSTANCE,
|
|
|
- startRequest,
|
|
|
- listener.delegateFailureAndWrap((l, r) -> l.onResponse(Boolean.TRUE))
|
|
|
- );
|
|
|
+ client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, elserNotDownloadedListener(model, listener));
|
|
|
+ }
|
|
|
+
|
|
|
+ private static ActionListener<CreateTrainedModelAssignmentAction.Response> elserNotDownloadedListener(
|
|
|
+ Model model,
|
|
|
+ ActionListener<Boolean> listener
|
|
|
+ ) {
|
|
|
+ return new ActionListener<>() {
|
|
|
+ @Override
|
|
|
+ public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
|
|
|
+ listener.onResponse(Boolean.TRUE);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onFailure(Exception e) {
|
|
|
+ if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
|
|
+ listener.onFailure(
|
|
|
+ new ResourceNotFoundException(
|
|
|
+ "Could not start the ELSER service as the ELSER model for this platform cannot be found."
|
|
|
+ + " ELSER needs to be downloaded before it can be started"
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ listener.onFailure(e);
|
|
|
+ }
|
|
|
+ };
|
|
|
}
|
|
|
|
|
|
@Override
|