|
@@ -8,8 +8,14 @@ package org.elasticsearch.xpack.ml.rest.inference;
|
|
|
import org.elasticsearch.client.node.NodeClient;
|
|
|
import org.elasticsearch.cluster.metadata.MetaData;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
+import org.elasticsearch.common.xcontent.ToXContent;
|
|
|
+import org.elasticsearch.common.xcontent.ToXContentObject;
|
|
|
+import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.rest.BaseRestHandler;
|
|
|
+import org.elasticsearch.rest.BytesRestResponse;
|
|
|
+import org.elasticsearch.rest.RestChannel;
|
|
|
import org.elasticsearch.rest.RestRequest;
|
|
|
+import org.elasticsearch.rest.RestResponse;
|
|
|
import org.elasticsearch.rest.action.RestToXContentListener;
|
|
|
import org.elasticsearch.xpack.core.action.util.PageParams;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
|
@@ -18,7 +24,9 @@ import org.elasticsearch.xpack.ml.MachineLearning;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.Collections;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
|
|
|
import static java.util.Arrays.asList;
|
|
@@ -34,6 +42,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
|
|
new Route(GET, MachineLearning.BASE_PATH + "inference"));
|
|
|
}
|
|
|
|
|
|
+ private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
|
|
|
+ Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
|
|
|
@Override
|
|
|
public String getName() {
|
|
|
return "ml_get_trained_models_action";
|
|
@@ -56,7 +66,9 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
|
|
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
|
|
}
|
|
|
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
|
|
|
- return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
|
|
+ return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
|
|
|
+ request,
|
|
|
+ new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -64,4 +76,23 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
|
|
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
|
|
|
}
|
|
|
|
|
|
+ private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
|
|
|
+ private final Map<String, String> defaultToXContentParamValues;
|
|
|
+
|
|
|
+ private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
|
|
|
+ super(channel);
|
|
|
+ this.defaultToXContentParamValues = defaultToXContentParamValues;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
|
|
|
+ assert response.isFragment() == false; //would be nice if we could make default methods final
|
|
|
+ Map<String, String> params = new HashMap<>(channel.request().params());
|
|
|
+ defaultToXContentParamValues.forEach((k, v) ->
|
|
|
+ params.computeIfAbsent(k, defaultToXContentParamValues::get)
|
|
|
+ );
|
|
|
+ response.toXContent(builder, new ToXContent.MapParams(params));
|
|
|
+ return new BytesRestResponse(getStatus(response), builder);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|