Przeglądaj źródła

Add geometry example (#1631)

Signed-off-by: yhmo <yihua.mo@zilliz.com>
groot 1 tydzień temu
rodzic
commit
de4202dc11

+ 2 - 2
docker-compose.yml

@@ -3,7 +3,7 @@ version: '3.5'
 services:
   standalone:
     container_name: milvus-javasdk-standalone-1
-    image: milvusdb/milvus:master-20250927-cc53b25b
+    image: milvusdb/milvus:master-20250929-ca1cc7c9-amd64
     command: [ "milvus", "run", "standalone" ]
     environment:
       - COMMON_STORAGETYPE=local
@@ -24,7 +24,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-standalone-2
-    image: milvusdb/milvus:master-20250927-cc53b25b
+    image: milvusdb/milvus:master-20250929-ca1cc7c9-amd64
     command: [ "milvus", "run", "standalone" ]
     environment:
       - COMMON_STORAGETYPE=local

+ 155 - 0
examples/src/main/java/io/milvus/v2/GeometryExample.java

@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.milvus.v2;
+
+import com.google.gson.JsonObject;
+import io.milvus.common.utils.JsonUtils;
+import io.milvus.v1.CommonUtils;
+import io.milvus.v2.client.ConnectConfig;
+import io.milvus.v2.client.MilvusClientV2;
+import io.milvus.v2.common.ConsistencyLevel;
+import io.milvus.v2.common.DataType;
+import io.milvus.v2.common.IndexParam;
+import io.milvus.v2.service.collection.request.AddFieldReq;
+import io.milvus.v2.service.collection.request.CreateCollectionReq;
+import io.milvus.v2.service.collection.request.DropCollectionReq;
+import io.milvus.v2.service.collection.request.LoadCollectionReq;
+import io.milvus.v2.service.index.request.CreateIndexReq;
+import io.milvus.v2.service.vector.request.InsertReq;
+import io.milvus.v2.service.vector.request.QueryReq;
+import io.milvus.v2.service.vector.response.QueryResp;
+
+import java.util.*;
+
+public class GeometryExample {
+    private static final MilvusClientV2 client;
+    static {
+        client = new MilvusClientV2(ConnectConfig.builder()
+                .uri("http://localhost:19530")
+                .build());
+    }
+
+    private static final String COLLECTION_NAME = "java_sdk_example_geometry_v2";
+    private static final String ID_FIELD = "id";
+    private static final String GEO_FIELD = "geometry";
+    private static final String VECTOR_FIELD = "vector";
+    private static final Integer VECTOR_DIM = 128;
+
+    private static void createCollection() {
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(ID_FIELD)
+                .dataType(DataType.Int64)
+                .isPrimaryKey(true)
+                .autoID(true)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(GEO_FIELD)
+                .dataType(DataType.Geometry)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(VECTOR_FIELD)
+                .dataType(DataType.FloatVector)
+                .dimension(VECTOR_DIM)
+                .build());
+
+        client.dropCollection(DropCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .collectionSchema(collectionSchema)
+                .build();
+        client.createCollection(requestCreate);
+
+        List<IndexParam> indexParams = new ArrayList<>();
+        indexParams.add(IndexParam.builder()
+                .fieldName(VECTOR_FIELD)
+                .indexType(IndexParam.IndexType.AUTOINDEX)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+        // geometry index no need metric type
+        indexParams.add(IndexParam.builder()
+                .fieldName(GEO_FIELD)
+                .indexType(IndexParam.IndexType.RTREE)
+                .build());
+        client.createIndex(CreateIndexReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .indexParams(indexParams)
+                .build());
+        client.loadCollection(LoadCollectionReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .build());
+        System.out.println("Collection created: " + COLLECTION_NAME);
+    }
+
+    private static void insertGeometry(String geo) {
+        JsonObject row = new JsonObject();
+        row.addProperty(GEO_FIELD, geo);
+        row.add(VECTOR_FIELD, JsonUtils.toJsonTree(CommonUtils.generateFloatVector(VECTOR_DIM)));
+
+        client.insert(InsertReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .data(Collections.singletonList(row))
+                .build());
+        System.out.println("Inserted geometry: " + geo);
+    }
+
+    private static void printRowCount() {
+        QueryResp countR = client.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .outputFields(Collections.singletonList("count(*)"))
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .build());
+        System.out.printf("%d rows persisted\n", (long)countR.getQueryResults().get(0).getEntity().get("count(*)"));
+    }
+
+    private static void query(String filter) {
+        System.out.println("===================================================");
+        System.out.println("Query with filter expression: " + filter);
+        QueryResp queryResp = client.query(QueryReq.builder()
+                .collectionName(COLLECTION_NAME)
+                .filter(filter)
+                .consistencyLevel(ConsistencyLevel.BOUNDED)
+                .outputFields(Collections.singletonList(GEO_FIELD))
+                .build());
+        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+        System.out.println("Query results:");
+        for (QueryResp.QueryResult result : queryResults) {
+            System.out.println(result.getEntity());
+        }
+    }
+
+    public static void main(String[] args) {
+        createCollection();
+        insertGeometry("POINT (1 1)");
+        insertGeometry("LINESTRING (10 10, 10 30, 40 40)");
+        insertGeometry("POLYGON ((0 100, 100 100, 100 50, 0 50, 0 100))");
+        printRowCount();
+
+        query("ST_EQUALS(" + GEO_FIELD + ", 'POINT (1 1)')");
+        query("ST_TOUCHES(" + GEO_FIELD + ", 'LINESTRING (0 50, 0 100)')");
+        query("ST_CONTAINS(" + GEO_FIELD + ", 'POINT (70 70)')");
+        query("ST_CROSSES(" + GEO_FIELD + ", 'LINESTRING (20 0, 20 100)')");
+        query("ST_WITHIN(" + GEO_FIELD + ", 'POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))')");
+    }
+}

+ 4 - 0
sdk-core/src/main/java/io/milvus/v2/common/IndexParam.java

@@ -88,6 +88,10 @@ public class IndexParam {
 
         // Only for varchar type field
         TRIE("Trie", 100),
+
+        // Only for geometry type field
+        RTREE("RTREE", 120),
+
         // Only for scalar type field
         STL_SORT(200), // only for numeric type field
         INVERTED(201), // works for all scalar fields except JSON type field

+ 1 - 1
sdk-core/src/test/java/io/milvus/TestUtils.java

@@ -11,7 +11,7 @@ public class TestUtils {
     private int dimension = 256;
     private static final Random RANDOM = new Random();
 
-    public static final String MilvusDockerImageID = "milvusdb/milvus:master-20250927-cc53b25b";
+    public static final String MilvusDockerImageID = "milvusdb/milvus:master-20250929-ca1cc7c9-amd64";
 
     public TestUtils(int dimension) {
         this.dimension = dimension;

+ 60 - 56
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1258,6 +1258,10 @@ class MilvusClientV2DockerTest {
                 .indexType(IndexParam.IndexType.HNSW)
                 .metricType(IndexParam.MetricType.COSINE)
                 .build());
+        indexParams.add(IndexParam.builder()
+                .fieldName(geoField)
+                .indexType(IndexParam.IndexType.RTREE)
+                .build());
         client.createIndex(CreateIndexReq.builder()
                 .collectionName(randomCollectionName)
                 .indexParams(indexParams)
@@ -1276,62 +1280,62 @@ class MilvusClientV2DockerTest {
         Assertions.assertEquals(geoField, fields.get(2).getName());
         Assertions.assertEquals(DataType.Geometry, fields.get(2).getDataType());
 
-//        // insert
-//        List<JsonObject> rows = new ArrayList<>();
-//        {
-//            JsonObject row = new JsonObject();
-//            row.addProperty(pkField, 1);
-//            row.addProperty(geoField, "POINT (1.0 -1.0)");
-//            row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector()));
-//            rows.add(row);
-//        }
-//        {
-//            JsonObject row = new JsonObject();
-//            row.addProperty(pkField, 2);
-//            row.addProperty(geoField, "POINT (2.0 2.0)");
-//            row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector()));
-//            rows.add(row);
-//        }
-//        InsertResp insertResp = client.insert(InsertReq.builder()
-//                .collectionName(randomCollectionName)
-//                .data(rows)
-//                .build());
-//        Assertions.assertEquals(rows.size(), insertResp.getInsertCnt());
-//
-//        // query
-//        String filter = String.format("ST_WITHIN(%s, 'POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))')", geoField);
-//        QueryResp queryResp = client.query(QueryReq.builder()
-//                .collectionName(randomCollectionName)
-//                .limit(10)
-//                .filter(filter)
-//                .consistencyLevel(ConsistencyLevel.STRONG)
-//                .outputFields(Arrays.asList(pkField, geoField))
-//                .build());
-//        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
-//        Assertions.assertEquals(1, queryResults.size());
-//        for (QueryResp.QueryResult res : queryResults) {
-//            Assertions.assertTrue(res.getEntity().containsKey(geoField));
-//            Assertions.assertEquals(res.getEntity().get(pkField), 2L);
-//        }
-//
-//        // search
-//        SearchResp searchResp = client.search(SearchReq.builder()
-//                .collectionName(randomCollectionName)
-//                .annsField(vectorField)
-//                .data(Collections.singletonList(new FloatVec(utils.generateFloatVector())))
-//                .limit(10)
-//                .filter(filter)
-//                .outputFields(Arrays.asList(pkField, geoField))
-//                .build());
-//        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
-//        Assertions.assertEquals(1, searchResults.size());
-//        for (List<SearchResp.SearchResult> oneResults : searchResults) {
-//            Assertions.assertEquals(1, oneResults.size());
-//            for (SearchResp.SearchResult res : oneResults) {
-//                Assertions.assertTrue(res.getEntity().containsKey(geoField));
-//                Assertions.assertEquals(res.getId(), 2L);
-//            }
-//        }
+        // insert
+        List<JsonObject> rows = new ArrayList<>();
+        {
+            JsonObject row = new JsonObject();
+            row.addProperty(pkField, 1);
+            row.addProperty(geoField, "POINT (1.0 -1.0)");
+            row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector()));
+            rows.add(row);
+        }
+        {
+            JsonObject row = new JsonObject();
+            row.addProperty(pkField, 2);
+            row.addProperty(geoField, "POINT (2.0 2.0)");
+            row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector()));
+            rows.add(row);
+        }
+        InsertResp insertResp = client.insert(InsertReq.builder()
+                .collectionName(randomCollectionName)
+                .data(rows)
+                .build());
+        Assertions.assertEquals(rows.size(), insertResp.getInsertCnt());
+
+        // query
+        String filter = String.format("ST_WITHIN(%s, 'POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))')", geoField);
+        QueryResp queryResp = client.query(QueryReq.builder()
+                .collectionName(randomCollectionName)
+                .limit(10)
+                .filter(filter)
+                .consistencyLevel(ConsistencyLevel.STRONG)
+                .outputFields(Arrays.asList(pkField, geoField))
+                .build());
+        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+        Assertions.assertEquals(1, queryResults.size());
+        for (QueryResp.QueryResult res : queryResults) {
+            Assertions.assertTrue(res.getEntity().containsKey(geoField));
+            Assertions.assertEquals(res.getEntity().get(pkField), 2L);
+        }
+
+        // search
+        SearchResp searchResp = client.search(SearchReq.builder()
+                .collectionName(randomCollectionName)
+                .annsField(vectorField)
+                .data(Collections.singletonList(new FloatVec(utils.generateFloatVector())))
+                .limit(10)
+                .filter(filter)
+                .outputFields(Arrays.asList(pkField, geoField))
+                .build());
+        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+        Assertions.assertEquals(1, searchResults.size());
+        for (List<SearchResp.SearchResult> oneResults : searchResults) {
+            Assertions.assertEquals(1, oneResults.size());
+            for (SearchResp.SearchResult res : oneResults) {
+                Assertions.assertTrue(res.getEntity().containsKey(geoField));
+                Assertions.assertEquals(res.getId(), 2L);
+            }
+        }
     }
 
     @Test