|
|
@@ -14,6 +14,7 @@ import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.settings.SecureString;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
|
|
+import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.test.cluster.ElasticsearchCluster;
|
|
|
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
|
|
|
@@ -50,8 +51,14 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
|
|
}
|
|
|
|
|
|
static String mockServiceModelConfig() {
|
|
|
- return """
|
|
|
+ return mockServiceModelConfig(null);
|
|
|
+ }
|
|
|
+
|
|
|
+ static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) {
|
|
|
+ var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
|
|
|
+ return Strings.format("""
|
|
|
{
|
|
|
+ %s
|
|
|
"service": "test_service",
|
|
|
"service_settings": {
|
|
|
"model": "my_model",
|
|
|
@@ -61,11 +68,35 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
|
|
"temperature": 3
|
|
|
}
|
|
|
}
|
|
|
- """;
|
|
|
+ """, taskType);
|
|
|
+ }
|
|
|
+
|
|
|
+ protected void deleteModel(String modelId) throws IOException {
|
|
|
+ var request = new Request("DELETE", "_inference/" + modelId);
|
|
|
+ var response = client().performRequest(request);
|
|
|
+ assertOkOrCreated(response);
|
|
|
+ }
|
|
|
+
|
|
|
+ protected void deleteModel(String modelId, TaskType taskType) throws IOException {
|
|
|
+ var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId));
|
|
|
+ var response = client().performRequest(request);
|
|
|
+ assertOkOrCreated(response);
|
|
|
}
|
|
|
|
|
|
protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
|
|
|
String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
|
|
|
+ return putModelInternal(endpoint, modelConfig);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Task type should be in modelConfig
|
|
|
+ */
|
|
|
+ protected Map<String, Object> putModel(String modelId, String modelConfig) throws IOException {
|
|
|
+ String endpoint = Strings.format("_inference/%s", modelId);
|
|
|
+ return putModelInternal(endpoint, modelConfig);
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, Object> putModelInternal(String endpoint, String modelConfig) throws IOException {
|
|
|
var request = new Request("PUT", endpoint);
|
|
|
request.setJsonEntity(modelConfig);
|
|
|
var response = client().performRequest(request);
|
|
|
@@ -73,24 +104,38 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
|
|
return entityAsMap(response);
|
|
|
}
|
|
|
|
|
|
+ protected Map<String, Object> getModel(String modelId) throws IOException {
|
|
|
+ var endpoint = Strings.format("_inference/%s", modelId);
|
|
|
+ return getAllModelInternal(endpoint);
|
|
|
+ }
|
|
|
+
|
|
|
protected Map<String, Object> getModels(String modelId, TaskType taskType) throws IOException {
|
|
|
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
|
|
|
- var request = new Request("GET", endpoint);
|
|
|
- var response = client().performRequest(request);
|
|
|
- assertOkOrCreated(response);
|
|
|
- return entityAsMap(response);
|
|
|
+ return getAllModelInternal(endpoint);
|
|
|
}
|
|
|
|
|
|
protected Map<String, Object> getAllModels() throws IOException {
|
|
|
- var endpoint = Strings.format("_inference/_all");
|
|
|
+ return getAllModelInternal("_inference/_all");
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, Object> getAllModelInternal(String endpoint) throws IOException {
|
|
|
var request = new Request("GET", endpoint);
|
|
|
var response = client().performRequest(request);
|
|
|
assertOkOrCreated(response);
|
|
|
return entityAsMap(response);
|
|
|
}
|
|
|
|
|
|
+ protected Map<String, Object> inferOnMockService(String modelId, List<String> input) throws IOException {
|
|
|
+ var endpoint = Strings.format("_inference/%s", modelId);
|
|
|
+ return inferOnMockServiceInternal(endpoint, input);
|
|
|
+ }
|
|
|
+
|
|
|
protected Map<String, Object> inferOnMockService(String modelId, TaskType taskType, List<String> input) throws IOException {
|
|
|
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
|
|
|
+ return inferOnMockServiceInternal(endpoint, input);
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, Object> inferOnMockServiceInternal(String endpoint, List<String> input) throws IOException {
|
|
|
var request = new Request("POST", endpoint);
|
|
|
|
|
|
var bodyBuilder = new StringBuilder("{\"input\": [");
|