|
@@ -5,7 +5,7 @@
|
|
|
* 2.0.
|
|
|
*/
|
|
|
|
|
|
-package org.elasticsearch.xpack.inference.action;
|
|
|
+package org.elasticsearch.xpack.core.inference.action;
|
|
|
|
|
|
import org.elasticsearch.TransportVersion;
|
|
|
import org.elasticsearch.TransportVersions;
|
|
@@ -14,7 +14,6 @@ import org.elasticsearch.core.Tuple;
|
|
|
import org.elasticsearch.inference.InputType;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
|
|
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
|
|
|
|
|
import java.io.IOException;
|
|
@@ -23,6 +22,7 @@ import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
|
|
|
+import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.getInputTypeToWrite;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
|
|
|
|
|
@@ -159,15 +159,26 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
InputType.UNSPECIFIED
|
|
|
);
|
|
|
} else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED)
|
|
|
- && instance.getInputType() == InputType.UNSPECIFIED) {
|
|
|
- return new InferenceAction.Request(
|
|
|
- instance.getTaskType(),
|
|
|
- instance.getInferenceEntityId(),
|
|
|
- instance.getInput(),
|
|
|
- instance.getTaskSettings(),
|
|
|
- InputType.INGEST
|
|
|
- );
|
|
|
- }
|
|
|
+ && (instance.getInputType() == InputType.UNSPECIFIED
|
|
|
+ || instance.getInputType() == InputType.CLASSIFICATION
|
|
|
+ || instance.getInputType() == InputType.CLUSTERING)) {
|
|
|
+ return new InferenceAction.Request(
|
|
|
+ instance.getTaskType(),
|
|
|
+ instance.getInferenceEntityId(),
|
|
|
+ instance.getInput(),
|
|
|
+ instance.getTaskSettings(),
|
|
|
+ InputType.INGEST
|
|
|
+ );
|
|
|
+ } else if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED)
|
|
|
+ && (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) {
|
|
|
+ return new InferenceAction.Request(
|
|
|
+ instance.getTaskType(),
|
|
|
+ instance.getInferenceEntityId(),
|
|
|
+ instance.getInput(),
|
|
|
+ instance.getTaskSettings(),
|
|
|
+ InputType.UNSPECIFIED
|
|
|
+ );
|
|
|
+ }
|
|
|
|
|
|
return instance;
|
|
|
}
|
|
@@ -199,6 +210,66 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
assertThat(deserializedInstance.getInputType(), is(InputType.INGEST));
|
|
|
}
|
|
|
|
|
|
+ public void testWriteTo_WhenVersionIsBeforeUnspecifiedAdded_ButAfterInputTypeAdded_ShouldSetToIngest_WhenClustering_ManualCheck()
|
|
|
+ throws IOException {
|
|
|
+ var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLUSTERING);
|
|
|
+
|
|
|
+ InferenceAction.Request deserializedInstance = copyWriteable(
|
|
|
+ instance,
|
|
|
+ getNamedWriteableRegistry(),
|
|
|
+ instanceReader(),
|
|
|
+ TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(deserializedInstance.getInputType(), is(InputType.INGEST));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testWriteTo_WhenVersionIsBeforeUnspecifiedAdded_ButAfterInputTypeAdded_ShouldSetToIngest_WhenClassification_ManualCheck()
|
|
|
+ throws IOException {
|
|
|
+ var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLASSIFICATION);
|
|
|
+
|
|
|
+ InferenceAction.Request deserializedInstance = copyWriteable(
|
|
|
+ instance,
|
|
|
+ getNamedWriteableRegistry(),
|
|
|
+ instanceReader(),
|
|
|
+ TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(deserializedInstance.getInputType(), is(InputType.INGEST));
|
|
|
+ }
|
|
|
+
|
|
|
+ public
|
|
|
+ void
|
|
|
+ testWriteTo_WhenVersionIsBeforeClusterClassAdded_ButAfterUnspecifiedAdded_ShouldSetToUnspecified_WhenClassification_ManualCheck()
|
|
|
+ throws IOException {
|
|
|
+ var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLASSIFICATION);
|
|
|
+
|
|
|
+ InferenceAction.Request deserializedInstance = copyWriteable(
|
|
|
+ instance,
|
|
|
+ getNamedWriteableRegistry(),
|
|
|
+ instanceReader(),
|
|
|
+ TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
|
|
|
+ }
|
|
|
+
|
|
|
+ public
|
|
|
+ void
|
|
|
+ testWriteTo_WhenVersionIsBeforeClusterClassAdded_ButAfterUnspecifiedAdded_ShouldSetToUnspecified_WhenClustering_ManualCheck()
|
|
|
+ throws IOException {
|
|
|
+ var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.CLUSTERING);
|
|
|
+
|
|
|
+ InferenceAction.Request deserializedInstance = copyWriteable(
|
|
|
+ instance,
|
|
|
+ getNamedWriteableRegistry(),
|
|
|
+ instanceReader(),
|
|
|
+ TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
|
|
|
+ }
|
|
|
+
|
|
|
public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUnspecified() throws IOException {
|
|
|
var instance = new InferenceAction.Request(TaskType.TEXT_EMBEDDING, "model", List.of(), Map.of(), InputType.INGEST);
|
|
|
|
|
@@ -211,4 +282,39 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
|
|
|
|
|
assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
|
|
|
}
|
|
|
+
|
|
|
+ public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() {
|
|
|
+ assertThat(
|
|
|
+ getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED),
|
|
|
+ is(InputType.INGEST)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsClassification_VersionBeforeUnspecifiedIntroduced() {
|
|
|
+ assertThat(
|
|
|
+ getInputTypeToWrite(InputType.CLASSIFICATION, TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED),
|
|
|
+ is(InputType.INGEST)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsClustering_VersionBeforeUnspecifiedIntroduced() {
|
|
|
+ assertThat(
|
|
|
+ getInputTypeToWrite(InputType.CLUSTERING, TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED),
|
|
|
+ is(InputType.INGEST)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetInputTypeToWrite_ReturnsUnspecified_WhenInputTypeIsClassification_VersionBeforeClusteringClassIntroduced() {
|
|
|
+ assertThat(
|
|
|
+ getInputTypeToWrite(InputType.CLUSTERING, TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED),
|
|
|
+ is(InputType.UNSPECIFIED)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetInputTypeToWrite_ReturnsUnspecified_WhenInputTypeIsClustering_VersionBeforeClusteringClassIntroduced() {
|
|
|
+ assertThat(
|
|
|
+ getInputTypeToWrite(InputType.CLASSIFICATION, TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED),
|
|
|
+ is(InputType.UNSPECIFIED)
|
|
|
+ );
|
|
|
+ }
|
|
|
}
|