|
|
@@ -24,9 +24,7 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
|
|
import org.elasticsearch.action.bulk.BulkResponse;
|
|
|
import org.elasticsearch.action.index.IndexAction;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
|
-import org.elasticsearch.action.search.MultiSearchAction;
|
|
|
import org.elasticsearch.action.search.MultiSearchRequest;
|
|
|
-import org.elasticsearch.action.search.MultiSearchRequestBuilder;
|
|
|
import org.elasticsearch.action.search.MultiSearchResponse;
|
|
|
import org.elasticsearch.action.search.SearchAction;
|
|
|
import org.elasticsearch.action.search.SearchRequest;
|
|
|
@@ -602,26 +600,17 @@ public class TrainedModelProvider {
|
|
|
}, finalListener::onFailure);
|
|
|
|
|
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelId));
|
|
|
- MultiSearchRequestBuilder multiSearchRequestBuilder = client.prepareMultiSearch()
|
|
|
- .add(
|
|
|
- client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
|
|
- .setQuery(queryBuilder)
|
|
|
- // use sort to get the last
|
|
|
- .addSort("_index", SortOrder.DESC)
|
|
|
- .setSize(1)
|
|
|
- .request()
|
|
|
- );
|
|
|
-
|
|
|
- if (includes.isIncludeModelDefinition()) {
|
|
|
- multiSearchRequestBuilder.add(
|
|
|
- ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS)
|
|
|
- );
|
|
|
- }
|
|
|
+ SearchRequest trainedModelConfigSearch = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
|
|
+ .setQuery(queryBuilder)
|
|
|
+ // use sort to get the last
|
|
|
+ .addSort("_index", SortOrder.DESC)
|
|
|
+ .setSize(1)
|
|
|
+ .request();
|
|
|
|
|
|
- ActionListener<MultiSearchResponse> multiSearchResponseActionListener = ActionListener.wrap(multiSearchResponse -> {
|
|
|
+ ActionListener<SearchResponse> trainedModelSearchHandler = ActionListener.wrap(modelSearchResponse -> {
|
|
|
TrainedModelConfig.Builder builder;
|
|
|
try {
|
|
|
- builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseModelConfigLenientlyFromSource);
|
|
|
+ builder = handleHits(modelSearchResponse.getHits().getHits(), modelId, this::parseModelConfigLenientlyFromSource).get(0);
|
|
|
} catch (ResourceNotFoundException ex) {
|
|
|
getTrainedModelListener.onFailure(
|
|
|
new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))
|
|
|
@@ -631,46 +620,58 @@ public class TrainedModelProvider {
|
|
|
getTrainedModelListener.onFailure(ex);
|
|
|
return;
|
|
|
}
|
|
|
-
|
|
|
- if (includes.isIncludeModelDefinition()) {
|
|
|
- try {
|
|
|
- List<TrainedModelDefinitionDoc> docs = handleSearchItems(
|
|
|
- multiSearchResponse.getResponses()[1],
|
|
|
+ if (includes.isIncludeModelDefinition() == false) {
|
|
|
+ getTrainedModelListener.onResponse(builder);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (builder.getModelType() == TrainedModelType.PYTORCH && includes.isIncludeModelDefinition()) {
|
|
|
+ finalListener.onFailure(
|
|
|
+ ExceptionsHelper.badRequestException(
|
|
|
+ "[{}] is type [{}] and does not support retrieving the definition",
|
|
|
modelId,
|
|
|
- (bytes, resourceId) -> ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(
|
|
|
- bytes,
|
|
|
- resourceId,
|
|
|
- xContentRegistry
|
|
|
- )
|
|
|
- );
|
|
|
+ builder.getModelType()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ executeAsyncWithOrigin(
|
|
|
+ client,
|
|
|
+ ML_ORIGIN,
|
|
|
+ SearchAction.INSTANCE,
|
|
|
+ ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS),
|
|
|
+ ActionListener.wrap(definitionSearchResponse -> {
|
|
|
try {
|
|
|
- BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
|
|
|
- builder.setDefinitionFromBytes(compressedData);
|
|
|
- } catch (ElasticsearchException elasticsearchException) {
|
|
|
- getTrainedModelListener.onFailure(elasticsearchException);
|
|
|
+ List<TrainedModelDefinitionDoc> docs = handleHits(
|
|
|
+ definitionSearchResponse.getHits().getHits(),
|
|
|
+ modelId,
|
|
|
+ (bytes, resourceId) -> ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(
|
|
|
+ bytes,
|
|
|
+ resourceId,
|
|
|
+ xContentRegistry
|
|
|
+ )
|
|
|
+ );
|
|
|
+ try {
|
|
|
+ BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
|
|
|
+ builder.setDefinitionFromBytes(compressedData);
|
|
|
+ } catch (ElasticsearchException elasticsearchException) {
|
|
|
+ getTrainedModelListener.onFailure(elasticsearchException);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ } catch (ResourceNotFoundException ex) {
|
|
|
+ getTrainedModelListener.onFailure(
|
|
|
+ new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ } catch (Exception ex) {
|
|
|
+ getTrainedModelListener.onFailure(ex);
|
|
|
return;
|
|
|
}
|
|
|
-
|
|
|
- } catch (ResourceNotFoundException ex) {
|
|
|
- getTrainedModelListener.onFailure(
|
|
|
- new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))
|
|
|
- );
|
|
|
- return;
|
|
|
- } catch (Exception ex) {
|
|
|
- getTrainedModelListener.onFailure(ex);
|
|
|
- return;
|
|
|
- }
|
|
|
- }
|
|
|
- getTrainedModelListener.onResponse(builder);
|
|
|
+ getTrainedModelListener.onResponse(builder);
|
|
|
+ }, getTrainedModelListener::onFailure)
|
|
|
+ );
|
|
|
}, getTrainedModelListener::onFailure);
|
|
|
-
|
|
|
- executeAsyncWithOrigin(
|
|
|
- client,
|
|
|
- ML_ORIGIN,
|
|
|
- MultiSearchAction.INSTANCE,
|
|
|
- multiSearchRequestBuilder.request(),
|
|
|
- multiSearchResponseActionListener
|
|
|
- );
|
|
|
+ executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, trainedModelConfigSearch, trainedModelSearchHandler);
|
|
|
}
|
|
|
|
|
|
public void getTrainedModels(
|
|
|
@@ -1204,6 +1205,9 @@ public class TrainedModelProvider {
|
|
|
String resourceId,
|
|
|
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently
|
|
|
) throws Exception {
|
|
|
+ if (hits.length == 0) {
|
|
|
+ throw new ResourceNotFoundException(resourceId);
|
|
|
+ }
|
|
|
List<T> results = new ArrayList<>(hits.length);
|
|
|
String initialIndex = hits[0].getIndex();
|
|
|
for (SearchHit hit : hits) {
|