|
@@ -13,8 +13,10 @@ import org.elasticsearch.common.bytes.BytesArray;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
+import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.inference.ChunkingSettings;
|
|
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
|
|
+import org.elasticsearch.inference.InputType;
|
|
|
import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
@@ -109,7 +111,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
projectId
|
|
|
)
|
|
|
),
|
|
|
- new HashMap<>(Map.of()),
|
|
|
+ getTaskSettingsMap(true, InputType.INGEST),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
),
|
|
|
modelListener
|
|
@@ -154,7 +156,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
projectId
|
|
|
)
|
|
|
),
|
|
|
- new HashMap<>(Map.of()),
|
|
|
+ getTaskSettingsMap(true, InputType.INGEST),
|
|
|
createRandomChunkingSettingsMap(),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
),
|
|
@@ -200,7 +202,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
projectId
|
|
|
)
|
|
|
),
|
|
|
- new HashMap<>(Map.of()),
|
|
|
+ getTaskSettingsMap(false, InputType.SEARCH),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
),
|
|
|
modelListener
|
|
@@ -281,7 +283,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
"project"
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(true),
|
|
|
+ getTaskSettingsMap(true, InputType.SEARCH),
|
|
|
getSecretSettingsMap("{}")
|
|
|
);
|
|
|
config.put("extra_key", "value");
|
|
@@ -308,7 +310,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
);
|
|
|
serviceSettings.put("extra_key", "value");
|
|
|
|
|
|
- var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(true), getSecretSettingsMap("{}"));
|
|
|
+ var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(true, InputType.CLUSTERING), getSecretSettingsMap("{}"));
|
|
|
|
|
|
var failureListener = getModelListenerForException(
|
|
|
ElasticsearchStatusException.class,
|
|
@@ -362,7 +364,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
"project"
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(true),
|
|
|
+ getTaskSettingsMap(true, null),
|
|
|
secretSettings
|
|
|
);
|
|
|
|
|
@@ -399,7 +401,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, InputType.SEARCH),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
);
|
|
|
|
|
@@ -417,7 +419,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.SEARCH)));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
|
}
|
|
@@ -447,7 +449,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, null),
|
|
|
createRandomChunkingSettingsMap(),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
);
|
|
@@ -466,7 +468,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
|
|
|
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
@@ -497,7 +499,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, null),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
);
|
|
|
|
|
@@ -515,7 +517,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
|
|
|
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
@@ -573,7 +575,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, InputType.INGEST),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
);
|
|
|
persistedConfig.config().put("extra_key", "value");
|
|
@@ -592,7 +594,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.INGEST)));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
|
}
|
|
@@ -625,7 +627,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, null),
|
|
|
secretSettingsMap
|
|
|
);
|
|
|
|
|
@@ -643,7 +645,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
|
}
|
|
@@ -676,7 +678,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
|
|
|
var persistedConfig = getPersistedConfigMap(
|
|
|
serviceSettingsMap,
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, InputType.CLUSTERING),
|
|
|
getSecretSettingsMap(serviceAccountJson)
|
|
|
);
|
|
|
|
|
@@ -694,7 +696,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.CLUSTERING)));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
|
}
|
|
@@ -711,7 +713,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
""";
|
|
|
|
|
|
try (var service = createGoogleVertexAiService()) {
|
|
|
- var taskSettings = getTaskSettingsMap(autoTruncate);
|
|
|
+ var taskSettings = getTaskSettingsMap(autoTruncate, InputType.SEARCH);
|
|
|
taskSettings.put("extra_key", "value");
|
|
|
|
|
|
var persistedConfig = getPersistedConfigMap(
|
|
@@ -745,7 +747,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.SEARCH)));
|
|
|
assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
|
|
|
}
|
|
|
}
|
|
@@ -770,7 +772,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate),
|
|
|
+ getTaskSettingsMap(autoTruncate, null),
|
|
|
createRandomChunkingSettingsMap()
|
|
|
);
|
|
|
|
|
@@ -783,7 +785,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
|
|
|
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
}
|
|
|
}
|
|
@@ -808,7 +810,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
true
|
|
|
)
|
|
|
),
|
|
|
- getTaskSettingsMap(autoTruncate)
|
|
|
+ getTaskSettingsMap(autoTruncate, null)
|
|
|
);
|
|
|
|
|
|
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
|
@@ -820,7 +822,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().location(), is(location));
|
|
|
assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
|
|
|
- assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
|
|
|
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
}
|
|
|
}
|
|
@@ -838,12 +840,44 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
{
|
|
|
"task_type": "text_embedding",
|
|
|
"configuration": {
|
|
|
+ "input_type": {
|
|
|
+ "default_value": null,
|
|
|
+ "depends_on": [],
|
|
|
+ "display": "dropdown",
|
|
|
+ "label": "Input Type",
|
|
|
+ "options": [
|
|
|
+ {
|
|
|
+ "label": "classification",
|
|
|
+ "value": "classification"
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "label": "clustering",
|
|
|
+ "value": "clustering"
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "label": "ingest",
|
|
|
+ "value": "ingest"
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "label": "search",
|
|
|
+ "value": "search"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "order": 1,
|
|
|
+ "required": false,
|
|
|
+ "sensitive": false,
|
|
|
+ "tooltip": "Specifies the type of input passed to the model.",
|
|
|
+ "type": "str",
|
|
|
+ "ui_restrictions": [],
|
|
|
+ "validations": [],
|
|
|
+ "value": ""
|
|
|
+ },
|
|
|
"auto_truncate": {
|
|
|
"default_value": null,
|
|
|
"depends_on": [],
|
|
|
"display": "toggle",
|
|
|
"label": "Auto Truncate",
|
|
|
- "order": 1,
|
|
|
+ "order": 2,
|
|
|
"required": false,
|
|
|
"sensitive": false,
|
|
|
"tooltip": "Specifies if the API truncates inputs longer than the maximum token length automatically.",
|
|
@@ -1005,11 +1039,15 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- private static Map<String, Object> getTaskSettingsMap(Boolean autoTruncate) {
|
|
|
+ private static Map<String, Object> getTaskSettingsMap(Boolean autoTruncate, @Nullable InputType inputType) {
|
|
|
var taskSettings = new HashMap<String, Object>();
|
|
|
|
|
|
taskSettings.put(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, autoTruncate);
|
|
|
|
|
|
+ if (inputType != null) {
|
|
|
+ taskSettings.put(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
|
|
|
+ }
|
|
|
+
|
|
|
return taskSettings;
|
|
|
}
|
|
|
|