|
@@ -23,6 +23,8 @@ import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.support.WriteRequest;
|
|
|
import org.elasticsearch.client.ml.GetBucketsRequest;
|
|
|
import org.elasticsearch.client.ml.GetBucketsResponse;
|
|
|
+import org.elasticsearch.client.ml.GetCategoriesRequest;
|
|
|
+import org.elasticsearch.client.ml.GetCategoriesResponse;
|
|
|
import org.elasticsearch.client.ml.GetInfluencersRequest;
|
|
|
import org.elasticsearch.client.ml.GetInfluencersResponse;
|
|
|
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
|
|
@@ -126,11 +128,150 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
|
bulkRequest.add(indexRequest);
|
|
|
}
|
|
|
|
|
|
+ private void addCategoryIndexRequest(long categoryId, String categoryName, BulkRequest bulkRequest) {
|
|
|
+ IndexRequest indexRequest = new IndexRequest(RESULTS_INDEX, DOC);
|
|
|
+ indexRequest.source("{\"job_id\":\"" + JOB_ID + "\", \"category_id\": " + categoryId + ", \"terms\": \"" +
|
|
|
+ categoryName + "\", \"regex\": \".*?" + categoryName + ".*\", \"max_matching_length\": 3, \"examples\": [\"" +
|
|
|
+ categoryName + "\"]}", XContentType.JSON);
|
|
|
+ bulkRequest.add(indexRequest);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void addCategoriesIndexRequests(BulkRequest bulkRequest) {
|
|
|
+
|
|
|
+ List<String> categories = Arrays.asList("AAL", "JZA", "JBU");
|
|
|
+
|
|
|
+ for (int i = 0; i < categories.size(); i++) {
|
|
|
+ addCategoryIndexRequest(i+1, categories.get(i), bulkRequest);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
@After
|
|
|
public void deleteJob() throws IOException {
|
|
|
new MlRestTestStateCleaner(logger, client()).clearMlMetadata();
|
|
|
}
|
|
|
|
|
|
+ public void testGetCategories() throws IOException {
|
|
|
+
|
|
|
+ // index some category results
|
|
|
+ BulkRequest bulkRequest = new BulkRequest();
|
|
|
+ bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
+
|
|
|
+ addCategoriesIndexRequests(bulkRequest);
|
|
|
+
|
|
|
+ highLevelClient().bulk(bulkRequest, RequestOptions.DEFAULT);
|
|
|
+
|
|
|
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
+
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setPageParams(new PageParams(0, 10000));
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(3L));
|
|
|
+ assertThat(response.categories().size(), equalTo(3));
|
|
|
+ assertThat(response.categories().get(0).getCategoryId(), equalTo(1L));
|
|
|
+ assertThat(response.categories().get(0).getGrokPattern(), equalTo(".*?AAL.*"));
|
|
|
+ assertThat(response.categories().get(0).getRegex(), equalTo(".*?AAL.*"));
|
|
|
+ assertThat(response.categories().get(0).getTerms(), equalTo("AAL"));
|
|
|
+
|
|
|
+ assertThat(response.categories().get(1).getCategoryId(), equalTo(2L));
|
|
|
+ assertThat(response.categories().get(1).getGrokPattern(), equalTo(".*?JZA.*"));
|
|
|
+ assertThat(response.categories().get(1).getRegex(), equalTo(".*?JZA.*"));
|
|
|
+ assertThat(response.categories().get(1).getTerms(), equalTo("JZA"));
|
|
|
+
|
|
|
+ assertThat(response.categories().get(2).getCategoryId(), equalTo(3L));
|
|
|
+ assertThat(response.categories().get(2).getGrokPattern(), equalTo(".*?JBU.*"));
|
|
|
+ assertThat(response.categories().get(2).getRegex(), equalTo(".*?JBU.*"));
|
|
|
+ assertThat(response.categories().get(2).getTerms(), equalTo("JBU"));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setPageParams(new PageParams(0, 1));
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(3L));
|
|
|
+ assertThat(response.categories().size(), equalTo(1));
|
|
|
+ assertThat(response.categories().get(0).getCategoryId(), equalTo(1L));
|
|
|
+ assertThat(response.categories().get(0).getGrokPattern(), equalTo(".*?AAL.*"));
|
|
|
+ assertThat(response.categories().get(0).getRegex(), equalTo(".*?AAL.*"));
|
|
|
+ assertThat(response.categories().get(0).getTerms(), equalTo("AAL"));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setPageParams(new PageParams(1, 2));
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(3L));
|
|
|
+ assertThat(response.categories().size(), equalTo(2));
|
|
|
+ assertThat(response.categories().get(0).getCategoryId(), equalTo(2L));
|
|
|
+ assertThat(response.categories().get(0).getGrokPattern(), equalTo(".*?JZA.*"));
|
|
|
+ assertThat(response.categories().get(0).getRegex(), equalTo(".*?JZA.*"));
|
|
|
+ assertThat(response.categories().get(0).getTerms(), equalTo("JZA"));
|
|
|
+
|
|
|
+ assertThat(response.categories().get(1).getCategoryId(), equalTo(3L));
|
|
|
+ assertThat(response.categories().get(1).getGrokPattern(), equalTo(".*?JBU.*"));
|
|
|
+ assertThat(response.categories().get(1).getRegex(), equalTo(".*?JBU.*"));
|
|
|
+ assertThat(response.categories().get(1).getTerms(), equalTo("JBU"));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setCategoryId(0L); // request a non-existent category
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(0L));
|
|
|
+ assertThat(response.categories().size(), equalTo(0));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setCategoryId(1L);
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(1L));
|
|
|
+ assertThat(response.categories().size(), equalTo(1));
|
|
|
+ assertThat(response.categories().get(0).getCategoryId(), equalTo(1L));
|
|
|
+ assertThat(response.categories().get(0).getGrokPattern(), equalTo(".*?AAL.*"));
|
|
|
+ assertThat(response.categories().get(0).getRegex(), equalTo(".*?AAL.*"));
|
|
|
+ assertThat(response.categories().get(0).getTerms(), equalTo("AAL"));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setCategoryId(2L);
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(1L));
|
|
|
+ assertThat(response.categories().get(0).getCategoryId(), equalTo(2L));
|
|
|
+ assertThat(response.categories().get(0).getGrokPattern(), equalTo(".*?JZA.*"));
|
|
|
+ assertThat(response.categories().get(0).getRegex(), equalTo(".*?JZA.*"));
|
|
|
+ assertThat(response.categories().get(0).getTerms(), equalTo("JZA"));
|
|
|
+
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetCategoriesRequest request = new GetCategoriesRequest(JOB_ID);
|
|
|
+ request.setCategoryId(3L);
|
|
|
+
|
|
|
+ GetCategoriesResponse response = execute(request, machineLearningClient::getCategories,
|
|
|
+ machineLearningClient::getCategoriesAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(1L));
|
|
|
+ assertThat(response.categories().get(0).getCategoryId(), equalTo(3L));
|
|
|
+ assertThat(response.categories().get(0).getGrokPattern(), equalTo(".*?JBU.*"));
|
|
|
+ assertThat(response.categories().get(0).getRegex(), equalTo(".*?JBU.*"));
|
|
|
+ assertThat(response.categories().get(0).getTerms(), equalTo("JBU"));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testGetBuckets() throws IOException {
|
|
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
|