|
@@ -14,11 +14,8 @@ import org.elasticsearch.common.xcontent.XContentParserUtils;
|
|
|
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.index.IndexVersions;
|
|
|
-import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
import org.elasticsearch.inference.ChunkedInference;
|
|
|
-import org.elasticsearch.inference.Model;
|
|
|
-import org.elasticsearch.inference.SimilarityMeasure;
|
|
|
-import org.elasticsearch.inference.TaskType;
|
|
|
+import org.elasticsearch.inference.MinimalServiceSettings;
|
|
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.xcontent.DeprecationHandler;
|
|
|
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
|
@@ -38,10 +35,7 @@ import java.util.Iterator;
|
|
|
import java.util.LinkedHashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
-import java.util.Objects;
|
|
|
|
|
|
-import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
|
|
|
-import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
|
|
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
|
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
|
|
@@ -75,117 +69,13 @@ public record SemanticTextField(
|
|
|
static final String CHUNKED_START_OFFSET_FIELD = "start_offset";
|
|
|
static final String CHUNKED_END_OFFSET_FIELD = "end_offset";
|
|
|
static final String MODEL_SETTINGS_FIELD = "model_settings";
|
|
|
- static final String TASK_TYPE_FIELD = "task_type";
|
|
|
- static final String DIMENSIONS_FIELD = "dimensions";
|
|
|
- static final String SIMILARITY_FIELD = "similarity";
|
|
|
- static final String ELEMENT_TYPE_FIELD = "element_type";
|
|
|
|
|
|
- public record InferenceResult(String inferenceId, ModelSettings modelSettings, Map<String, List<Chunk>> chunks) {}
|
|
|
+ public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map<String, List<Chunk>> chunks) {}
|
|
|
|
|
|
public record Chunk(@Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {}
|
|
|
|
|
|
public record Offset(String sourceFieldName, int startOffset, int endOffset) {}
|
|
|
|
|
|
- public record ModelSettings(
|
|
|
- TaskType taskType,
|
|
|
- Integer dimensions,
|
|
|
- SimilarityMeasure similarity,
|
|
|
- DenseVectorFieldMapper.ElementType elementType
|
|
|
- ) implements ToXContentObject {
|
|
|
- public ModelSettings(Model model) {
|
|
|
- this(
|
|
|
- model.getTaskType(),
|
|
|
- model.getServiceSettings().dimensions(),
|
|
|
- model.getServiceSettings().similarity(),
|
|
|
- model.getServiceSettings().elementType()
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
- public ModelSettings(
|
|
|
- TaskType taskType,
|
|
|
- Integer dimensions,
|
|
|
- SimilarityMeasure similarity,
|
|
|
- DenseVectorFieldMapper.ElementType elementType
|
|
|
- ) {
|
|
|
- this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
|
|
|
- this.dimensions = dimensions;
|
|
|
- this.similarity = similarity;
|
|
|
- this.elementType = elementType;
|
|
|
- validate();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
- builder.startObject();
|
|
|
- builder.field(TASK_TYPE_FIELD, taskType.toString());
|
|
|
- if (dimensions != null) {
|
|
|
- builder.field(DIMENSIONS_FIELD, dimensions);
|
|
|
- }
|
|
|
- if (similarity != null) {
|
|
|
- builder.field(SIMILARITY_FIELD, similarity);
|
|
|
- }
|
|
|
- if (elementType != null) {
|
|
|
- builder.field(ELEMENT_TYPE_FIELD, elementType);
|
|
|
- }
|
|
|
- return builder.endObject();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String toString() {
|
|
|
- final StringBuilder sb = new StringBuilder();
|
|
|
- sb.append("task_type=").append(taskType);
|
|
|
- if (dimensions != null) {
|
|
|
- sb.append(", dimensions=").append(dimensions);
|
|
|
- }
|
|
|
- if (similarity != null) {
|
|
|
- sb.append(", similarity=").append(similarity);
|
|
|
- }
|
|
|
- if (elementType != null) {
|
|
|
- sb.append(", element_type=").append(elementType);
|
|
|
- }
|
|
|
- return sb.toString();
|
|
|
- }
|
|
|
-
|
|
|
- private void validate() {
|
|
|
- switch (taskType) {
|
|
|
- case TEXT_EMBEDDING:
|
|
|
- validateFieldPresent(DIMENSIONS_FIELD, dimensions);
|
|
|
- validateFieldPresent(SIMILARITY_FIELD, similarity);
|
|
|
- validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
|
|
|
- break;
|
|
|
- case SPARSE_EMBEDDING:
|
|
|
- validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
|
|
|
- validateFieldNotPresent(SIMILARITY_FIELD, similarity);
|
|
|
- validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
|
|
|
- break;
|
|
|
-
|
|
|
- default:
|
|
|
- throw new IllegalArgumentException(
|
|
|
- "Wrong ["
|
|
|
- + TASK_TYPE_FIELD
|
|
|
- + "], expected "
|
|
|
- + TEXT_EMBEDDING
|
|
|
- + " or "
|
|
|
- + SPARSE_EMBEDDING
|
|
|
- + ", got "
|
|
|
- + taskType.name()
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private void validateFieldPresent(String field, Object fieldValue) {
|
|
|
- if (fieldValue == null) {
|
|
|
- throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private void validateFieldNotPresent(String field, Object fieldValue) {
|
|
|
- if (fieldValue != null) {
|
|
|
- throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
public static String getOriginalTextFieldName(String fieldName) {
|
|
|
return fieldName + "." + TEXT_FIELD;
|
|
|
}
|
|
@@ -212,7 +102,7 @@ public record SemanticTextField(
|
|
|
return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context);
|
|
|
}
|
|
|
|
|
|
- static ModelSettings parseModelSettingsFromMap(Object node) {
|
|
|
+ static MinimalServiceSettings parseModelSettingsFromMap(Object node) {
|
|
|
if (node == null) {
|
|
|
return null;
|
|
|
}
|
|
@@ -224,7 +114,7 @@ public record SemanticTextField(
|
|
|
map,
|
|
|
XContentType.JSON
|
|
|
);
|
|
|
- return MODEL_SETTINGS_PARSER.parse(parser, null);
|
|
|
+ return MinimalServiceSettings.parse(parser);
|
|
|
} catch (Exception exc) {
|
|
|
throw new ElasticsearchException(exc);
|
|
|
}
|
|
@@ -307,7 +197,7 @@ public record SemanticTextField(
|
|
|
private static final ConstructingObjectParser<InferenceResult, ParserContext> INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>(
|
|
|
INFERENCE_FIELD,
|
|
|
true,
|
|
|
- args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (Map<String, List<Chunk>>) args[2])
|
|
|
+ args -> new InferenceResult((String) args[0], (MinimalServiceSettings) args[1], (Map<String, List<Chunk>>) args[2])
|
|
|
);
|
|
|
|
|
|
private static final ConstructingObjectParser<Chunk, ParserContext> CHUNKS_PARSER = new ConstructingObjectParser<>(
|
|
@@ -322,20 +212,6 @@ public record SemanticTextField(
|
|
|
}
|
|
|
);
|
|
|
|
|
|
- private static final ConstructingObjectParser<ModelSettings, Void> MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>(
|
|
|
- MODEL_SETTINGS_FIELD,
|
|
|
- true,
|
|
|
- args -> {
|
|
|
- TaskType taskType = TaskType.fromString((String) args[0]);
|
|
|
- Integer dimensions = (Integer) args[1];
|
|
|
- SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
|
|
|
- DenseVectorFieldMapper.ElementType elementType = args[3] == null
|
|
|
- ? null
|
|
|
- : DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
|
|
|
- return new ModelSettings(taskType, dimensions, similarity, elementType);
|
|
|
- }
|
|
|
- );
|
|
|
-
|
|
|
static {
|
|
|
SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD));
|
|
|
SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD));
|
|
@@ -343,7 +219,7 @@ public record SemanticTextField(
|
|
|
INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD));
|
|
|
INFERENCE_RESULT_PARSER.declareObjectOrNull(
|
|
|
constructorArg(),
|
|
|
- (p, c) -> MODEL_SETTINGS_PARSER.parse(p, null),
|
|
|
+ (p, c) -> MinimalServiceSettings.parse(p),
|
|
|
null,
|
|
|
new ParseField(MODEL_SETTINGS_FIELD)
|
|
|
);
|
|
@@ -362,11 +238,6 @@ public record SemanticTextField(
|
|
|
b.copyCurrentStructure(p);
|
|
|
return BytesReference.bytes(b);
|
|
|
}, new ParseField(CHUNKED_EMBEDDINGS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY);
|
|
|
-
|
|
|
- MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
|
|
|
- MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
|
|
|
- MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
|
|
|
- MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ELEMENT_TYPE_FIELD));
|
|
|
}
|
|
|
|
|
|
private static Map<String, List<Chunk>> parseChunksMap(XContentParser parser, ParserContext context) throws IOException {
|