Pārlūkot izejas kodu

[8.x] Test ML model server (#120270) (#120586)

* Test ML model server (#120270)

* Fix model downloading for very small models.

* Test MlModelServer

* Tiny ELSER

* unmute TextEmbeddingCrudIT and DefaultEndPointsIT

* update ELSER

* Improve MlModelServer

* tiny E5

* more logging

* improved E5 model

* tiny reranker

* scan for ports

* [CI] Auto commit changes from spotless

* Serve default models when optimized model is requested

* @ClassRule

* polish code

* Respect dynamic setting ML model repo

* fix metadata for optimized models

* improve logging

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>

* backport HttpHeaderParser

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Jan Kuipers 9 mēneši atpakaļ
vecāks
revīzija
dc66c15bc0
19 mainītis faili ar 354 papildinājumiem un 29 dzēšanām
  1. 0 21
      muted-tests.yml
  2. 42 0
      test/framework/src/main/java/org/elasticsearch/test/fixture/HttpHeaderParser.java
  3. 53 0
      test/framework/src/test/java/org/elasticsearch/http/HttpHeaderParserTests.java
  4. 1 0
      x-pack/plugin/inference/qa/inference-service-tests/build.gradle
  5. 18 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java
  6. 146 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java
  7. 0 2
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java
  8. 25 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.metadata.json
  9. BIN
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.pt
  10. 0 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.vocab.json
  11. 32 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/multilingual-e5-small.metadata.json
  12. BIN
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/multilingual-e5-small.pt
  13. 0 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/multilingual-e5-small.vocab.json
  14. 15 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/rerank-v1.metadata.json
  15. BIN
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/rerank-v1.pt
  16. 0 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/rerank-v1.vocab.json
  17. 4 1
      x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java
  18. 1 5
      x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java
  19. 17 0
      x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java

+ 0 - 21
muted-tests.yml

@@ -176,12 +176,6 @@ tests:
 - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
   method: test {p0=indices.split/40_routing_partition_size/more than 1}
   issue: https://github.com/elastic/elasticsearch/issues/113841
-- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
-  method: testPutE5WithTrainedModelAndInference
-  issue: https://github.com/elastic/elasticsearch/issues/114023
-- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
-  method: testPutE5Small_withPlatformAgnosticVariant
-  issue: https://github.com/elastic/elasticsearch/issues/113983
 - class: org.elasticsearch.datastreams.LazyRolloverDuringDisruptionIT
   method: testRolloverIsExecutedOnce
   issue: https://github.com/elastic/elasticsearch/issues/112634
@@ -191,18 +185,12 @@ tests:
 - class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT
   method: testTracingCrossCluster
   issue: https://github.com/elastic/elasticsearch/issues/112731
-- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
-  method: testPutE5Small_withPlatformSpecificVariant
-  issue: https://github.com/elastic/elasticsearch/issues/113950
 - class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT
   method: test {p0=search.vectors/42_knn_search_int4_flat/Vector similarity with filter only}
   issue: https://github.com/elastic/elasticsearch/issues/115475
 - class: org.elasticsearch.reservedstate.service.FileSettingsServiceTests
   method: testProcessFileChanges
   issue: https://github.com/elastic/elasticsearch/issues/115280
-- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
-  method: testInferDeploysDefaultE5
-  issue: https://github.com/elastic/elasticsearch/issues/115361
 - class: org.elasticsearch.xpack.inference.InferenceCrudIT
   method: testSupportedStream
   issue: https://github.com/elastic/elasticsearch/issues/113430
@@ -244,9 +232,6 @@ tests:
 - class: org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT
   method: test {p0=esql/61_enrich_ip/IP strings}
   issue: https://github.com/elastic/elasticsearch/issues/116529
-- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
-  method: testInferDeploysDefaultElser
-  issue: https://github.com/elastic/elasticsearch/issues/114913
 - class: org.elasticsearch.threadpool.SimpleThreadPoolIT
   method: testThreadPoolMetrics
   issue: https://github.com/elastic/elasticsearch/issues/108320
@@ -301,9 +286,6 @@ tests:
 - class: org.elasticsearch.discovery.ClusterDisruptionIT
   method: testAckedIndexing
   issue: https://github.com/elastic/elasticsearch/issues/117024
-- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
-  method: testMultipleInferencesTriggeringDownloadAndDeploy
-  issue: https://github.com/elastic/elasticsearch/issues/117208
 
 # Examples:
 #
@@ -365,9 +347,6 @@ tests:
 - class: org.elasticsearch.xpack.esql.action.EsqlActionTaskIT
   method: testCancelRequestWhenFailingFetchingPages
   issue: https://github.com/elastic/elasticsearch/issues/118213
-- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
-  method: testInferDeploysDefaultRerank
-  issue: https://github.com/elastic/elasticsearch/issues/118184
 - class: org.elasticsearch.reservedstate.service.RepositoriesFileSettingsIT
   method: testSettingsApplied
   issue: https://github.com/elastic/elasticsearch/issues/116694

+ 42 - 0
test/framework/src/main/java/org/elasticsearch/test/fixture/HttpHeaderParser.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.test.fixture;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+public enum HttpHeaderParser {
+    ;
+
+    private static final Pattern RANGE_HEADER_PATTERN = Pattern.compile("bytes=([0-9]+)-([0-9]+)");
+
+    /**
+     * Parse a "Range" header
+     *
+     * Note: only a single bounded range is supported (e.g. <code>Range: bytes={range_start}-{range_end}</code>)
+     *
+     * @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range">MDN: Range header</a>
+     * @param rangeHeaderValue The header value as a string
+     * @return a {@link Range} instance representing the parsed value, or null if the header is malformed
+     */
+    public static Range parseRangeHeader(String rangeHeaderValue) {
+        final Matcher matcher = RANGE_HEADER_PATTERN.matcher(rangeHeaderValue);
+        if (matcher.matches()) {
+            try {
+                return new Range(Long.parseLong(matcher.group(1)), Long.parseLong(matcher.group(2)));
+            } catch (NumberFormatException e) {
+                return null;
+            }
+        }
+        return null;
+    }
+
+    public record Range(long start, long end) {}
+}

+ 53 - 0
test/framework/src/test/java/org/elasticsearch/http/HttpHeaderParserTests.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.http;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.fixture.HttpHeaderParser;
+
+import java.math.BigInteger;
+
+public class HttpHeaderParserTests extends ESTestCase {
+
+    public void testParseRangeHeader() {
+        final long start = randomLongBetween(0, 10_000);
+        final long end = randomLongBetween(start, start + 10_000);
+        assertEquals(new HttpHeaderParser.Range(start, end), HttpHeaderParser.parseRangeHeader("bytes=" + start + "-" + end));
+    }
+
+    public void testParseRangeHeaderInvalidLong() {
+        final BigInteger longOverflow = BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE).add(randomBigInteger());
+        assertNull(HttpHeaderParser.parseRangeHeader("bytes=123-" + longOverflow));
+        assertNull(HttpHeaderParser.parseRangeHeader("bytes=" + longOverflow + "-123"));
+    }
+
+    public void testParseRangeHeaderMultipleRangesNotMatched() {
+        assertNull(
+            HttpHeaderParser.parseRangeHeader(
+                Strings.format(
+                    "bytes=%d-%d,%d-%d",
+                    randomIntBetween(0, 99),
+                    randomIntBetween(100, 199),
+                    randomIntBetween(200, 299),
+                    randomIntBetween(300, 399)
+                )
+            )
+        );
+    }
+
+    public void testParseRangeHeaderEndlessRangeNotMatched() {
+        assertNull(HttpHeaderParser.parseRangeHeader(Strings.format("bytes=%d-", randomLongBetween(0, Long.MAX_VALUE))));
+    }
+
+    public void testParseRangeHeaderSuffixLengthNotMatched() {
+        assertNull(HttpHeaderParser.parseRangeHeader(Strings.format("bytes=-%d", randomLongBetween(0, Long.MAX_VALUE))));
+    }
+}

+ 1 - 0
x-pack/plugin/inference/qa/inference-service-tests/build.gradle

@@ -1,6 +1,7 @@
 apply plugin: 'elasticsearch.internal-java-rest-test'
 
 dependencies {
+  javaRestTestImplementation project(path: xpackModule('core'))
   javaRestTestImplementation project(path: xpackModule('inference'))
   clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
 }

+ 18 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
+import org.junit.Before;
 import org.junit.ClassRule;
 
 import java.io.IOException;
@@ -41,6 +42,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 
 public class InferenceBaseRestTest extends ESRestTestCase {
+
     @ClassRule
     public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
         .distribution(DistributionType.DEFAULT)
@@ -51,6 +53,22 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
         .build();
 
+    @ClassRule
+    public static MlModelServer mlModelServer = new MlModelServer();
+
+    @Before
+    public void setMlModelRepository() throws IOException {
+        logger.info("setting ML model repository to: {}", mlModelServer.getUrl());
+        var request = new Request("PUT", "/_cluster/settings");
+        request.setJsonEntity(Strings.format("""
+            {
+              "persistent": {
+                "xpack.ml.model_repository": "%s"
+              }
+            }""", mlModelServer.getUrl()));
+        assertOK(client().performRequest(request));
+    }
+
     @Override
     protected String getTestRestCluster() {
         return cluster.getHttpAddresses();

+ 146 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java

@@ -0,0 +1,146 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import com.sun.net.httpserver.HttpExchange;
+import com.sun.net.httpserver.HttpServer;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.HttpStatus;
+import org.apache.http.client.utils.URIBuilder;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+import org.elasticsearch.test.fixture.HttpHeaderParser;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.XPackSettings;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.InetSocketAddress;
+import java.nio.charset.StandardCharsets;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/**
+ * Simple model server to serve ML models.
+ * The URL path corresponds to a file name in this class's resources.
+ * If the file is found, its content is returned, otherwise 404.
+ * Respects a range header to serve partial content.
+ */
+public class MlModelServer implements TestRule {
+
+    private static final String HOST = "localhost";
+    private static final Logger logger = LogManager.getLogger(MlModelServer.class);
+
+    private int port;
+
+    public String getUrl() {
+        return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
+    }
+
+    private void handle(HttpExchange exchange) throws IOException {
+        String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
+        HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
+        logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);
+
+        try (InputStream is = getInputStream(exchange)) {
+            int httpStatus;
+            long numBytes;
+            if (is == null) {
+                httpStatus = HttpStatus.SC_NOT_FOUND;
+                numBytes = 0;
+            } else if (range == null) {
+                httpStatus = HttpStatus.SC_OK;
+                numBytes = is.available();
+            } else {
+                httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
+                is.skipNBytes(range.start());
+                numBytes = range.end() - range.start() + 1;
+            }
+            logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
+            exchange.sendResponseHeaders(httpStatus, numBytes);
+            try (OutputStream os = exchange.getResponseBody()) {
+                while (numBytes > 0) {
+                    byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
+                    os.write(bytes);
+                    numBytes -= bytes.length;
+                }
+            }
+        }
+    }
+
+    private InputStream getInputStream(HttpExchange exchange) throws IOException {
+        String path = exchange.getRequestURI().getPath().substring(1);  // Strip leading slash
+        String modelId = path.substring(0, path.indexOf('.'));
+        String extension = path.substring(path.indexOf('.') + 1);
+
+        // If a model specifically optimized for some platform is requested,
+        // serve the default non-optimized model instead, which is compatible.
+        String defaultModelId = modelId;
+        for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
+            defaultModelId = defaultModelId.replace("_" + platform, "");
+        }
+
+        ClassLoader classloader = Thread.currentThread().getContextClassLoader();
+        InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
+        if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
+            // When an optimized version is requested, fix the default metadata,
+            // so that it contains the correct model ID.
+            try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
+                is.close();
+                ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
+                packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
+                is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
+            }
+        }
+        return is;
+    }
+
+    @Override
+    public Statement apply(Statement statement, Description description) {
+        return new Statement() {
+            @Override
+            public void evaluate() throws Throwable {
+                logger.info("Starting ML model server");
+                HttpServer server = HttpServer.create();
+                while (true) {
+                    port = new Random().nextInt(10000, 65536);
+                    try {
+                        server.bind(new InetSocketAddress(HOST, port), 1);
+                    } catch (Exception e) {
+                        continue;
+                    }
+                    break;
+                }
+                logger.info("Bound ML model server to port {}", port);
+
+                ExecutorService executor = Executors.newCachedThreadPool();
+                server.setExecutor(executor);
+                server.createContext("/", MlModelServer.this::handle);
+                server.start();
+
+                try {
+                    statement.evaluate();
+                } finally {
+                    logger.info("Stopping ML model server on port {}", port);
+                    server.stop(1);
+                    executor.shutdown();
+                }
+            }
+        };
+    }
+}

+ 0 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

@@ -18,8 +18,6 @@ import java.util.Map;
 
 import static org.hamcrest.Matchers.containsString;
 
-// This test was previously disabled in CI due to the models being too large
-// See "https://github.com/elastic/elasticsearch/issues/105198".
 public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
 
     public void testPutE5Small_withNoModelVariant() {

+ 25 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.metadata.json

@@ -0,0 +1,25 @@
+{
+  "packaged_model_id": "elser_model_2",
+  "minimum_version": "11.0.0",
+  "size": 1859242,
+  "sha256": "602dbccfb2746e5700bf65d8019b06fb2ec1e3c5bfb980eb2005fc17c1bfe0c0",
+  "description": "Elastic Learned Sparse EncodeR v2",
+  "model_type": "pytorch",
+  "tags": [
+    "elastic"
+  ],
+  "inference_config": {
+    "text_expansion": {
+      "tokenization": {
+        "bert": {
+          "do_lower_case": true,
+          "with_special_tokens": true,
+          "max_sequence_length": 512,
+          "truncate": "first",
+          "span": -1
+        }
+      }
+    }
+  },
+  "vocabulary_file": "elser_model_2.vocab.json"
+}

BIN
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.pt


Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 0 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.vocab.json


+ 32 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/multilingual-e5-small.metadata.json

@@ -0,0 +1,32 @@
+{
+  "packaged_model_id": "multilingual-e5-small",
+  "minimum_version": "12.0.0",
+  "size": 5531160,
+  "sha256": "92e24566eff554d3a6808cc62731dbecf32db63e01801f3f62210aa9131c7a8b",
+  "description": "E5 small multilingual",
+  "model_type": "pytorch",
+  "tags": [],
+  "inference_config": {
+    "text_embedding": {
+      "tokenization": {
+        "xlm_roberta": {
+          "do_lower_case": false,
+          "with_special_tokens": true,
+          "max_sequence_length": 512,
+          "truncate": "first",
+          "span": -1
+        }
+      },
+      "embedding_size": 384
+    }
+  },
+  "prefix_strings": {
+      "search": "query: ",
+      "ingest": "passage: "
+  },
+  "metadata": {
+    "per_allocation_memory_bytes": 557785256,
+    "per_deployment_memory_bytes": 470031872
+  },
+  "vocabulary_file": "multilingual-e5-small.vocab.json"
+}

BIN
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/multilingual-e5-small.pt


Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 0 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/multilingual-e5-small.vocab.json


+ 15 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/rerank-v1.metadata.json

@@ -0,0 +1,15 @@
+{
+    "packaged_model_id": "rerank-v1",
+    "minimum_version": "9.0.0",
+    "size": 12419194,
+    "sha256": "8d37d7240175b59a1a82f409e572c4d0136acff875da980ec5e5e1783263a042",
+    "description": "Elastic Rerank v1",
+    "model_type": "pytorch",
+    "tags": [
+          "curated"
+    ],
+    "inference_config": {
+        "text_similarity": {"tokenization": {"deberta_v2": {"truncate": "balanced"}}}
+    },
+    "vocabulary_file": "rerank-v1.vocab.json"
+  }

BIN
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/rerank-v1.pt


Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 0 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/rerank-v1.vocab.json


+ 4 - 1
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java

@@ -333,6 +333,7 @@ final class ModelLoaderUtils {
      * in size. The separate range for the final chunk is because when streaming and
      * uploading a large model definition, writing the last part has to handled
      * as a special case.
+     * Less ranges may be returned in case the stream size is too small.
      * @param sizeInBytes The total size of the stream
      * @param numberOfStreams Divide the bulk of the size into this many streams.
      * @param chunkSizeBytes The size of each chunk
@@ -340,7 +341,9 @@ final class ModelLoaderUtils {
      */
     static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
         int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);
-
+        if (numberOfStreams > numberOfChunks) {
+            numberOfStreams = numberOfChunks;
+        }
         var ranges = new ArrayList<RequestRange>();
 
         int baseChunksPerStream = numberOfChunks / numberOfStreams;

+ 1 - 5
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java

@@ -19,7 +19,6 @@ import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
@@ -45,11 +44,9 @@ import static org.elasticsearch.core.Strings.format;
 public class TransportGetTrainedModelPackageConfigAction extends TransportMasterNodeAction<Request, Response> {
 
     private static final Logger logger = LogManager.getLogger(TransportGetTrainedModelPackageConfigAction.class);
-    private final Settings settings;
 
     @Inject
     public TransportGetTrainedModelPackageConfigAction(
-        Settings settings,
         TransportService transportService,
         ClusterService clusterService,
         ThreadPool threadPool,
@@ -67,12 +64,11 @@ public class TransportGetTrainedModelPackageConfigAction extends TransportMaster
             GetTrainedModelPackageConfigAction.Response::new,
             EsExecutors.DIRECT_EXECUTOR_SERVICE
         );
-        this.settings = settings;
     }
 
     @Override
     protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) throws Exception {
-        String repository = MachineLearningPackageLoader.MODEL_REPOSITORY.get(settings);
+        String repository = clusterService.getClusterSettings().get(MachineLearningPackageLoader.MODEL_REPOSITORY);
 
         String packagedModelId = request.getPackagedModelId();
         logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository));

+ 17 - 0
x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java

@@ -143,6 +143,23 @@ public class ModelLoaderUtilsTests extends ESTestCase {
         assertThat(finalRange.numParts(), is(1));
     }
 
+    public void testSplitIntoRanges_numRangesSmallerThanNumStreams() {
+        long totalSize = 2_142;
+        int numStreams = 10;
+        int chunkSize = 1_000;
+        var ranges = ModelLoaderUtils.split(totalSize, numStreams, chunkSize);
+        assertThat(ranges.toString(), ranges, hasSize(3));
+        assertThat(ranges.get(0).rangeStart(), is(0L));
+        assertThat(ranges.get(0).rangeEnd(), is(999L));
+        assertThat(ranges.get(0).numParts(), is(1));
+        assertThat(ranges.get(1).rangeStart(), is(1000L));
+        assertThat(ranges.get(1).rangeEnd(), is(1999L));
+        assertThat(ranges.get(1).numParts(), is(1));
+        assertThat(ranges.get(2).rangeStart(), is(2000L));
+        assertThat(ranges.get(2).rangeEnd(), is(2141L));
+        assertThat(ranges.get(2).numParts(), is(1));
+    }
+
     public void testRangeRequestBytesRange() {
         long start = randomLongBetween(0, 2 << 10);
         long end = randomLongBetween(start + 1, 2 << 11);

Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels