浏览代码

[ML] Fix Get DeepSeek Model (#124802) (#124812)

When secrets are null, we should allow the model to return for the Get
Model API.
Pat Whelan 7 月之前
父节点
当前提交
42b80dda12

+ 8 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java

@@ -10,6 +10,8 @@ package org.elasticsearch.xpack.inference.external.deepseek;
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
 import org.apache.http.entity.ByteArrayEntity;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.xcontent.ToXContent;
@@ -29,6 +31,7 @@ import java.util.Objects;
 import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
 
 public class DeepSeekChatCompletionRequest implements Request {
+    private static final Logger logger = LogManager.getLogger(DeepSeekChatCompletionRequest.class);
     private static final String MODEL_FIELD = "model";
     private static final String MAX_TOKENS = "max_tokens";
 
@@ -47,7 +50,11 @@ public class DeepSeekChatCompletionRequest implements Request {
         httpPost.setEntity(createEntity());
 
         httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
-        httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
+        model.apiKey()
+            .ifPresentOrElse(
+                apiKey -> httpPost.setHeader(createAuthBearerHeader(apiKey)),
+                () -> logger.debug("No auth token present in request, sending without auth...")
+            );
 
         return new HttpRequest(httpPost, getInferenceEntityId());
     }

+ 6 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
@@ -30,6 +31,7 @@ import java.net.URI;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
@@ -63,6 +65,7 @@ public class DeepSeekChatCompletionModel extends Model {
 
     private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions");
     private final DeepSeekServiceSettings serviceSettings;
+    @Nullable
     private final DefaultSecretSettings secretSettings;
 
     public static List<NamedWriteableRegistry.Entry> namedWriteables() {
@@ -126,7 +129,7 @@ public class DeepSeekChatCompletionModel extends Model {
 
     private DeepSeekChatCompletionModel(
         DeepSeekServiceSettings serviceSettings,
-        DefaultSecretSettings secretSettings,
+        @Nullable DefaultSecretSettings secretSettings,
         ModelConfigurations configurations,
         ModelSecrets secrets
     ) {
@@ -135,8 +138,8 @@ public class DeepSeekChatCompletionModel extends Model {
         this.secretSettings = secretSettings;
     }
 
-    public SecureString apiKey() {
-        return secretSettings.apiKey();
+    public Optional<SecureString> apiKey() {
+        return Optional.ofNullable(secretSettings).map(DefaultSecretSettings::apiKey);
     }
 
     public String model() {

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java

@@ -150,7 +150,8 @@ public class DeepSeekService extends SenderService {
 
     @Override
     public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
-        return parsePersistedConfigWithSecrets(modelId, taskType, config, config);
+        var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+        return DeepSeekChatCompletionModel.readFromStorage(modelId, taskType, NAME, serviceSettingsMap, null);
     }
 
     @Override

+ 10 - 28
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java

@@ -44,6 +44,7 @@ import java.net.URISyntaxException;
 import java.nio.charset.StandardCharsets;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.ExceptionsHelper.unwrapCause;
@@ -90,7 +91,7 @@ public class DeepSeekServiceTests extends ESTestCase {
             }
             """, webServer.getUri(null).toString()), assertNoFailureListener(model -> {
             if (model instanceof DeepSeekChatCompletionModel deepSeekModel) {
-                assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
+                assertThat(deepSeekModel.apiKey().get().getChars(), equalTo("12345".toCharArray()));
                 assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
                 assertThat(deepSeekModel.uri(), equalTo(webServer.getUri(null)));
             } else {
@@ -158,13 +159,10 @@ public class DeepSeekServiceTests extends ESTestCase {
             {
               "service_settings": {
                 "model_id": "some-cool-model"
-              },
-              "secret_settings": {
-                "api_key": "12345"
               }
             }
             """);
-        assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
+        assertThat(deepSeekModel.apiKey(), equalTo(Optional.empty()));
         assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
     }
 
@@ -174,33 +172,14 @@ public class DeepSeekServiceTests extends ESTestCase {
               "service_settings": {
                 "model_id": "some-cool-model",
                 "url": "http://localhost:989"
-              },
-              "secret_settings": {
-                "api_key": "12345"
               }
             }
             """);
-        assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
+        assertThat(deepSeekModel.apiKey(), equalTo(Optional.empty()));
         assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
         assertThat(deepSeekModel.uri(), equalTo(URI.create("http://localhost:989")));
     }
 
-    public void testParsePersistedConfigWithoutApiKey() {
-        assertThrows(
-            "Validation Failed: 1: [secret_settings] does not contain the required setting [api_key];",
-            ValidationException.class,
-            () -> parsePersistedConfig("""
-                {
-                  "service_settings": {
-                    "model_id": "some-cool-model"
-                  },
-                  "secret_settings": {
-                  }
-                }
-                """)
-        );
-    }
-
     public void testParsePersistedConfigWithoutModel() {
         assertThrows(
             "Validation Failed: 1: [service_settings] does not contain the required setting [model];",
@@ -424,17 +403,20 @@ public class DeepSeekServiceTests extends ESTestCase {
     }
 
     private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException {
-        var model = service.parsePersistedConfig("inference-id", taskType, map(Strings.format("""
+        var model = service.parsePersistedConfigWithSecrets("inference-id", taskType, map(Strings.format("""
             {
               "service_settings": {
                 "model_id": "some-cool-model",
                 "url": "%s"
-              },
+              }
+            }
+            """, webServer.getUri(null).toString())), map("""
+            {
               "secret_settings": {
                 "api_key": "12345"
               }
             }
-            """, webServer.getUri(null).toString())));
+            """));
         assertThat(model, isA(DeepSeekChatCompletionModel.class));
         return (DeepSeekChatCompletionModel) model;
     }