|
@@ -56,6 +56,9 @@ import org.elasticsearch.index.mapper.SourceToParse;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldTypeTests;
|
|
|
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
|
|
|
+import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapperTests;
|
|
|
+import org.elasticsearch.index.mapper.vectors.SparseVectorFieldTypeTests;
|
|
|
+import org.elasticsearch.index.mapper.vectors.TokenPruningConfig;
|
|
|
import org.elasticsearch.index.query.SearchExecutionContext;
|
|
|
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
|
|
|
import org.elasticsearch.inference.ChunkingSettings;
|
|
@@ -95,6 +98,7 @@ import java.util.function.BiConsumer;
|
|
|
import java.util.function.Supplier;
|
|
|
|
|
|
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldTypeTests.randomIndexOptionsAll;
|
|
|
+import static org.elasticsearch.index.mapper.vectors.SparseVectorFieldTypeTests.randomSparseVectorIndexOptions;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD;
|
|
@@ -113,6 +117,9 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.ra
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.instanceOf;
|
|
|
+import static org.mockito.ArgumentMatchers.anyString;
|
|
|
+import static org.mockito.Mockito.spy;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
|
|
|
public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
private final boolean useLegacyFormat;
|
|
@@ -123,9 +130,20 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
this.useLegacyFormat = useLegacyFormat;
|
|
|
}
|
|
|
|
|
|
+ ModelRegistry globalModelRegistry;
|
|
|
+
|
|
|
@Before
|
|
|
private void startThreadPool() {
|
|
|
threadPool = createThreadPool();
|
|
|
+ var clusterService = ClusterServiceUtils.createClusterService(threadPool);
|
|
|
+ var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
|
|
|
+ globalModelRegistry = spy(modelRegistry);
|
|
|
+ globalModelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
|
|
|
+ @Override
|
|
|
+ public boolean localNodeMaster() {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@After
|
|
@@ -140,18 +158,10 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
|
|
|
@Override
|
|
|
protected Collection<? extends Plugin> getPlugins() {
|
|
|
- var clusterService = ClusterServiceUtils.createClusterService(threadPool);
|
|
|
- var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
|
|
|
- modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
|
|
|
- @Override
|
|
|
- public boolean localNodeMaster() {
|
|
|
- return false;
|
|
|
- }
|
|
|
- });
|
|
|
return List.of(new InferencePlugin(Settings.EMPTY) {
|
|
|
@Override
|
|
|
protected Supplier<ModelRegistry> getModelRegistry() {
|
|
|
- return () -> modelRegistry;
|
|
|
+ return () -> globalModelRegistry;
|
|
|
}
|
|
|
}, new XPackClientPlugin());
|
|
|
}
|
|
@@ -174,6 +184,11 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
) throws IOException {
|
|
|
validateIndexVersion(minIndexVersion, useLegacyFormat);
|
|
|
IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion);
|
|
|
+ return createMapperServiceWithIndexVersion(mappings, useLegacyFormat, indexVersion);
|
|
|
+ }
|
|
|
+
|
|
|
+ private MapperService createMapperServiceWithIndexVersion(XContentBuilder mappings, boolean useLegacyFormat, IndexVersion indexVersion)
|
|
|
+ throws IOException {
|
|
|
var settings = Settings.builder()
|
|
|
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion)
|
|
|
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
|
|
@@ -189,17 +204,6 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private MapperService createMapperService(String mappings, boolean useLegacyFormat) throws IOException {
|
|
|
- var settings = Settings.builder()
|
|
|
- .put(
|
|
|
- IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(),
|
|
|
- SemanticInferenceMetadataFieldsMapperTests.getRandomCompatibleIndexVersion(useLegacyFormat)
|
|
|
- )
|
|
|
- .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
|
|
|
- .build();
|
|
|
- return createMapperService(settings, mappings);
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
protected Settings getIndexSettings() {
|
|
|
return Settings.builder()
|
|
@@ -380,6 +384,14 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private SemanticTextIndexOptions getDefaultSparseVectorIndexOptionsForMapper(MapperService mapperService) {
|
|
|
+ var mapperIndexVersion = mapperService.getIndexSettings().getIndexVersionCreated();
|
|
|
+ var defaultSparseVectorIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(mapperIndexVersion);
|
|
|
+ return defaultSparseVectorIndexOptions == null
|
|
|
+ ? null
|
|
|
+ : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, defaultSparseVectorIndexOptions);
|
|
|
+ }
|
|
|
+
|
|
|
public void testInvalidTaskTypes() {
|
|
|
for (var taskType : TaskType.values()) {
|
|
|
if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.SPARSE_EMBEDDING) {
|
|
@@ -415,7 +427,13 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
}), useLegacyFormat));
|
|
|
assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields"));
|
|
|
} else {
|
|
|
- var mapperService = createMapperService(fieldMapping(b -> {
|
|
|
+ IndexVersion indexVersion = SparseVectorFieldMapperTests.getIndexOptionsCompatibleIndexVersion();
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions expectedIndexOptions = SparseVectorFieldMapper.SparseVectorIndexOptions
|
|
|
+ .getDefaultIndexOptions(indexVersion);
|
|
|
+ SemanticTextIndexOptions semanticTextIndexOptions = expectedIndexOptions == null
|
|
|
+ ? null
|
|
|
+ : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, expectedIndexOptions);
|
|
|
+ var mapperService = createMapperServiceWithIndexVersion(fieldMapping(b -> {
|
|
|
b.field("type", "text");
|
|
|
b.startObject("fields");
|
|
|
b.startObject("semantic");
|
|
@@ -426,10 +444,10 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
b.endObject();
|
|
|
b.endObject();
|
|
|
b.endObject();
|
|
|
- }), useLegacyFormat);
|
|
|
- assertSemanticTextField(mapperService, "field.semantic", true, null, null);
|
|
|
+ }), useLegacyFormat, indexVersion);
|
|
|
+ assertSemanticTextField(mapperService, "field.semantic", true, null, semanticTextIndexOptions);
|
|
|
|
|
|
- mapperService = createMapperService(fieldMapping(b -> {
|
|
|
+ mapperService = createMapperServiceWithIndexVersion(fieldMapping(b -> {
|
|
|
b.field("type", "semantic_text");
|
|
|
b.field("inference_id", "my_inference_id");
|
|
|
b.startObject("model_settings");
|
|
@@ -440,10 +458,10 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
b.field("type", "text");
|
|
|
b.endObject();
|
|
|
b.endObject();
|
|
|
- }), useLegacyFormat);
|
|
|
- assertSemanticTextField(mapperService, "field", true, null, null);
|
|
|
+ }), useLegacyFormat, indexVersion);
|
|
|
+ assertSemanticTextField(mapperService, "field", true, null, semanticTextIndexOptions);
|
|
|
|
|
|
- mapperService = createMapperService(fieldMapping(b -> {
|
|
|
+ mapperService = createMapperServiceWithIndexVersion(fieldMapping(b -> {
|
|
|
b.field("type", "semantic_text");
|
|
|
b.field("inference_id", "my_inference_id");
|
|
|
b.startObject("model_settings");
|
|
@@ -458,9 +476,9 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
b.endObject();
|
|
|
b.endObject();
|
|
|
b.endObject();
|
|
|
- }), useLegacyFormat);
|
|
|
- assertSemanticTextField(mapperService, "field", true, null, null);
|
|
|
- assertSemanticTextField(mapperService, "field.semantic", true, null, null);
|
|
|
+ }), useLegacyFormat, indexVersion);
|
|
|
+ assertSemanticTextField(mapperService, "field", true, null, semanticTextIndexOptions);
|
|
|
+ assertSemanticTextField(mapperService, "field.semantic", true, null, semanticTextIndexOptions);
|
|
|
|
|
|
Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
|
|
|
b.field("type", "semantic_text");
|
|
@@ -472,7 +490,6 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
b.endObject();
|
|
|
}), useLegacyFormat));
|
|
|
assertThat(e.getMessage(), containsString("is already used by another field"));
|
|
|
-
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -504,7 +521,8 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
inferenceId,
|
|
|
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, null)
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
|
|
|
}
|
|
|
|
|
@@ -515,7 +533,8 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
searchInferenceId,
|
|
|
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, null)
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId);
|
|
|
}
|
|
|
}
|
|
@@ -559,14 +578,16 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
.endObject()
|
|
|
)
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
}
|
|
|
{
|
|
|
merge(
|
|
|
mapperService,
|
|
|
mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject())
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
}
|
|
|
{
|
|
|
Exception exc = expectThrows(
|
|
@@ -614,6 +635,87 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void addSparseVectorModelSettingsToBuilder(XContentBuilder b) throws IOException {
|
|
|
+ b.startObject("model_settings");
|
|
|
+ b.field("task_type", TaskType.SPARSE_EMBEDDING);
|
|
|
+ b.endObject();
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSparseVectorIndexOptionsValidationAndMapping() throws IOException {
|
|
|
+ for (int depth = 1; depth < 5; depth++) {
|
|
|
+ String inferenceId = "test_model";
|
|
|
+ String fieldName = randomFieldName(depth);
|
|
|
+ IndexVersion indexVersion = SparseVectorFieldMapperTests.getIndexOptionsCompatibleIndexVersion();
|
|
|
+ var sparseVectorIndexOptions = SparseVectorFieldTypeTests.randomSparseVectorIndexOptions();
|
|
|
+ var expectedIndexOptions = sparseVectorIndexOptions == null
|
|
|
+ ? null
|
|
|
+ : new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, sparseVectorIndexOptions);
|
|
|
+
|
|
|
+ // should not throw an exception
|
|
|
+ MapperService mapper = createMapperServiceWithIndexVersion(mapping(b -> {
|
|
|
+ b.startObject(fieldName);
|
|
|
+ {
|
|
|
+ b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
|
|
|
+ b.field(INFERENCE_ID_FIELD, inferenceId);
|
|
|
+ addSparseVectorModelSettingsToBuilder(b);
|
|
|
+ if (sparseVectorIndexOptions != null) {
|
|
|
+ b.startObject(INDEX_OPTIONS_FIELD);
|
|
|
+ {
|
|
|
+ b.field(SparseVectorFieldMapper.CONTENT_TYPE);
|
|
|
+ sparseVectorIndexOptions.toXContent(b, null);
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }), useLegacyFormat, indexVersion);
|
|
|
+
|
|
|
+ assertSemanticTextField(mapper, fieldName, true, null, expectedIndexOptions);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSparseVectorMappingUpdate() throws IOException {
|
|
|
+ for (int i = 0; i < 5; i++) {
|
|
|
+ Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
+ when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer(
|
|
|
+ invocation -> { return new MinimalServiceSettings(model); }
|
|
|
+ );
|
|
|
+
|
|
|
+ final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false);
|
|
|
+ IndexVersion indexVersion = SparseVectorFieldMapperTests.getIndexOptionsCompatibleIndexVersion();
|
|
|
+ final SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING);
|
|
|
+ String fieldName = "field";
|
|
|
+
|
|
|
+ MapperService mapperService = createMapperServiceWithIndexVersion(
|
|
|
+ mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings, indexOptions)),
|
|
|
+ useLegacyFormat,
|
|
|
+ indexVersion
|
|
|
+ );
|
|
|
+ var expectedIndexOptions = (indexOptions == null)
|
|
|
+ ? new SemanticTextIndexOptions(
|
|
|
+ SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion)
|
|
|
+ )
|
|
|
+ : indexOptions;
|
|
|
+ assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, expectedIndexOptions);
|
|
|
+
|
|
|
+ final SemanticTextIndexOptions newIndexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING);
|
|
|
+ expectedIndexOptions = (newIndexOptions == null)
|
|
|
+ ? new SemanticTextIndexOptions(
|
|
|
+ SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion)
|
|
|
+ )
|
|
|
+ : newIndexOptions;
|
|
|
+
|
|
|
+ ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings);
|
|
|
+ merge(
|
|
|
+ mapperService,
|
|
|
+ mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, newChunkingSettings, newIndexOptions))
|
|
|
+ );
|
|
|
+ assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings, expectedIndexOptions);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testUpdateSearchInferenceId() throws IOException {
|
|
|
final String inferenceId = "test_inference_id";
|
|
|
final String searchInferenceId1 = "test_search_inference_id_1";
|
|
@@ -650,27 +752,24 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
inferenceId,
|
|
|
new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null)
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ var expectedIndexOptions = getDefaultSparseVectorIndexOptionsForMapper(mapperService);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
|
|
|
|
|
|
merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1));
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1);
|
|
|
|
|
|
merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2));
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2);
|
|
|
|
|
|
merge(mapperService, buildMapping.apply(fieldName, null));
|
|
|
- assertSemanticTextField(mapperService, fieldName, true, null, null);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, true, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) {
|
|
|
- assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null, null);
|
|
|
- }
|
|
|
-
|
|
|
private static void assertSemanticTextField(
|
|
|
MapperService mapperService,
|
|
|
String fieldName,
|
|
@@ -720,9 +819,20 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) {
|
|
|
case SPARSE_EMBEDDING -> {
|
|
|
assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class));
|
|
|
- SparseVectorFieldMapper sparseMapper = (SparseVectorFieldMapper) embeddingsMapper;
|
|
|
- assertEquals(sparseMapper.fieldType().isStored(), semanticTextFieldType.useLegacyFormat() == false);
|
|
|
- assertNull(expectedIndexOptions);
|
|
|
+ SparseVectorFieldMapper sparseVectorFieldMapper = (SparseVectorFieldMapper) embeddingsMapper;
|
|
|
+ assertEquals(sparseVectorFieldMapper.fieldType().isStored(), semanticTextFieldType.useLegacyFormat() == false);
|
|
|
+
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions applied = sparseVectorFieldMapper.fieldType().getIndexOptions();
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions expected = expectedIndexOptions == null
|
|
|
+ ? null
|
|
|
+ : (SparseVectorFieldMapper.SparseVectorIndexOptions) expectedIndexOptions.indexOptions();
|
|
|
+ if (expected == null && applied != null) {
|
|
|
+ var indexVersionCreated = mapperService.getIndexSettings().getIndexVersionCreated();
|
|
|
+ if (SparseVectorFieldMapper.SparseVectorIndexOptions.isDefaultOptions(applied, indexVersionCreated)) {
|
|
|
+ expected = SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersionCreated);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ assertEquals(expected, applied);
|
|
|
}
|
|
|
case TEXT_EMBEDDING -> {
|
|
|
assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class));
|
|
@@ -763,6 +873,8 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
|
|
|
public void testSuccessfulParse() throws IOException {
|
|
|
for (int depth = 1; depth < 4; depth++) {
|
|
|
+ final IndexVersion indexVersion = SemanticInferenceMetadataFieldsMapperTests.getRandomCompatibleIndexVersion(useLegacyFormat);
|
|
|
+
|
|
|
final String fieldName1 = randomFieldName(depth);
|
|
|
final String fieldName2 = randomFieldName(depth + 1);
|
|
|
final String searchInferenceId = randomAlphaOfLength(8);
|
|
@@ -771,6 +883,18 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
TaskType taskType = TaskType.SPARSE_EMBEDDING;
|
|
|
Model model1 = TestModel.createRandomInstance(taskType);
|
|
|
Model model2 = TestModel.createRandomInstance(taskType);
|
|
|
+
|
|
|
+ when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer(invocation -> {
|
|
|
+ var modelId = (String) invocation.getArguments()[0];
|
|
|
+ if (modelId.equals(model1.getInferenceEntityId())) {
|
|
|
+ return new MinimalServiceSettings(model1);
|
|
|
+ }
|
|
|
+ if (modelId.equals(model2.getInferenceEntityId())) {
|
|
|
+ return new MinimalServiceSettings(model2);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ });
|
|
|
+
|
|
|
ChunkingSettings chunkingSettings = null; // Some chunking settings configs can produce different Lucene docs counts
|
|
|
SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(taskType);
|
|
|
XContentBuilder mapping = mapping(b -> {
|
|
@@ -792,15 +916,22 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
);
|
|
|
});
|
|
|
|
|
|
- MapperService mapperService = createMapperService(mapping, useLegacyFormat);
|
|
|
- assertSemanticTextField(mapperService, fieldName1, false, null, null);
|
|
|
+ var expectedIndexOptions = (indexOptions == null)
|
|
|
+ ? new SemanticTextIndexOptions(
|
|
|
+ SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion)
|
|
|
+ )
|
|
|
+ : indexOptions;
|
|
|
+
|
|
|
+ MapperService mapperService = createMapperServiceWithIndexVersion(mapping, useLegacyFormat, indexVersion);
|
|
|
+ assertSemanticTextField(mapperService, fieldName1, false, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(
|
|
|
mapperService,
|
|
|
fieldName1,
|
|
|
model1.getInferenceEntityId(),
|
|
|
setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId()
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName2, false, null, null);
|
|
|
+ assertSemanticTextField(mapperService, fieldName2, false, null, expectedIndexOptions);
|
|
|
assertInferenceEndpoints(
|
|
|
mapperService,
|
|
|
fieldName2,
|
|
@@ -1015,24 +1146,19 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
|
|
|
public void testSettingAndUpdatingChunkingSettings() throws IOException {
|
|
|
Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
+ when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer(
|
|
|
+ invocation -> { return new MinimalServiceSettings(model); }
|
|
|
+ );
|
|
|
+
|
|
|
final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false);
|
|
|
- final SemanticTextIndexOptions indexOptions = null;
|
|
|
+ final SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(TaskType.SPARSE_EMBEDDING);
|
|
|
String fieldName = "field";
|
|
|
|
|
|
- SemanticTextField randomSemanticText = randomSemanticText(
|
|
|
- useLegacyFormat,
|
|
|
- fieldName,
|
|
|
- model,
|
|
|
- chunkingSettings,
|
|
|
- List.of("a"),
|
|
|
- XContentType.JSON
|
|
|
- );
|
|
|
-
|
|
|
MapperService mapperService = createMapperService(
|
|
|
mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings, indexOptions)),
|
|
|
useLegacyFormat
|
|
|
);
|
|
|
- assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, null);
|
|
|
+ assertSemanticTextField(mapperService, fieldName, false, chunkingSettings, indexOptions);
|
|
|
|
|
|
ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings);
|
|
|
merge(
|
|
@@ -1046,6 +1172,11 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
// Create inference results where model settings are set to null and chunks are provided
|
|
|
TaskType taskType = TaskType.SPARSE_EMBEDDING;
|
|
|
Model model = TestModel.createRandomInstance(taskType);
|
|
|
+
|
|
|
+ when(globalModelRegistry.getMinimalServiceSettings(anyString())).thenAnswer(
|
|
|
+ invocation -> { return new MinimalServiceSettings(model); }
|
|
|
+ );
|
|
|
+
|
|
|
ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false);
|
|
|
SemanticTextIndexOptions indexOptions = randomSemanticTextIndexOptions(taskType);
|
|
|
SemanticTextField randomSemanticText = randomSemanticText(
|
|
@@ -1196,6 +1327,13 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ private static SemanticTextIndexOptions defaultSparseVectorIndexOptions(IndexVersion indexVersion) {
|
|
|
+ return new SemanticTextIndexOptions(
|
|
|
+ SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions.getDefaultIndexOptions(indexVersion)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
public void testDefaultIndexOptions() throws IOException {
|
|
|
|
|
|
// We default to BBQ for eligible dense vectors
|
|
@@ -1318,6 +1456,42 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
IndexVersionUtils.getPreviousVersion(IndexVersions.SEMANTIC_TEXT_DEFAULTS_TO_BBQ_BACKPORT_8_X)
|
|
|
);
|
|
|
assertSemanticTextField(mapperService, "field", true, null, defaultDenseVectorSemanticIndexOptions());
|
|
|
+
|
|
|
+ mapperService = createMapperService(fieldMapping(b -> {
|
|
|
+ b.field("type", "semantic_text");
|
|
|
+ b.field("inference_id", "another_inference_id");
|
|
|
+ b.startObject("model_settings");
|
|
|
+ b.field("task_type", "sparse_embedding");
|
|
|
+ b.endObject();
|
|
|
+ }),
|
|
|
+ useLegacyFormat,
|
|
|
+ IndexVersionUtils.getPreviousVersion(IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT),
|
|
|
+ IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT
|
|
|
+ );
|
|
|
+
|
|
|
+ assertSemanticTextField(
|
|
|
+ mapperService,
|
|
|
+ "field",
|
|
|
+ true,
|
|
|
+ null,
|
|
|
+ defaultSparseVectorIndexOptions(mapperService.getIndexSettings().getIndexVersionCreated())
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSparseVectorIndexOptionsDefaultsBeforeSupport() throws IOException {
|
|
|
+ var mapperService = createMapperService(fieldMapping(b -> {
|
|
|
+ b.field("type", "semantic_text");
|
|
|
+ b.field("inference_id", "another_inference_id");
|
|
|
+ b.startObject("model_settings");
|
|
|
+ b.field("task_type", "sparse_embedding");
|
|
|
+ b.endObject();
|
|
|
+ }),
|
|
|
+ useLegacyFormat,
|
|
|
+ IndexVersions.INFERENCE_METADATA_FIELDS,
|
|
|
+ IndexVersionUtils.getPreviousVersion(IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT)
|
|
|
+ );
|
|
|
+
|
|
|
+ assertSemanticTextField(mapperService, "field", true, null, null);
|
|
|
}
|
|
|
|
|
|
public void testSpecifiedDenseVectorIndexOptions() throws IOException {
|
|
@@ -1428,7 +1602,74 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
b.endObject();
|
|
|
}), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT));
|
|
|
assertThat(e.getMessage(), containsString("Unsupported index options type invalid"));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSpecificSparseVectorIndexOptions() throws IOException {
|
|
|
+ for (int i = 0; i < 10; i++) {
|
|
|
+ SparseVectorFieldMapper.SparseVectorIndexOptions testIndexOptions = randomSparseVectorIndexOptions(false);
|
|
|
+ var mapperService = createMapperService(fieldMapping(b -> {
|
|
|
+ b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
|
|
|
+ b.field(INFERENCE_ID_FIELD, "test_inference_id");
|
|
|
+ addSparseVectorModelSettingsToBuilder(b);
|
|
|
+ b.startObject(INDEX_OPTIONS_FIELD);
|
|
|
+ {
|
|
|
+ b.field(SparseVectorFieldMapper.CONTENT_TYPE);
|
|
|
+ testIndexOptions.toXContent(b, null);
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT);
|
|
|
|
|
|
+ assertSemanticTextField(
|
|
|
+ mapperService,
|
|
|
+ "field",
|
|
|
+ true,
|
|
|
+ null,
|
|
|
+ new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR, testIndexOptions)
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSparseVectorIndexOptionsValidations() throws IOException {
|
|
|
+ Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
|
|
|
+ b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
|
|
|
+ b.field(INFERENCE_ID_FIELD, "test_inference_id");
|
|
|
+ b.startObject(INDEX_OPTIONS_FIELD);
|
|
|
+ {
|
|
|
+ b.startObject(SparseVectorFieldMapper.CONTENT_TYPE);
|
|
|
+ {
|
|
|
+ b.field("prune", false);
|
|
|
+ b.startObject("pruning_config");
|
|
|
+ {
|
|
|
+ b.field(TokenPruningConfig.TOKENS_FREQ_RATIO_THRESHOLD.getPreferredName(), 5.0f);
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT));
|
|
|
+ assertThat(e.getMessage(), containsString("failed to parse field [pruning_config]"));
|
|
|
+
|
|
|
+ e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
|
|
|
+ b.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
|
|
|
+ b.field(INFERENCE_ID_FIELD, "test_inference_id");
|
|
|
+ b.startObject(INDEX_OPTIONS_FIELD);
|
|
|
+ {
|
|
|
+ b.startObject(SparseVectorFieldMapper.CONTENT_TYPE);
|
|
|
+ {
|
|
|
+ b.field("prune", true);
|
|
|
+ b.startObject("pruning_config");
|
|
|
+ {
|
|
|
+ b.field(TokenPruningConfig.TOKENS_FREQ_RATIO_THRESHOLD.getPreferredName(), 1000.0f);
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }
|
|
|
+ b.endObject();
|
|
|
+ }), useLegacyFormat, IndexVersions.INFERENCE_METADATA_FIELDS_BACKPORT));
|
|
|
+ var innerClause = e.getCause().getCause().getCause().getCause();
|
|
|
+ assertThat(innerClause.getMessage(), containsString("[tokens_freq_ratio_threshold] must be between [1] and [100], got 1000.0"));
|
|
|
}
|
|
|
|
|
|
public static SemanticTextIndexOptions randomSemanticTextIndexOptions() {
|
|
@@ -1437,13 +1678,21 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
|
|
|
}
|
|
|
|
|
|
public static SemanticTextIndexOptions randomSemanticTextIndexOptions(TaskType taskType) {
|
|
|
-
|
|
|
if (taskType == TaskType.TEXT_EMBEDDING) {
|
|
|
return randomBoolean()
|
|
|
? null
|
|
|
: new SemanticTextIndexOptions(SemanticTextIndexOptions.SupportedIndexOptions.DENSE_VECTOR, randomIndexOptionsAll());
|
|
|
}
|
|
|
|
|
|
+ if (taskType == TaskType.SPARSE_EMBEDDING) {
|
|
|
+ return randomBoolean()
|
|
|
+ ? null
|
|
|
+ : new SemanticTextIndexOptions(
|
|
|
+ SemanticTextIndexOptions.SupportedIndexOptions.SPARSE_VECTOR,
|
|
|
+ randomSparseVectorIndexOptions(false)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
return null;
|
|
|
}
|
|
|
|