Browse Source

[ML] Adding clustering and classification enums for cohere input type (#105253)

* Adding clustering and classification enums for cohere

* Adding comment about default package scope

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Jonathan Buttner 1 year ago
parent
commit
89e714ee5d

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -128,6 +128,7 @@ public class TransportVersions {
     public static final TransportVersion HEALTH_INFO_ENRICHED_WITH_REPOS = def(8_588_00_0);
     public static final TransportVersion RESOLVE_CLUSTER_ENDPOINT_ADDED = def(8_589_00_0);
     public static final TransportVersion FIELD_CAPS_FIELD_HAS_VALUE = def(8_590_00_0);
+    public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED = def(8_591_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 3 - 1
server/src/main/java/org/elasticsearch/inference/InputType.java

@@ -16,7 +16,9 @@ import java.util.Locale;
 public enum InputType {
     INGEST,
     SEARCH,
-    UNSPECIFIED;
+    UNSPECIFIED,
+    CLASSIFICATION,
+    CLUSTERING;
 
     @Override
     public String toString() {

+ 17 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

@@ -32,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -57,6 +58,12 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
         }
 
+        private static final EnumSet<InputType> validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH);
+        private static final EnumSet<InputType> validEnumsBeforeClassificationClusteringAdded = EnumSet.range(
+            InputType.INGEST,
+            InputType.UNSPECIFIED
+        );
+
         public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) {
             Request.Builder builder = PARSER.apply(parser, null);
             builder.setInferenceEntityId(inferenceEntityId);
@@ -152,16 +159,20 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             // in version ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED the input type enum was added, so we only want to write the enum if we're
             // at that version or later
             if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
-                out.writeEnum(getInputTypeToWrite(out.getTransportVersion()));
+                out.writeEnum(getInputTypeToWrite(inputType, out.getTransportVersion()));
             }
         }
 
-        private InputType getInputTypeToWrite(TransportVersion version) {
-            // in version ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED the UNSPECIFIED value was added, so if we're before that
-            // version other nodes won't know about it, so set it to INGEST instead
-            if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED) && inputType == InputType.UNSPECIFIED) {
+        // default for easier testing
+        static InputType getInputTypeToWrite(InputType inputType, TransportVersion version) {
+            if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED)
+                && validEnumsBeforeUnspecifiedAdded.contains(inputType) == false) {
                 return InputType.INGEST;
-            }
+            } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED)
+                && validEnumsBeforeClassificationClusteringAdded.contains(inputType) == false) {
+                    return InputType.UNSPECIFIED;
+                }
+
             return inputType;
         }
 

+ 117 - 11
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java → x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.inference.action;
+package org.elasticsearch.xpack.core.inference.action;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
@@ -14,7 +14,6 @@ import org.elasticsearch.core.Tuple;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xcontent.json.JsonXContent;
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 
 import java.io.IOException;
@@ -23,6 +22,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.getInputTypeToWrite;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
 
@@ -159,15 +159,26 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                 InputType.UNSPECIFIED
             );
         } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED)
-            && instance.getInputType() == InputType.UNSPECIFIED) {
-                return new InferenceAction.Request(
-                    instance.getTaskType(),
-                    instance.getInferenceEntityId(),
-                    instance.getInput(),
-                    instance.getTaskSettings(),
-                    InputType.INGEST
-                );
-            }
+            && (instance.getInputType() == InputType.UNSPECIFIED
+                || instance.getInputType() == InputType.CLASSIFICATION
+                || instance.getInputType() == InputType.CLUSTERING)) {
+                    return new InferenceAction.Request(
+                        instance.getTaskType(),
+                        instance.getInferenceEntityId(),
+                        instance.getInput(),
+                        instance.getTaskSettings(),
+                        InputType.INGEST
+                    );
+                } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED)
+                    && (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) {
+                        return new InferenceAction.Request(
+                            instance.getTaskType(),
+                            instance.getInferenceEntityId(),
+                            instance.getInput(),
+                            instance.getTaskSettings(),
+                            InputType.UNSPECIFIED
+                        );
+                    }
 
         return instance;
     }
@@ -199,6 +210,66 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
         assertThat(deserializedInstance.getInputType(), is(InputType.INGEST));
     }
 
+    public void testWriteTo_WhenVersionIsBeforeUnspecifiedAdded_ButAfterInputTypeAdded_ShouldSetToIngest_WhenClustering_ManualCheck()
+        throws IOException {
+        var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLUSTERING);
+
+        InferenceAction.Request deserializedInstance = copyWriteable(
+            instance,
+            getNamedWriteableRegistry(),
+            instanceReader(),
+            TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED
+        );
+
+        assertThat(deserializedInstance.getInputType(), is(InputType.INGEST));
+    }
+
+    public void testWriteTo_WhenVersionIsBeforeUnspecifiedAdded_ButAfterInputTypeAdded_ShouldSetToIngest_WhenClassification_ManualCheck()
+        throws IOException {
+        var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLASSIFICATION);
+
+        InferenceAction.Request deserializedInstance = copyWriteable(
+            instance,
+            getNamedWriteableRegistry(),
+            instanceReader(),
+            TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED
+        );
+
+        assertThat(deserializedInstance.getInputType(), is(InputType.INGEST));
+    }
+
+    public
+        void
+        testWriteTo_WhenVersionIsBeforeClusterClassAdded_ButAfterUnspecifiedAdded_ShouldSetToUnspecified_WhenClassification_ManualCheck()
+            throws IOException {
+        var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLASSIFICATION);
+
+        InferenceAction.Request deserializedInstance = copyWriteable(
+            instance,
+            getNamedWriteableRegistry(),
+            instanceReader(),
+            TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED
+        );
+
+        assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
+    }
+
+    public
+        void
+        testWriteTo_WhenVersionIsBeforeClusterClassAdded_ButAfterUnspecifiedAdded_ShouldSetToUnspecified_WhenClustering_ManualCheck()
+            throws IOException {
+        var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLUSTERING);
+
+        InferenceAction.Request deserializedInstance = copyWriteable(
+            instance,
+            getNamedWriteableRegistry(),
+            instanceReader(),
+            TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED
+        );
+
+        assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
+    }
+
     public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUnspecified() throws IOException {
         var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.INGEST);
 
@@ -211,4 +282,39 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
 
         assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
     }
+
+    public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() {
+        assertThat(
+            getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED),
+            is(InputType.INGEST)
+        );
+    }
+
+    public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsClassification_VersionBeforeUnspecifiedIntroduced() {
+        assertThat(
+            getInputTypeToWrite(InputType.CLASSIFICATION, TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED),
+            is(InputType.INGEST)
+        );
+    }
+
+    public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsClustering_VersionBeforeUnspecifiedIntroduced() {
+        assertThat(
+            getInputTypeToWrite(InputType.CLUSTERING, TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED),
+            is(InputType.INGEST)
+        );
+    }
+
+    public void testGetInputTypeToWrite_ReturnsUnspecified_WhenInputTypeIsClassification_VersionBeforeClusteringClassIntroduced() {
+        assertThat(
+            getInputTypeToWrite(InputType.CLUSTERING, TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED),
+            is(InputType.UNSPECIFIED)
+        );
+    }
+
+    public void testGetInputTypeToWrite_ReturnsUnspecified_WhenInputTypeIsClustering_VersionBeforeClusteringClassIntroduced() {
+        assertThat(
+            getInputTypeToWrite(InputType.CLASSIFICATION, TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED),
+            is(InputType.UNSPECIFIED)
+        );
+    }
 }

+ 4 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntity.java

@@ -31,6 +31,8 @@ public record CohereEmbeddingsRequestEntity(
 
     private static final String SEARCH_DOCUMENT = "search_document";
     private static final String SEARCH_QUERY = "search_query";
+    private static final String CLUSTERING = "clustering";
+    private static final String CLASSIFICATION = "classification";
 
     private static final String TEXTS_FIELD = "texts";
 
@@ -71,6 +73,8 @@ public record CohereEmbeddingsRequestEntity(
         return switch (inputType) {
             case INGEST -> SEARCH_DOCUMENT;
             case SEARCH -> SEARCH_QUERY;
+            case CLASSIFICATION -> CLASSIFICATION;
+            case CLUSTERING -> CLUSTERING;
             default -> {
                 assert false : invalidInputTypeMessage(inputType);
                 yield null;

+ 7 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -24,6 +24,7 @@ import org.elasticsearch.xpack.inference.common.SimilarityMeasure;
 
 import java.net.URI;
 import java.net.URISyntaxException;
+import java.util.Arrays;
 import java.util.EnumSet;
 import java.util.List;
 import java.util.Locale;
@@ -110,13 +111,16 @@ public class ServiceUtils {
         return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
     }
 
-    public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredTypes) {
+    public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredValues) {
+        var copyOfRequiredValues = requiredValues.clone();
+        Arrays.sort(copyOfRequiredValues);
+
         return Strings.format(
             "[%s] Invalid value [%s] received. [%s] must be one of [%s]",
             scope,
             invalidType,
             settingName,
-            String.join(", ", requiredTypes)
+            String.join(", ", copyOfRequiredValues)
         );
     }
 
@@ -235,6 +239,7 @@ public class ServiceUtils {
         }
 
         var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
+
         try {
             var createdEnum = constructor.apply(enumString);
             validateEnumValue(createdEnum, validValues);

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

@@ -223,6 +223,6 @@ public class CohereService extends SenderService {
 
     @Override
     public TransportVersion getMinimalSupportedVersion() {
-        return TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED;
+        return TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED;
     }
 }

+ 9 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java

@@ -40,7 +40,12 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
     public static final String NAME = "cohere_embeddings_task_settings";
     public static final CohereEmbeddingsTaskSettings EMPTY_SETTINGS = new CohereEmbeddingsTaskSettings(null, null);
     static final String INPUT_TYPE = "input_type";
-    private static final EnumSet<InputType> VALID_REQUEST_VALUES2 = EnumSet.of(InputType.INGEST, InputType.SEARCH);
+    static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(
+        InputType.INGEST,
+        InputType.SEARCH,
+        InputType.CLASSIFICATION,
+        InputType.CLUSTERING
+    );
 
     public static CohereEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
         if (map == null || map.isEmpty()) {
@@ -54,7 +59,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
             INPUT_TYPE,
             ModelConfigurations.TASK_SETTINGS,
             InputType::fromString,
-            VALID_REQUEST_VALUES2,
+            VALID_REQUEST_VALUES,
             validationException
         );
         CohereTruncation truncation = extractOptionalEnum(
@@ -103,7 +108,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
     ) {
         InputType inputTypeToUse = originalSettings.inputType;
 
-        if (VALID_REQUEST_VALUES2.contains(requestInputType)) {
+        if (VALID_REQUEST_VALUES.contains(requestInputType)) {
             inputTypeToUse = requestInputType;
         } else if (requestTaskSettings.inputType != null) {
             inputTypeToUse = requestTaskSettings.inputType;
@@ -137,7 +142,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
             return;
         }
 
-        assert VALID_REQUEST_VALUES2.contains(inputType) : invalidInputTypeMessage(inputType);
+        assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType);
     }
 
     @Override

+ 1 - 5
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java

@@ -12,10 +12,6 @@ import org.elasticsearch.test.ESTestCase;
 
 public class InputTypeTests extends ESTestCase {
     public static InputType randomWithoutUnspecified() {
-        return randomFrom(InputType.INGEST, InputType.SEARCH);
-    }
-
-    public static InputType[] valuesWithoutUnspecified() {
-        return new InputType[] { InputType.INGEST, InputType.SEARCH };
+        return randomFrom(InputType.INGEST, InputType.SEARCH, InputType.CLUSTERING, InputType.CLASSIFICATION);
     }
 }

+ 17 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java

@@ -311,6 +311,23 @@ public class ServiceUtilsTests extends ESTestCase {
         assertTrue(map.isEmpty());
     }
 
+    public void testExtractOptionalEnum_ReturnsClassification_WhenValueIsAcceptable() {
+        var validation = new ValidationException();
+        Map<String, Object> map = modifiableMap(Map.of("key", InputType.CLASSIFICATION.toString()));
+        var createdEnum = extractOptionalEnum(
+            map,
+            "key",
+            "scope",
+            InputType::fromString,
+            EnumSet.of(InputType.INGEST, InputType.CLASSIFICATION),
+            validation
+        );
+
+        assertThat(createdEnum, is(InputType.CLASSIFICATION));
+        assertTrue(validation.validationErrors().isEmpty());
+        assertTrue(map.isEmpty());
+    }
+
     public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() {
         var service = mock(InferenceService.class);
 

+ 24 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettingsTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.services.cohere.embeddings;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.Nullable;
@@ -18,10 +19,14 @@ import org.hamcrest.CoreMatchers;
 import org.hamcrest.MatcherAssert;
 
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.EnumSet;
 import java.util.HashMap;
+import java.util.Locale;
 import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified;
+import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.VALID_REQUEST_VALUES;
 import static org.hamcrest.Matchers.is;
 
 public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase<CohereEmbeddingsTaskSettings> {
@@ -68,7 +73,12 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe
 
         MatcherAssert.assertThat(
             exception.getMessage(),
-            is("Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [ingest, search];")
+            is(
+                Strings.format(
+                    "Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];",
+                    getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
+                )
+            )
         );
     }
 
@@ -82,10 +92,22 @@ public class CohereEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe
 
         MatcherAssert.assertThat(
             exception.getMessage(),
-            is("Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [ingest, search];")
+            is(
+                Strings.format(
+                    "Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];",
+                    getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
+                )
+            )
         );
     }
 
+    private static <E extends Enum<E>> String getValidValuesSortedAndCombined(EnumSet<E> validValues) {
+        var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
+        Arrays.sort(validValuesAsStrings);
+
+        return String.join(", ", validValuesAsStrings);
+    }
+
     public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
         var thrownException = expectThrows(AssertionError.class, () -> new CohereEmbeddingsTaskSettings(InputType.UNSPECIFIED, null));
         MatcherAssert.assertThat(thrownException.getMessage(), CoreMatchers.is("received invalid input type value [unspecified]"));