|
@@ -15,7 +15,8 @@ import org.elasticsearch.ResourceNotFoundException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.DocWriteRequest;
|
|
|
import org.elasticsearch.action.bulk.BulkAction;
|
|
|
-import org.elasticsearch.action.bulk.BulkRequest;
|
|
|
+import org.elasticsearch.action.bulk.BulkItemResponse;
|
|
|
+import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
|
|
import org.elasticsearch.action.bulk.BulkResponse;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.search.MultiSearchAction;
|
|
@@ -85,6 +86,7 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
import java.util.TreeSet;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
|
@@ -95,6 +97,9 @@ public class TrainedModelProvider {
|
|
|
public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
|
|
|
private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
|
|
|
private static final String MODEL_RESOURCE_FILE_EXT = ".json";
|
|
|
+ private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024;
|
|
|
+ private static final int MAX_NUM_DEFINITION_DOCS = 100;
|
|
|
+ private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS;
|
|
|
|
|
|
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
|
|
|
private final Client client;
|
|
@@ -138,30 +143,41 @@ public class TrainedModelProvider {
|
|
|
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
|
|
|
ActionListener<Boolean> listener) {
|
|
|
|
|
|
- TrainedModelDefinitionDoc trainedModelDefinitionDoc;
|
|
|
+ List<TrainedModelDefinitionDoc> trainedModelDefinitionDocs = new ArrayList<>();
|
|
|
try {
|
|
|
- // TODO should we check length against allowed stream size???
|
|
|
String compressedString = trainedModelConfig.getCompressedDefinition();
|
|
|
- trainedModelDefinitionDoc = new TrainedModelDefinitionDoc.Builder()
|
|
|
- .setDocNum(0)
|
|
|
- .setModelId(trainedModelConfig.getModelId())
|
|
|
- .setCompressedString(compressedString)
|
|
|
- .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
|
|
- .setDefinitionLength(compressedString.length())
|
|
|
- .setTotalDefinitionLength(compressedString.length())
|
|
|
- .build();
|
|
|
+ if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) {
|
|
|
+ listener.onFailure(
|
|
|
+ ExceptionsHelper.badRequestException(
|
|
|
+ "Unable to store model as compressed definition has length [{}] the limit is [{}]",
|
|
|
+ compressedString.length(),
|
|
|
+ MAX_COMPRESSED_STRING_SIZE));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ List<String> chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE);
|
|
|
+ for(int i = 0; i < chunkedStrings.size(); ++i) {
|
|
|
+ trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder()
|
|
|
+ .setDocNum(i)
|
|
|
+ .setModelId(trainedModelConfig.getModelId())
|
|
|
+ .setCompressedString(chunkedStrings.get(i))
|
|
|
+ .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
|
|
+ .setDefinitionLength(chunkedStrings.get(i).length())
|
|
|
+ .setTotalDefinitionLength(compressedString.length())
|
|
|
+ .build());
|
|
|
+ }
|
|
|
} catch (IOException ex) {
|
|
|
listener.onFailure(ExceptionsHelper.serverError(
|
|
|
- "Unexpected IOException while serializing definition for storage for model [" + trainedModelConfig.getModelId() + "]",
|
|
|
- ex));
|
|
|
+ "Unexpected IOException while serializing definition for storage for model [{}]",
|
|
|
+ ex,
|
|
|
+ trainedModelConfig.getModelId()));
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME)
|
|
|
+ BulkRequestBuilder bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME)
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig))
|
|
|
- .add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), 0), trainedModelDefinitionDoc))
|
|
|
- .request();
|
|
|
+ .add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig));
|
|
|
+ trainedModelDefinitionDocs.forEach(defDoc ->
|
|
|
+ bulkRequest.add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), defDoc.getDocNum()), defDoc)));
|
|
|
|
|
|
ActionListener<Boolean> wrappedListener = ActionListener.wrap(
|
|
|
listener::onResponse,
|
|
@@ -181,9 +197,8 @@ public class TrainedModelProvider {
|
|
|
|
|
|
ActionListener<BulkResponse> bulkResponseActionListener = ActionListener.wrap(
|
|
|
r -> {
|
|
|
- assert r.getItems().length == 2;
|
|
|
+ assert r.getItems().length == trainedModelDefinitionDocs.size() + 1;
|
|
|
if (r.getItems()[0].isFailed()) {
|
|
|
-
|
|
|
logger.error(new ParameterizedMessage(
|
|
|
"[{}] failed to store trained model config for inference",
|
|
|
trainedModelConfig.getModelId()),
|
|
@@ -192,12 +207,18 @@ public class TrainedModelProvider {
|
|
|
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
|
|
|
return;
|
|
|
}
|
|
|
- if (r.getItems()[1].isFailed()) {
|
|
|
+ if (r.hasFailures()) {
|
|
|
+ Exception firstFailure = Arrays.stream(r.getItems())
|
|
|
+ .filter(BulkItemResponse::isFailed)
|
|
|
+ .map(BulkItemResponse::getFailure)
|
|
|
+ .map(BulkItemResponse.Failure::getCause)
|
|
|
+ .findFirst()
|
|
|
+ .orElse(new Exception("unknown failure"));
|
|
|
logger.error(new ParameterizedMessage(
|
|
|
"[{}] failed to store trained model definition for inference",
|
|
|
trainedModelConfig.getModelId()),
|
|
|
- r.getItems()[1].getFailure().getCause());
|
|
|
- wrappedListener.onFailure(r.getItems()[1].getFailure().getCause());
|
|
|
+ firstFailure);
|
|
|
+ wrappedListener.onFailure(firstFailure);
|
|
|
return;
|
|
|
}
|
|
|
wrappedListener.onResponse(true);
|
|
@@ -205,7 +226,7 @@ public class TrainedModelProvider {
|
|
|
wrappedListener::onFailure
|
|
|
);
|
|
|
|
|
|
- executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener);
|
|
|
+ executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
|
|
|
}
|
|
|
|
|
|
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
|
@@ -234,11 +255,20 @@ public class TrainedModelProvider {
|
|
|
if (includeDefinition) {
|
|
|
multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
|
|
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
|
|
- .idsQuery()
|
|
|
- .addIds(TrainedModelDefinitionDoc.docId(modelId, 0))))
|
|
|
- // use sort to get the last
|
|
|
+ .boolQuery()
|
|
|
+ .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
|
|
|
+ .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME))))
|
|
|
+ // There should be AT MOST these many docs. There might be more if definitions have been reindex to newer indices
|
|
|
+ // If this ends up getting duplicate groups of definition documents, the parsing logic will throw away any doc that
|
|
|
+ // is in a different index than the first index seen.
|
|
|
+ .setSize(MAX_NUM_DEFINITION_DOCS)
|
|
|
+ // First find the latest index
|
|
|
.addSort("_index", SortOrder.DESC)
|
|
|
- .setSize(1)
|
|
|
+ // Then, sort by doc_num
|
|
|
+ .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName())
|
|
|
+ .order(SortOrder.ASC)
|
|
|
+ // We need this for the search not to fail when there are no mappings yet in the index
|
|
|
+ .unmappedType("long"))
|
|
|
.request());
|
|
|
}
|
|
|
|
|
@@ -258,15 +288,18 @@ public class TrainedModelProvider {
|
|
|
|
|
|
if (includeDefinition) {
|
|
|
try {
|
|
|
- TrainedModelDefinitionDoc doc = handleSearchItem(multiSearchResponse.getResponses()[1],
|
|
|
+ List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
|
|
|
modelId,
|
|
|
this::parseModelDefinitionDocLenientlyFromSource);
|
|
|
- if (doc.getCompressedString().length() != doc.getTotalDefinitionLength()) {
|
|
|
+ String compressedString = docs.stream()
|
|
|
+ .map(TrainedModelDefinitionDoc::getCompressedString)
|
|
|
+ .collect(Collectors.joining());
|
|
|
+ if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
|
|
listener.onFailure(ExceptionsHelper.serverError(
|
|
|
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
|
|
return;
|
|
|
}
|
|
|
- builder.setDefinitionFromString(doc.getCompressedString());
|
|
|
+ builder.setDefinitionFromString(compressedString);
|
|
|
} catch (ResourceNotFoundException ex) {
|
|
|
listener.onFailure(new ResourceNotFoundException(
|
|
|
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
|
@@ -677,13 +710,36 @@ public class TrainedModelProvider {
|
|
|
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
|
|
String resourceId,
|
|
|
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
|
|
|
+ return handleSearchItems(item, resourceId, parseLeniently).get(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ // NOTE: This ignores any results that are in a different index than the first one seen in the search response.
|
|
|
+ private static <T> List<T> handleSearchItems(MultiSearchResponse.Item item,
|
|
|
+ String resourceId,
|
|
|
+ CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
|
|
|
if (item.isFailure()) {
|
|
|
throw item.getFailure();
|
|
|
}
|
|
|
if (item.getResponse().getHits().getHits().length == 0) {
|
|
|
throw new ResourceNotFoundException(resourceId);
|
|
|
}
|
|
|
- return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
|
|
|
+ List<T> results = new ArrayList<>(item.getResponse().getHits().getHits().length);
|
|
|
+ String initialIndex = item.getResponse().getHits().getHits()[0].getIndex();
|
|
|
+ for (SearchHit hit : item.getResponse().getHits().getHits()) {
|
|
|
+ // We don't want to spread across multiple backing indices
|
|
|
+ if (hit.getIndex().equals(initialIndex)) {
|
|
|
+ results.add(parseLeniently.apply(hit.getSourceRef(), resourceId));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return results;
|
|
|
+ }
|
|
|
+
|
|
|
+ static List<String> chunkStringWithSize(String str, int chunkSize) {
|
|
|
+ List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
|
|
|
+ for (int i = 0; i < str.length();i += chunkSize) {
|
|
|
+ subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
|
|
|
+ }
|
|
|
+ return subStrings;
|
|
|
}
|
|
|
|
|
|
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {
|