فهرست منبع

[ML] fixing bug when returning multi-doc compressed model definitions (#80377)

When deserializing multi-doc compressed model definitions, I
periodically receive weird errors regarding bytes references.  Turns
out, when we parse the individual bytes references we parse them into
`ByteArrays` which satisfy the `bytes()` method.  But,
`CompositeBytesReference`, when there are multiple BytesReferences, does
not allow `bytes()` to be called.
Benjamin Trent 4 سال پیش
والد
کامیت
7ed67e655d

+ 77 - 0
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.index.IndexResponse;
 import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.license.License;
 import org.elasticsearch.xcontent.ToXContent;
@@ -145,6 +146,82 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         assertThat(getConfigHolder.get().getMetadata(), hasKey("hyperparameters"));
     }
 
+    public void testGetTrainedModelConfigWithMultiDocDefinition() throws Exception {
+        String modelId = "test-get-trained-model-config";
+        TrainedModelConfig config = buildTrainedModelConfig(modelId);
+
+        AtomicReference<Void> dummy = new AtomicReference<>();
+        AtomicReference<Boolean> booleanDummy = new AtomicReference<>();
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+
+        BytesReference definition = config.getCompressedDefinition();
+
+        blockingCall(
+            listener -> trainedModelProvider.storeTrainedModelDefinitionDoc(
+                new TrainedModelDefinitionDoc(
+                    new BytesArray(definition.array(), 0, definition.length() - 5),
+                    modelId,
+                    0,
+                    (long) definition.length(),
+                    definition.length() - 5,
+                    1,
+                    false
+                ),
+                listener
+            ),
+            dummy::set,
+            e -> fail(e.getMessage())
+        );
+        blockingCall(
+            listener -> trainedModelProvider.storeTrainedModelDefinitionDoc(
+                new TrainedModelDefinitionDoc(
+                    new BytesArray(definition.array(), definition.length() - 5, 5),
+                    modelId,
+                    1,
+                    (long) definition.length(),
+                    5,
+                    1,
+                    true
+                ),
+                listener
+            ),
+            dummy::set,
+            e -> fail(e.getMessage())
+        );
+        blockingCall(
+            listener -> trainedModelProvider.storeTrainedModelConfig(
+                new TrainedModelConfig.Builder(config).clearDefinition().build(),
+                listener
+            ),
+            booleanDummy::set,
+            e -> fail(e.getMessage())
+        );
+        blockingCall(
+            listener -> trainedModelProvider.refreshInferenceIndex(listener),
+            new AtomicReference<RefreshResponse>(),
+            new AtomicReference<>()
+        );
+
+        AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
+        blockingCall(
+            listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
+            getConfigHolder,
+            exceptionHolder
+        );
+        if (exceptionHolder.get() != null) {
+            throw exceptionHolder.get();
+        }
+        getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
+        assertThat(getConfigHolder.get(), is(not(nullValue())));
+        assertThat(getConfigHolder.get(), equalTo(config));
+        assertThat(getConfigHolder.get().getModelDefinition(), is(not(nullValue())));
+
+        try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
+            // Should not throw
+            getConfigHolder.get().toXContent(builder, ToXContent.EMPTY_PARAMS);
+        }
+    }
+
     public void testGetTrainedModelConfigWithoutDefinition() throws Exception {
         String modelId = "test-get-trained-model-config-no-definition";
         TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).build();

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java

@@ -86,7 +86,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
     private final int compressionVersion;
     private final boolean eos;
 
-    private TrainedModelDefinitionDoc(
+    public TrainedModelDefinitionDoc(
         BytesReference binaryData,
         String modelId,
         int docNum,

+ 10 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -35,6 +35,7 @@ import org.elasticsearch.client.Client;
 import org.elasticsearch.common.CheckedBiFunction;
 import org.elasticsearch.common.Numbers;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.bytes.CompositeBytesReference;
 import org.elasticsearch.common.regex.Regex;
@@ -1219,14 +1220,16 @@ public class TrainedModelProvider {
         return results;
     }
 
-    private static BytesReference getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId)
-        throws ElasticsearchException {
+    static BytesReference getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId) throws ElasticsearchException {
 
-        BytesReference[] bb = new BytesReference[docs.size()];
-        for (int i = 0; i < docs.size(); i++) {
-            bb[i] = docs.get(i).getBinaryData();
-        }
-        BytesReference bytes = CompositeBytesReference.of(bb);
+        // If the user requested the compressed data string, we need access to the underlying bytes.
+        // BytesArray gives us that access.
+        BytesReference bytes = docs.size() == 1
+            ? docs.get(0).getBinaryData()
+            : new BytesArray(
+                CompositeBytesReference.of(docs.stream().map(TrainedModelDefinitionDoc::getBinaryData).toArray(BytesReference[]::new))
+                    .toBytesRef()
+            );
 
         if (docs.get(0).getTotalDefinitionLength() != null) {
             if (bytes.length() != docs.get(0).getTotalDefinitionLength()) {
@@ -1264,7 +1267,6 @@ public class TrainedModelProvider {
                 // lang ident model were the only models supported. Models created after
                 // VERSION_3RD_PARTY_CONFIG_ADDED must have modelType set, if not set modelType
                 // is a tree ensemble
-                assert builder.getVersion().before(TrainedModelConfig.VERSION_3RD_PARTY_CONFIG_ADDED);
                 builder.setModelType(TrainedModelType.TREE_ENSEMBLE);
             }
             return builder;

+ 7 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java

@@ -119,13 +119,18 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
     }
 
     protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
+        throws InterruptedException {
+        blockingCall(function, response::set, error::set);
+    }
+
+    protected <T> void blockingCall(Consumer<ActionListener<T>> function, Consumer<T> response, Consumer<Exception> error)
         throws InterruptedException {
         CountDownLatch latch = new CountDownLatch(1);
         ActionListener<T> listener = ActionListener.wrap(r -> {
-            response.set(r);
+            response.accept(r);
             latch.countDown();
         }, e -> {
-            error.set(e);
+            error.accept(e);
             latch.countDown();
         });
 

+ 143 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

@@ -24,12 +24,17 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.TreeSet;
 
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.emptyString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
@@ -183,6 +188,144 @@ public class TrainedModelProviderTests extends ESTestCase {
         }
     }
 
+    public void testGetDefinitionFromDocsTruncated() {
+        String modelId = randomAlphaOfLength(10);
+        Exception ex = expectThrows(
+            Exception.class,
+            () -> TrainedModelProvider.getDefinitionFromDocs(
+                List.of(
+                    new TrainedModelDefinitionDoc(
+                        new BytesArray(randomByteArrayOfLength(10)),
+                        modelId,
+                        0,
+                        randomLongBetween(10, 100),
+                        10,
+                        1,
+                        randomBoolean()
+                    )
+                ),
+                modelId
+            )
+        );
+        assertThat(
+            ex.getMessage(),
+            containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
+        );
+
+        ex = expectThrows(
+            Exception.class,
+            () -> TrainedModelProvider.getDefinitionFromDocs(
+                List.of(
+                    new TrainedModelDefinitionDoc(
+                        new BytesArray(randomByteArrayOfLength(10)),
+                        modelId,
+                        0,
+                        randomLongBetween(21, 100),
+                        10,
+                        1,
+                        randomBoolean()
+                    ),
+                    new TrainedModelDefinitionDoc(
+                        new BytesArray(randomByteArrayOfLength(10)),
+                        modelId,
+                        0,
+                        randomLongBetween(21, 100),
+                        10,
+                        1,
+                        randomBoolean()
+                    )
+                ),
+                modelId
+            )
+        );
+        assertThat(
+            ex.getMessage(),
+            containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
+        );
+
+        ex = expectThrows(
+            Exception.class,
+            () -> TrainedModelProvider.getDefinitionFromDocs(
+                List.of(
+                    new TrainedModelDefinitionDoc(
+                        new BytesArray(randomByteArrayOfLength(10)),
+                        modelId,
+                        0,
+                        randomFrom((Long) null, 20L),
+                        10,
+                        1,
+                        randomBoolean()
+                    ),
+                    new TrainedModelDefinitionDoc(
+                        new BytesArray(randomByteArrayOfLength(10)),
+                        modelId,
+                        1,
+                        randomFrom((Long) null, 20L),
+                        10,
+                        1,
+                        false
+                    )
+                ),
+                modelId
+            )
+        );
+        assertThat(
+            ex.getMessage(),
+            containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
+        );
+    }
+
+    public void testGetDefinitionFromDocs() {
+        String modelId = randomAlphaOfLength(10);
+
+        int byteArrayLength = randomIntBetween(1, 1000);
+        BytesReference bytesReference = TrainedModelProvider.getDefinitionFromDocs(
+            List.of(
+                new TrainedModelDefinitionDoc(
+                    new BytesArray(randomByteArrayOfLength(byteArrayLength)),
+                    modelId,
+                    0,
+                    randomFrom((Long) null, (long) byteArrayLength),
+                    byteArrayLength,
+                    1,
+                    true
+                )
+            ),
+            modelId
+        );
+        // None of the following should throw
+        ByteBuffer bb = Base64.getEncoder()
+            .encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
+        assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
+
+        bytesReference = TrainedModelProvider.getDefinitionFromDocs(
+            List.of(
+                new TrainedModelDefinitionDoc(
+                    new BytesArray(randomByteArrayOfLength(byteArrayLength)),
+                    modelId,
+                    0,
+                    randomFrom((Long) null, (long) byteArrayLength * 2),
+                    byteArrayLength,
+                    1,
+                    false
+                ),
+                new TrainedModelDefinitionDoc(
+                    new BytesArray(randomByteArrayOfLength(byteArrayLength)),
+                    modelId,
+                    1,
+                    randomFrom((Long) null, (long) byteArrayLength * 2),
+                    byteArrayLength,
+                    1,
+                    true
+                )
+            ),
+            modelId
+        );
+
+        bb = Base64.getEncoder().encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
+        assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
+    }
+
     @Override
     protected NamedXContentRegistry xContentRegistry() {
         return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());