|
@@ -41,6 +41,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.hamcrest.Matchers.in;
|
|
|
+import static org.hamcrest.Matchers.instanceOf;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.Matchers.not;
|
|
|
import static org.hamcrest.Matchers.nullValue;
|
|
@@ -193,8 +194,7 @@ public class LearningToRankRescorerBuilderRewriteTests extends AbstractBuilderTe
|
|
|
|
|
|
public void testBuildContext() throws Exception {
|
|
|
LocalModel localModel = mock(LocalModel.class);
|
|
|
- List<String> inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME);
|
|
|
- when(localModel.inputFields()).thenReturn(inputFields);
|
|
|
+ when(localModel.inputFields()).thenReturn(GOOD_MODEL_CONFIG.getInput().getFieldNames());
|
|
|
|
|
|
IndexSearcher searcher = mock(IndexSearcher.class);
|
|
|
doAnswer(invocation -> invocation.getArgument(0)).when(searcher).rewrite(any(Query.class));
|
|
@@ -211,11 +211,48 @@ public class LearningToRankRescorerBuilderRewriteTests extends AbstractBuilderTe
|
|
|
assertNotNull(rescoreContext);
|
|
|
assertThat(rescoreContext.getWindowSize(), equalTo(20));
|
|
|
List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher());
|
|
|
- assertThat(featureExtractors, hasSize(2));
|
|
|
- assertThat(
|
|
|
- featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(),
|
|
|
- containsInAnyOrder("feature_1", "feature_2", DOUBLE_FIELD_NAME, INT_FIELD_NAME)
|
|
|
+ assertThat(featureExtractors, hasSize(1));
|
|
|
+
|
|
|
+ FeatureExtractor queryExtractor = featureExtractors.get(0);
|
|
|
+ assertThat(queryExtractor, instanceOf(QueryFeatureExtractor.class));
|
|
|
+ assertThat(queryExtractor.featureNames(), hasSize(2));
|
|
|
+ assertThat(queryExtractor.featureNames(), containsInAnyOrder("feature_1", "feature_2"));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testLegacyFieldValueExtractorBuildContext() throws Exception {
|
|
|
+ // Models created before 8.15 have been saved with input fields.
|
|
|
+ // We check field value extractors are created and the deduplication is done correctly.
|
|
|
+ LocalModel localModel = mock(LocalModel.class);
|
|
|
+ when(localModel.inputFields()).thenReturn(List.of("feature_1", "field_1", "field_2"));
|
|
|
+
|
|
|
+ IndexSearcher searcher = mock(IndexSearcher.class);
|
|
|
+ doAnswer(invocation -> invocation.getArgument(0)).when(searcher).rewrite(any(Query.class));
|
|
|
+ SearchExecutionContext context = createSearchExecutionContext(searcher);
|
|
|
+
|
|
|
+ LearningToRankRescorerBuilder rescorerBuilder = new LearningToRankRescorerBuilder(
|
|
|
+ localModel,
|
|
|
+ (LearningToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig(),
|
|
|
+ null,
|
|
|
+ mock(LearningToRankService.class)
|
|
|
);
|
|
|
+
|
|
|
+ LearningToRankRescorerContext rescoreContext = rescorerBuilder.innerBuildContext(20, context);
|
|
|
+ assertNotNull(rescoreContext);
|
|
|
+ assertThat(rescoreContext.getWindowSize(), equalTo(20));
|
|
|
+ List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher());
|
|
|
+
|
|
|
+ assertThat(featureExtractors, hasSize(2));
|
|
|
+
|
|
|
+ FeatureExtractor queryExtractor = featureExtractors.stream().filter(fe -> fe instanceof QueryFeatureExtractor).findFirst().get();
|
|
|
+ assertThat(queryExtractor.featureNames(), hasSize(2));
|
|
|
+ assertThat(queryExtractor.featureNames(), containsInAnyOrder("feature_1", "feature_2"));
|
|
|
+
|
|
|
+ FeatureExtractor fieldValueExtractor = featureExtractors.stream()
|
|
|
+ .filter(fe -> fe instanceof FieldValueFeatureExtractor)
|
|
|
+ .findFirst()
|
|
|
+ .get();
|
|
|
+ assertThat(fieldValueExtractor.featureNames(), hasSize(2));
|
|
|
+ assertThat(fieldValueExtractor.featureNames(), containsInAnyOrder("field_1", "field_2"));
|
|
|
}
|
|
|
|
|
|
private LearningToRankRescorerBuilder rewriteAndFetch(
|