|  | @@ -24,12 +24,17 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import java.nio.ByteBuffer;
 | 
	
		
			
				|  |  | +import java.nio.charset.StandardCharsets;
 | 
	
		
			
				|  |  |  import java.util.Arrays;
 | 
	
		
			
				|  |  | +import java.util.Base64;
 | 
	
		
			
				|  |  |  import java.util.Collections;
 | 
	
		
			
				|  |  |  import java.util.HashSet;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.TreeSet;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.containsString;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.emptyString;
 | 
	
		
			
				|  |  |  import static org.hamcrest.Matchers.equalTo;
 | 
	
		
			
				|  |  |  import static org.hamcrest.Matchers.hasSize;
 | 
	
		
			
				|  |  |  import static org.hamcrest.Matchers.instanceOf;
 | 
	
	
		
			
				|  | @@ -183,6 +188,144 @@ public class TrainedModelProviderTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    public void testGetDefinitionFromDocsTruncated() {
 | 
	
		
			
				|  |  | +        String modelId = randomAlphaOfLength(10);
 | 
	
		
			
				|  |  | +        Exception ex = expectThrows(
 | 
	
		
			
				|  |  | +            Exception.class,
 | 
	
		
			
				|  |  | +            () -> TrainedModelProvider.getDefinitionFromDocs(
 | 
	
		
			
				|  |  | +                List.of(
 | 
	
		
			
				|  |  | +                    new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                        new BytesArray(randomByteArrayOfLength(10)),
 | 
	
		
			
				|  |  | +                        modelId,
 | 
	
		
			
				|  |  | +                        0,
 | 
	
		
			
				|  |  | +                        randomLongBetween(10, 100),
 | 
	
		
			
				|  |  | +                        10,
 | 
	
		
			
				|  |  | +                        1,
 | 
	
		
			
				|  |  | +                        randomBoolean()
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                modelId
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +        assertThat(
 | 
	
		
			
				|  |  | +            ex.getMessage(),
 | 
	
		
			
				|  |  | +            containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ex = expectThrows(
 | 
	
		
			
				|  |  | +            Exception.class,
 | 
	
		
			
				|  |  | +            () -> TrainedModelProvider.getDefinitionFromDocs(
 | 
	
		
			
				|  |  | +                List.of(
 | 
	
		
			
				|  |  | +                    new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                        new BytesArray(randomByteArrayOfLength(10)),
 | 
	
		
			
				|  |  | +                        modelId,
 | 
	
		
			
				|  |  | +                        0,
 | 
	
		
			
				|  |  | +                        randomLongBetween(21, 100),
 | 
	
		
			
				|  |  | +                        10,
 | 
	
		
			
				|  |  | +                        1,
 | 
	
		
			
				|  |  | +                        randomBoolean()
 | 
	
		
			
				|  |  | +                    ),
 | 
	
		
			
				|  |  | +                    new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                        new BytesArray(randomByteArrayOfLength(10)),
 | 
	
		
			
				|  |  | +                        modelId,
 | 
	
		
			
				|  |  | +                        0,
 | 
	
		
			
				|  |  | +                        randomLongBetween(21, 100),
 | 
	
		
			
				|  |  | +                        10,
 | 
	
		
			
				|  |  | +                        1,
 | 
	
		
			
				|  |  | +                        randomBoolean()
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                modelId
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +        assertThat(
 | 
	
		
			
				|  |  | +            ex.getMessage(),
 | 
	
		
			
				|  |  | +            containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ex = expectThrows(
 | 
	
		
			
				|  |  | +            Exception.class,
 | 
	
		
			
				|  |  | +            () -> TrainedModelProvider.getDefinitionFromDocs(
 | 
	
		
			
				|  |  | +                List.of(
 | 
	
		
			
				|  |  | +                    new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                        new BytesArray(randomByteArrayOfLength(10)),
 | 
	
		
			
				|  |  | +                        modelId,
 | 
	
		
			
				|  |  | +                        0,
 | 
	
		
			
				|  |  | +                        randomFrom((Long) null, 20L),
 | 
	
		
			
				|  |  | +                        10,
 | 
	
		
			
				|  |  | +                        1,
 | 
	
		
			
				|  |  | +                        randomBoolean()
 | 
	
		
			
				|  |  | +                    ),
 | 
	
		
			
				|  |  | +                    new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                        new BytesArray(randomByteArrayOfLength(10)),
 | 
	
		
			
				|  |  | +                        modelId,
 | 
	
		
			
				|  |  | +                        1,
 | 
	
		
			
				|  |  | +                        randomFrom((Long) null, 20L),
 | 
	
		
			
				|  |  | +                        10,
 | 
	
		
			
				|  |  | +                        1,
 | 
	
		
			
				|  |  | +                        false
 | 
	
		
			
				|  |  | +                    )
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                modelId
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +        assertThat(
 | 
	
		
			
				|  |  | +            ex.getMessage(),
 | 
	
		
			
				|  |  | +            containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testGetDefinitionFromDocs() {
 | 
	
		
			
				|  |  | +        String modelId = randomAlphaOfLength(10);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        int byteArrayLength = randomIntBetween(1, 1000);
 | 
	
		
			
				|  |  | +        BytesReference bytesReference = TrainedModelProvider.getDefinitionFromDocs(
 | 
	
		
			
				|  |  | +            List.of(
 | 
	
		
			
				|  |  | +                new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                    new BytesArray(randomByteArrayOfLength(byteArrayLength)),
 | 
	
		
			
				|  |  | +                    modelId,
 | 
	
		
			
				|  |  | +                    0,
 | 
	
		
			
				|  |  | +                    randomFrom((Long) null, (long) byteArrayLength),
 | 
	
		
			
				|  |  | +                    byteArrayLength,
 | 
	
		
			
				|  |  | +                    1,
 | 
	
		
			
				|  |  | +                    true
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            ),
 | 
	
		
			
				|  |  | +            modelId
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +        // None of the following should throw
 | 
	
		
			
				|  |  | +        ByteBuffer bb = Base64.getEncoder()
 | 
	
		
			
				|  |  | +            .encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
 | 
	
		
			
				|  |  | +        assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        bytesReference = TrainedModelProvider.getDefinitionFromDocs(
 | 
	
		
			
				|  |  | +            List.of(
 | 
	
		
			
				|  |  | +                new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                    new BytesArray(randomByteArrayOfLength(byteArrayLength)),
 | 
	
		
			
				|  |  | +                    modelId,
 | 
	
		
			
				|  |  | +                    0,
 | 
	
		
			
				|  |  | +                    randomFrom((Long) null, (long) byteArrayLength * 2),
 | 
	
		
			
				|  |  | +                    byteArrayLength,
 | 
	
		
			
				|  |  | +                    1,
 | 
	
		
			
				|  |  | +                    false
 | 
	
		
			
				|  |  | +                ),
 | 
	
		
			
				|  |  | +                new TrainedModelDefinitionDoc(
 | 
	
		
			
				|  |  | +                    new BytesArray(randomByteArrayOfLength(byteArrayLength)),
 | 
	
		
			
				|  |  | +                    modelId,
 | 
	
		
			
				|  |  | +                    1,
 | 
	
		
			
				|  |  | +                    randomFrom((Long) null, (long) byteArrayLength * 2),
 | 
	
		
			
				|  |  | +                    byteArrayLength,
 | 
	
		
			
				|  |  | +                    1,
 | 
	
		
			
				|  |  | +                    true
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            ),
 | 
	
		
			
				|  |  | +            modelId
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        bb = Base64.getEncoder().encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
 | 
	
		
			
				|  |  | +        assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  |      protected NamedXContentRegistry xContentRegistry() {
 | 
	
		
			
				|  |  |          return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
 |