|
|
@@ -0,0 +1,146 @@
|
|
|
+/*
|
|
|
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
+ * or more contributor license agreements. Licensed under the Elastic License
|
|
|
+ * 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
+ * 2.0.
|
|
|
+ */
|
|
|
+
|
|
|
+package org.elasticsearch.xpack.inference;
|
|
|
+
|
|
|
+import com.sun.net.httpserver.HttpExchange;
|
|
|
+import com.sun.net.httpserver.HttpServer;
|
|
|
+
|
|
|
+import org.apache.http.HttpHeaders;
|
|
|
+import org.apache.http.HttpStatus;
|
|
|
+import org.apache.http.client.utils.URIBuilder;
|
|
|
+import org.elasticsearch.logging.LogManager;
|
|
|
+import org.elasticsearch.logging.Logger;
|
|
|
+import org.elasticsearch.test.fixture.HttpHeaderParser;
|
|
|
+import org.elasticsearch.xcontent.XContentParser;
|
|
|
+import org.elasticsearch.xcontent.XContentParserConfiguration;
|
|
|
+import org.elasticsearch.xcontent.XContentType;
|
|
|
+import org.elasticsearch.xpack.core.XPackSettings;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
|
|
|
+import org.junit.rules.TestRule;
|
|
|
+import org.junit.runner.Description;
|
|
|
+import org.junit.runners.model.Statement;
|
|
|
+
|
|
|
+import java.io.ByteArrayInputStream;
|
|
|
+import java.io.IOException;
|
|
|
+import java.io.InputStream;
|
|
|
+import java.io.OutputStream;
|
|
|
+import java.net.InetSocketAddress;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import java.util.Random;
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+
|
|
|
+/**
|
|
|
+ * Simple model server to serve ML models.
|
|
|
+ * The URL path corresponds to a file name in this class's resources.
|
|
|
+ * If the file is found, its content is returned, otherwise 404.
|
|
|
+ * Respects a range header to serve partial content.
|
|
|
+ */
|
|
|
+public class MlModelServer implements TestRule {
|
|
|
+
|
|
|
+ private static final String HOST = "localhost";
|
|
|
+ private static final Logger logger = LogManager.getLogger(MlModelServer.class);
|
|
|
+
|
|
|
+ private int port;
|
|
|
+
|
|
|
+ public String getUrl() {
|
|
|
+ return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
|
|
|
+ }
|
|
|
+
|
|
|
+ private void handle(HttpExchange exchange) throws IOException {
|
|
|
+ String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
|
|
|
+ HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
|
|
|
+ logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);
|
|
|
+
|
|
|
+ try (InputStream is = getInputStream(exchange)) {
|
|
|
+ int httpStatus;
|
|
|
+ long numBytes;
|
|
|
+ if (is == null) {
|
|
|
+ httpStatus = HttpStatus.SC_NOT_FOUND;
|
|
|
+ numBytes = 0;
|
|
|
+ } else if (range == null) {
|
|
|
+ httpStatus = HttpStatus.SC_OK;
|
|
|
+ numBytes = is.available();
|
|
|
+ } else {
|
|
|
+ httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
|
|
|
+ is.skipNBytes(range.start());
|
|
|
+ numBytes = range.end() - range.start() + 1;
|
|
|
+ }
|
|
|
+ logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
|
|
|
+ exchange.sendResponseHeaders(httpStatus, numBytes);
|
|
|
+ try (OutputStream os = exchange.getResponseBody()) {
|
|
|
+ while (numBytes > 0) {
|
|
|
+ byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
|
|
|
+ os.write(bytes);
|
|
|
+ numBytes -= bytes.length;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private InputStream getInputStream(HttpExchange exchange) throws IOException {
|
|
|
+ String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash
|
|
|
+ String modelId = path.substring(0, path.indexOf('.'));
|
|
|
+ String extension = path.substring(path.indexOf('.') + 1);
|
|
|
+
|
|
|
+ // If a model specifically optimized for some platform is requested,
|
|
|
+ // serve the default non-optimized model instead, which is compatible.
|
|
|
+ String defaultModelId = modelId;
|
|
|
+ for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
|
|
|
+ defaultModelId = defaultModelId.replace("_" + platform, "");
|
|
|
+ }
|
|
|
+
|
|
|
+ ClassLoader classloader = Thread.currentThread().getContextClassLoader();
|
|
|
+ InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
|
|
|
+ if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
|
|
|
+ // When an optimized version is requested, fix the default metadata,
|
|
|
+ // so that it contains the correct model ID.
|
|
|
+ try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
|
|
|
+ is.close();
|
|
|
+ ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
|
|
|
+ packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
|
|
|
+ is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return is;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Statement apply(Statement statement, Description description) {
|
|
|
+ return new Statement() {
|
|
|
+ @Override
|
|
|
+ public void evaluate() throws Throwable {
|
|
|
+ logger.info("Starting ML model server");
|
|
|
+ HttpServer server = HttpServer.create();
|
|
|
+ while (true) {
|
|
|
+ port = new Random().nextInt(10000, 65536);
|
|
|
+ try {
|
|
|
+ server.bind(new InetSocketAddress(HOST, port), 1);
|
|
|
+ } catch (Exception e) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ logger.info("Bound ML model server to port {}", port);
|
|
|
+
|
|
|
+ ExecutorService executor = Executors.newCachedThreadPool();
|
|
|
+ server.setExecutor(executor);
|
|
|
+ server.createContext("/", MlModelServer.this::handle);
|
|
|
+ server.start();
|
|
|
+
|
|
|
+ try {
|
|
|
+ statement.evaluate();
|
|
|
+ } finally {
|
|
|
+ logger.info("Stopping ML model server on port {}", port);
|
|
|
+ server.stop(1);
|
|
|
+ executor.shutdown();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+}
|