Browse Source

[ML] CustomService adding template validation prior to request flow (#129591) (#129672)

* Adding template validation prior to request flow

* Fixing tests

* Narrowing exception
Jonathan Buttner 3 tháng trước cách đây
mục cha
commit
21d4691952

+ 6 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java

@@ -51,7 +51,12 @@ public class ValidatingSubstitutor {
         Matcher matcher = VARIABLE_PLACEHOLDER_PATTERN.matcher(substitutedString);
         if (matcher.find()) {
             throw new IllegalStateException(
-                Strings.format("Found placeholder [%s] in field [%s] after replacement call", matcher.group(), field)
+                Strings.format(
+                    "Found placeholder [%s] in field [%s] after replacement call, "
+                        + "please check that all templates have a corresponding field definition.",
+                    matcher.group(),
+                    field
+                )
             );
         }
     }

+ 22 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

@@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.SenderService;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -55,6 +56,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNot
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
 
 public class CustomService extends SenderService {
+
     public static final String NAME = "custom";
     private static final String SERVICE_NAME = "Custom";
 
@@ -101,12 +103,32 @@ public class CustomService extends SenderService {
             throwIfNotEmptyMap(serviceSettingsMap, NAME);
             throwIfNotEmptyMap(taskSettingsMap, NAME);
 
+            validateConfiguration(model);
+
             parsedModelListener.onResponse(model);
         } catch (Exception e) {
             parsedModelListener.onFailure(e);
         }
     }
 
+    /**
+     * This does some initial validation with mock inputs to determine if any templates are missing a field to fill them.
+     */
+    private static void validateConfiguration(CustomModel model) {
+        String query = null;
+        if (model.getTaskType() == TaskType.RERANK) {
+            query = "test query";
+        }
+
+        try {
+            new CustomRequest(query, List.of("test input"), model).createHttpRequest();
+        } catch (IllegalStateException e) {
+            var validationException = new ValidationException();
+            validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage()));
+            throw validationException;
+        }
+    }
+
     private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
         if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
             return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));

+ 21 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutorTests.java

@@ -37,20 +37,38 @@ public class ValidatingSubstitutorTests extends ESTestCase {
             var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "key2", "value2"), "${", "}");
             var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super:${key}", "setting"));
 
-            assertThat(exception.getMessage(), is("Found placeholder [${key}] in field [setting] after replacement call"));
+            assertThat(
+                exception.getMessage(),
+                is(
+                    "Found placeholder [${key}] in field [setting] after replacement call, "
+                        + "please check that all templates have a corresponding field definition."
+                )
+            );
         }
         // only reports the first placeholder pattern
         {
             var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "some_key2", "value2"), "${", "}");
             var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super, ${key}, ${key2}", "setting"));
 
-            assertThat(exception.getMessage(), is("Found placeholder [${key}] in field [setting] after replacement call"));
+            assertThat(
+                exception.getMessage(),
+                is(
+                    "Found placeholder [${key}] in field [setting] after replacement call, "
+                        + "please check that all templates have a corresponding field definition."
+                )
+            );
         }
         {
             var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "key2", "value2"), "${", "}");
             var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super:${     \\/\tkey\"}", "setting"));
 
-            assertThat(exception.getMessage(), is("Found placeholder [${     \\/\tkey\"}] in field [setting] after replacement call"));
+            assertThat(
+                exception.getMessage(),
+                is(
+                    "Found placeholder [${     \\/\tkey\"}] in field [setting] after replacement call,"
+                        + " please check that all templates have a corresponding field definition."
+                )
+            );
         }
     }
 }

+ 38 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.custom;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -52,6 +53,7 @@ import java.util.List;
 import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
+import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@@ -611,6 +613,42 @@ public class CustomServiceTests extends AbstractInferenceServiceTests {
         }
     }
 
+    public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNotFillTemplate() throws Exception {
+        try (var service = createService(threadPool, clientManager)) {
+
+            var settingsMap = new HashMap<>(
+                Map.of(
+                    CustomServiceSettings.URL,
+                    "http://www.abc.com",
+                    CustomServiceSettings.HEADERS,
+                    Map.of("key", "value"),
+                    QueryParameters.QUERY_PARAMETERS,
+                    List.of(List.of("key", "value")),
+                    CustomServiceSettings.REQUEST,
+                    "request body ${some_template}",
+                    CustomServiceSettings.RESPONSE,
+                    new HashMap<>(Map.of(CustomServiceSettings.JSON_PARSER, createResponseParserMap(TaskType.COMPLETION)))
+                )
+            );
+
+            var config = getRequestConfigMap(settingsMap, createTaskSettingsMap(), createSecretSettingsMap());
+
+            var listener = new PlainActionFuture<Model>();
+            service.parseRequestConfig("id", TaskType.COMPLETION, config, listener);
+
+            var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT));
+
+            assertThat(
+                exception.getMessage(),
+                is(
+                    "Validation Failed: 1: Failed to validate model configuration: Found placeholder "
+                        + "[${some_template}] in field [request] after replacement call, please check that all "
+                        + "templates have a corresponding field definition.;"
+                )
+            );
+        }
+    }
+
     public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
         var model = createInternalEmbeddingModel(
             SimilarityMeasure.DOT_PRODUCT,

+ 7 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java

@@ -264,7 +264,13 @@ public class CustomRequestTests extends ESTestCase {
 
         var request = new CustomRequest(null, List.of("abc", "123"), model);
         var exception = expectThrows(IllegalStateException.class, request::createHttpRequest);
-        assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call"));
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Found placeholder [${task.key}] in field [header.Accept] after replacement call, "
+                    + "please check that all templates have a corresponding field definition."
+            )
+        );
     }
 
     public void testCreateRequest_ThrowsException_ForInvalidUrl() {