|
@@ -342,7 +342,22 @@ public class CsvTestsDataLoader {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
public static void deleteInferenceEndpoints(RestClient client) throws IOException {
|
|
|
+ Response response = client.performRequest(new Request("GET", "_inference/_all"));
|
|
|
+
|
|
|
+ try (InputStream content = response.getEntity().getContent()) {
|
|
|
+ XContentType xContentType = XContentType.fromMediaType(response.getEntity().getContentType().getValue());
|
|
|
+ Map<String, ?> responseMap = XContentHelper.convertToMap(xContentType.xContent(), content, false);
|
|
|
+ List<Map<String, ?>> endpoints = (List<Map<String, ?>>) responseMap.get("endpoints");
|
|
|
+ for (Map<String, ?> endpoint : endpoints) {
|
|
|
+ String inferenceId = (String) endpoint.get("inference_id");
|
|
|
+ String taskType = (String) endpoint.get("task_type");
|
|
|
+ if (inferenceId != null && taskType != null && inferenceId.startsWith(".") == false) {
|
|
|
+ deleteInferenceEndpoint(client, inferenceId, TaskType.fromString(taskType));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
deleteSparseEmbeddingInferenceEndpoint(client);
|
|
|
deleteRerankInferenceEndpoint(client);
|
|
|
deleteCompletionInferenceEndpoint(client);
|
|
@@ -360,7 +375,7 @@ public class CsvTestsDataLoader {
|
|
|
}
|
|
|
|
|
|
public static void deleteSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
|
|
|
- deleteInferenceEndpoint(client, "test_sparse_inference");
|
|
|
+ deleteInferenceEndpoint(client, "test_sparse_inference", TaskType.SPARSE_EMBEDDING);
|
|
|
}
|
|
|
|
|
|
public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
|
|
@@ -378,7 +393,7 @@ public class CsvTestsDataLoader {
|
|
|
}
|
|
|
|
|
|
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException {
|
|
|
- deleteInferenceEndpoint(client, "test_reranker");
|
|
|
+ deleteInferenceEndpoint(client, "test_reranker", TaskType.RERANK);
|
|
|
}
|
|
|
|
|
|
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException {
|
|
@@ -396,7 +411,7 @@ public class CsvTestsDataLoader {
|
|
|
}
|
|
|
|
|
|
public static void deleteCompletionInferenceEndpoint(RestClient client) throws IOException {
|
|
|
- deleteInferenceEndpoint(client, "test_completion");
|
|
|
+ deleteInferenceEndpoint(client, "test_completion", TaskType.COMPLETION);
|
|
|
}
|
|
|
|
|
|
public static boolean clusterHasCompletionInferenceEndpoint(RestClient client) throws IOException {
|
|
@@ -423,9 +438,9 @@ public class CsvTestsDataLoader {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
- private static void deleteInferenceEndpoint(RestClient client, String inferenceId) throws IOException {
|
|
|
+ private static void deleteInferenceEndpoint(RestClient client, String inferenceId, TaskType taskType) throws IOException {
|
|
|
try {
|
|
|
- client.performRequest(new Request("DELETE", "_inference/" + inferenceId));
|
|
|
+ client.performRequest(new Request("DELETE", "_inference/" + taskType + "/" + inferenceId));
|
|
|
} catch (ResponseException e) {
|
|
|
// 404 here means the endpoint was not created
|
|
|
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
|