|
@@ -24,12 +24,17 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
|
|
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
|
|
|
|
|
+import java.nio.ByteBuffer;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
import java.util.Arrays;
|
|
|
+import java.util.Base64;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.TreeSet;
|
|
|
|
|
|
+import static org.hamcrest.Matchers.containsString;
|
|
|
+import static org.hamcrest.Matchers.emptyString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.hamcrest.Matchers.instanceOf;
|
|
@@ -183,6 +188,144 @@ public class TrainedModelProviderTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testGetDefinitionFromDocsTruncated() {
|
|
|
+ String modelId = randomAlphaOfLength(10);
|
|
|
+ Exception ex = expectThrows(
|
|
|
+ Exception.class,
|
|
|
+ () -> TrainedModelProvider.getDefinitionFromDocs(
|
|
|
+ List.of(
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(10)),
|
|
|
+ modelId,
|
|
|
+ 0,
|
|
|
+ randomLongBetween(10, 100),
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ randomBoolean()
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ modelId
|
|
|
+ )
|
|
|
+ );
|
|
|
+ assertThat(
|
|
|
+ ex.getMessage(),
|
|
|
+ containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
|
|
|
+ );
|
|
|
+
|
|
|
+ ex = expectThrows(
|
|
|
+ Exception.class,
|
|
|
+ () -> TrainedModelProvider.getDefinitionFromDocs(
|
|
|
+ List.of(
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(10)),
|
|
|
+ modelId,
|
|
|
+ 0,
|
|
|
+ randomLongBetween(21, 100),
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ randomBoolean()
|
|
|
+ ),
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(10)),
|
|
|
+ modelId,
|
|
|
+ 0,
|
|
|
+ randomLongBetween(21, 100),
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ randomBoolean()
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ modelId
|
|
|
+ )
|
|
|
+ );
|
|
|
+ assertThat(
|
|
|
+ ex.getMessage(),
|
|
|
+ containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
|
|
|
+ );
|
|
|
+
|
|
|
+ ex = expectThrows(
|
|
|
+ Exception.class,
|
|
|
+ () -> TrainedModelProvider.getDefinitionFromDocs(
|
|
|
+ List.of(
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(10)),
|
|
|
+ modelId,
|
|
|
+ 0,
|
|
|
+ randomFrom((Long) null, 20L),
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ randomBoolean()
|
|
|
+ ),
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(10)),
|
|
|
+ modelId,
|
|
|
+ 1,
|
|
|
+ randomFrom((Long) null, 20L),
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ false
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ modelId
|
|
|
+ )
|
|
|
+ );
|
|
|
+ assertThat(
|
|
|
+ ex.getMessage(),
|
|
|
+ containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetDefinitionFromDocs() {
|
|
|
+ String modelId = randomAlphaOfLength(10);
|
|
|
+
|
|
|
+ int byteArrayLength = randomIntBetween(1, 1000);
|
|
|
+ BytesReference bytesReference = TrainedModelProvider.getDefinitionFromDocs(
|
|
|
+ List.of(
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(byteArrayLength)),
|
|
|
+ modelId,
|
|
|
+ 0,
|
|
|
+ randomFrom((Long) null, (long) byteArrayLength),
|
|
|
+ byteArrayLength,
|
|
|
+ 1,
|
|
|
+ true
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ modelId
|
|
|
+ );
|
|
|
+ // None of the following should throw
|
|
|
+ ByteBuffer bb = Base64.getEncoder()
|
|
|
+ .encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
|
|
|
+ assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
|
|
|
+
|
|
|
+ bytesReference = TrainedModelProvider.getDefinitionFromDocs(
|
|
|
+ List.of(
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(byteArrayLength)),
|
|
|
+ modelId,
|
|
|
+ 0,
|
|
|
+ randomFrom((Long) null, (long) byteArrayLength * 2),
|
|
|
+ byteArrayLength,
|
|
|
+ 1,
|
|
|
+ false
|
|
|
+ ),
|
|
|
+ new TrainedModelDefinitionDoc(
|
|
|
+ new BytesArray(randomByteArrayOfLength(byteArrayLength)),
|
|
|
+ modelId,
|
|
|
+ 1,
|
|
|
+ randomFrom((Long) null, (long) byteArrayLength * 2),
|
|
|
+ byteArrayLength,
|
|
|
+ 1,
|
|
|
+ true
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ modelId
|
|
|
+ );
|
|
|
+
|
|
|
+ bb = Base64.getEncoder().encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
|
|
|
+ assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
protected NamedXContentRegistry xContentRegistry() {
|
|
|
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|