Explorar o código

Support Geometry field (#1627)

Signed-off-by: groot <yihua.mo@zilliz.com>
groot hai 1 semana
pai
achega
2668972794

+ 2 - 2
docker-compose.yml

@@ -3,7 +3,7 @@ version: '3.5'
 services:
   standalone:
     container_name: milvus-javasdk-standalone-1
-    image: milvusdb/milvus:master-20250924-20411e52-amd64
+    image: milvusdb/milvus:master-20250927-cc53b25b
     command: [ "milvus", "run", "standalone" ]
     environment:
       - COMMON_STORAGETYPE=local
@@ -24,7 +24,7 @@ services:
 
   standaloneslave:
     container_name: milvus-javasdk-standalone-2
-    image: milvusdb/milvus:master-20250924-20411e52-amd64
+    image: milvusdb/milvus:master-20250927-cc53b25b
     command: [ "milvus", "run", "standalone" ]
     environment:
       - COMMON_STORAGETYPE=local

+ 1 - 0
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkWriter.java

@@ -295,6 +295,7 @@ public abstract class BulkWriter implements AutoCloseable {
                     break;
                 }
                 case VarChar:
+                case Geometry:
                 case Timestamptz: {
                     Pair<Object, Integer> objectAndSize = verifyVarchar(obj, field);
                     rowValues.put(fieldName, objectAndSize.getLeft());

+ 1 - 1
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/ParquetUtils.java

@@ -94,6 +94,7 @@ public class ParquetUtils {
                     setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, false);
                     break;
                 case VarChar:
+                case Geometry:
                 case Timestamptz:
                 case JSON:
                 case SparseFloatVector: // sparse vector is parsed as JSON format string in the server side
@@ -137,7 +138,6 @@ public class ParquetUtils {
                 setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, true);
                 break;
             case VarChar:
-            case Timestamptz:
                 setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
                         LogicalTypeAnnotation.stringType(), field, true);
                 break;

+ 1 - 1
sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/writer/ParquetFileWriter.java

@@ -135,6 +135,7 @@ public class ParquetFileWriter implements FormatFileWriter {
                 break;
             case VarChar:
             case String:
+            case Geometry:
             case Timestamptz:
             case JSON:
                 group.append(paramName, (String)value);
@@ -170,7 +171,6 @@ public class ParquetFileWriter implements FormatFileWriter {
                         break;
                     case String:
                     case VarChar:
-                    case Timestamptz:
                         addStringArray(group, paramName, (List<String>) value);
                         break;
                     case Bool:

+ 1 - 0
sdk-bulkwriter/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java

@@ -521,6 +521,7 @@ public class BulkWriterTest {
                 Assertions.assertEquals(element.getAsDouble(), obj);
                 break;
             case VarChar:
+            case Geometry:
             case Timestamptz:
             case JSON:
                 verifyJsonString(element.getAsString(), ((Utf8)obj).toString());

+ 9 - 1
sdk-core/src/main/java/io/milvus/param/ParamUtils.java

@@ -270,6 +270,7 @@ public class ParamUtils {
                 break;
             case VarChar:
             case String:
+            case Geometry:
             case Timestamptz:
                 for (Object value : values) {
                     if (checkNullableFieldData(fieldSchema, value, verifyElementType)) {
@@ -419,6 +420,7 @@ public class ParamUtils {
                 return value.getAsDouble(); // return double for genFieldData()
             case VarChar:
             case String:
+            case Geometry:
             case Timestamptz:
                 if (!(value.isJsonPrimitive())) {
                     throw new ParamException(String.format(errMsgs.get(dataType), fieldName));
@@ -466,7 +468,6 @@ public class ParamUtils {
                 case Double:
                     return JsonUtils.fromJson(jsonArray, new TypeToken<List<Double>>() {}.getType());
                 case VarChar:
-                case Timestamptz:
                     return JsonUtils.fromJson(jsonArray, new TypeToken<List<String>>() {}.getType());
                 default:
                     throw new ParamException(String.format("Unsupported element type of Array field '%s'", fieldName));
@@ -1376,6 +1377,11 @@ public class ParamUtils {
                 StringArray stringArray = StringArray.newBuilder().addAllData(strings).build();
                 return ScalarField.newBuilder().setStringData(stringArray).build();
             }
+            case Geometry: {
+                List<String> strings = objects.stream().map(p -> (p == null) ? null : (String) p).collect(Collectors.toList());
+                GeometryWktArray wktArray = GeometryWktArray.newBuilder().addAllData(strings).build();
+                return ScalarField.newBuilder().setGeometryWktData(wktArray).build();
+            }
             case JSON: {
                 List<ByteString> byteStrings = objects.stream().map(p -> (p == null) ? null : ByteString.copyFromUtf8(p.toString()))
                         .collect(Collectors.toList());
@@ -1509,6 +1515,7 @@ public class ParamUtils {
                 break;
             case VarChar:
             case String:
+            case Geometry:
             case Timestamptz:
                 if (obj instanceof String) {
                     return builder.setStringData((String) obj).build();
@@ -1546,6 +1553,7 @@ public class ParamUtils {
                 return value.getBoolData();
             case VarChar:
             case String:
+            case Geometry:
             case Timestamptz:
                 return value.getStringData();
             case JSON:

+ 13 - 3
sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

@@ -175,6 +175,8 @@ public class FieldDataWrapper {
             case String:
             case Timestamptz:
                 return fieldData.getScalars().getStringData().getDataCount();
+            case Geometry:
+                return fieldData.getScalars().getGeometryWktData().getDataCount();
             case JSON:
                 return fieldData.getScalars().getJsonData().getDataCount();
             case Array:
@@ -247,6 +249,7 @@ public class FieldDataWrapper {
             case Double:
             case VarChar:
             case String:
+            case Geometry:
             case Timestamptz:
             case JSON:
                 return getScalarData(dt, fieldData.getScalars(), fieldData.getValidDataList());
@@ -343,13 +346,19 @@ public class FieldDataWrapper {
                 return setNoneData(scalar.getDoubleData().getDataList(), validData);
             case VarChar:
             case String:
-            case Timestamptz:
+            case Timestamptz: {
                 ProtocolStringList protoStrList = scalar.getStringData().getDataList();
                 return setNoneData(protoStrList.subList(0, protoStrList.size()), validData);
-            case JSON:
+            }
+            case Geometry: {
+                ProtocolStringList protoGeoList = scalar.getGeometryWktData().getDataList();
+                return setNoneData(protoGeoList.subList(0, protoGeoList.size()), validData);
+            }
+            case JSON: {
                 List<ByteString> dataList = scalar.getJsonData().getDataList();
                 return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
-            case Array:
+            }
+            case Array: {
                 List<List<?>> array = new ArrayList<>();
                 ArrayArray arrArray = scalar.getArrayData();
                 boolean nullable = validData != null && validData.size() == arrArray.getDataCount();
@@ -362,6 +371,7 @@ public class FieldDataWrapper {
                     }
                 }
                 return array;
+            }
             default:
                 return new ArrayList<>();
         }

+ 1 - 0
sdk-core/src/main/java/io/milvus/v2/common/DataType.java

@@ -39,6 +39,7 @@ public enum DataType {
     VarChar(21), // variable-length strings with a specified maximum length
     Array(22),
     JSON(23),
+    Geometry(24),
     Timestamptz(26),
 
     BinaryVector(100),

+ 1 - 1
sdk-core/src/main/java/io/milvus/v2/service/vector/request/InsertReq.java

@@ -35,7 +35,7 @@ public class InsertReq {
      * Sets the row data to insert. The rows list cannot be empty.
      *
      * Internal class for insert data.
-     * If dataType is Bool/Int8/Int16/Int32/Int64/Float/Double/Varchar, use JsonObject.addProperty(key, value) to input;
+     * If dataType is Bool/Int8/Int16/Int32/Int64/Float/Double/Varchar/Geometry/Timestamptz, use JsonObject.addProperty(key, value) to input;
      * If dataType is FloatVector, use JsonObject.add(key, gson.toJsonTree(List[Float]) to input;
      * If dataType is BinaryVector/Float16Vector/BFloat16Vector/Int8Vector, use JsonObject.add(key, gson.toJsonTree(byte[])) to input;
      * If dataType is SparseFloatVector, use JsonObject.add(key, gson.toJsonTree(SortedMap[Long, Float])) to input;

+ 2 - 1
sdk-core/src/main/java/io/milvus/v2/service/vector/request/UpsertReq.java

@@ -33,11 +33,12 @@ public class UpsertReq {
      * Sets the row data to insert. The rows list cannot be empty.
      *
      * Internal class for insert data.
-     * If dataType is Bool/Int8/Int16/Int32/Int64/Float/Double/Varchar, use JsonObject.addProperty(key, value) to input;
+     * If dataType is Bool/Int8/Int16/Int32/Int64/Float/Double/Varchar/Geometry/Timestamptz, use JsonObject.addProperty(key, value) to input;
      * If dataType is FloatVector, use JsonObject.add(key, gson.toJsonTree(List[Float]) to input;
      * If dataType is BinaryVector/Float16Vector/BFloat16Vector, use JsonObject.add(key, gson.toJsonTree(byte[])) to input;
      * If dataType is SparseFloatVector, use JsonObject.add(key, gson.toJsonTree(SortedMap[Long, Float])) to input;
      * If dataType is Array, use JsonObject.add(key, gson.toJsonTree(List of Boolean/Integer/Short/Long/Float/Double/String)) to input;
+     * If dataType is Array and elementType is Struct, use JsonObject.add(key, JsonArray) to input, ensure the JsonArray is a list of JsonObject;
      * If dataType is JSON, use JsonObject.add(key, JsonElement) to input;
      *
      * Note:

+ 1 - 1
sdk-core/src/test/java/io/milvus/TestUtils.java

@@ -11,7 +11,7 @@ public class TestUtils {
     private int dimension = 256;
     private static final Random RANDOM = new Random();
 
-    public static final String MilvusDockerImageID = "milvusdb/milvus:master-20250924-20411e52-amd64";
+    public static final String MilvusDockerImageID = "milvusdb/milvus:master-20250927-cc53b25b";
 
     public TestUtils(int dimension) {
         this.dimension = dimension;

+ 121 - 0
sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

@@ -1219,6 +1219,119 @@ class MilvusClientV2DockerTest {
         Assertions.assertEquals(9L, (long)searchResults.get(1).get(0).getId());
     }
 
+    @Test
+    void testGeometry() {
+        String randomCollectionName = generator.generate(10);
+        String pkField = "pk";
+        String vectorField = "vector";
+        String geoField = "geo";
+        CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder()
+                .build();
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(pkField)
+                .dataType(DataType.Int64)
+                .isPrimaryKey(Boolean.TRUE)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(vectorField)
+                .dataType(DataType.FloatVector)
+                .dimension(DIMENSION)
+                .build());
+        collectionSchema.addField(AddFieldReq.builder()
+                .fieldName(geoField)
+                .dataType(DataType.Geometry)
+                .build());
+
+        client.dropCollection(DropCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .build());
+
+        CreateCollectionReq requestCreate = CreateCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .collectionSchema(collectionSchema)
+                .build();
+        client.createCollection(requestCreate);
+
+        List<IndexParam> indexParams = new ArrayList<>();
+        indexParams.add(IndexParam.builder()
+                .fieldName(vectorField)
+                .indexType(IndexParam.IndexType.HNSW)
+                .metricType(IndexParam.MetricType.COSINE)
+                .build());
+        client.createIndex(CreateIndexReq.builder()
+                .collectionName(randomCollectionName)
+                .indexParams(indexParams)
+                .build());
+        client.loadCollection(LoadCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .build());
+
+        // describe
+        DescribeCollectionResp descResp = client.describeCollection(DescribeCollectionReq.builder()
+                .collectionName(randomCollectionName)
+                .build());
+        CreateCollectionReq.CollectionSchema descSchema = descResp.getCollectionSchema();
+        List<CreateCollectionReq.FieldSchema> fields = descSchema.getFieldSchemaList();
+        Assertions.assertEquals(collectionSchema.getFieldSchemaList().size(), fields.size());
+        Assertions.assertEquals(geoField, fields.get(2).getName());
+        Assertions.assertEquals(DataType.Geometry, fields.get(2).getDataType());
+
+//        // insert
+//        List<JsonObject> rows = new ArrayList<>();
+//        {
+//            JsonObject row = new JsonObject();
+//            row.addProperty(pkField, 1);
+//            row.addProperty(geoField, "POINT (1.0 -1.0)");
+//            row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector()));
+//            rows.add(row);
+//        }
+//        {
+//            JsonObject row = new JsonObject();
+//            row.addProperty(pkField, 2);
+//            row.addProperty(geoField, "POINT (2.0 2.0)");
+//            row.add(vectorField, JsonUtils.toJsonTree(utils.generateFloatVector()));
+//            rows.add(row);
+//        }
+//        InsertResp insertResp = client.insert(InsertReq.builder()
+//                .collectionName(randomCollectionName)
+//                .data(rows)
+//                .build());
+//        Assertions.assertEquals(rows.size(), insertResp.getInsertCnt());
+//
+//        // quer
+//        Map<String, Object> params = new HashMap<>();
+////        params.put("timezone", "America/Chicago");
+//        QueryResp queryResp = client.query(QueryReq.builder()
+//                .collectionName(randomCollectionName)
+//                .limit(10)
+//                .consistencyLevel(ConsistencyLevel.STRONG)
+//                .outputFields(Arrays.asList(pkField, geoField))
+//                .build());
+//        List<QueryResp.QueryResult> queryResults = queryResp.getQueryResults();
+//        Assertions.assertEquals(2, queryResults.size());
+//        for (QueryResp.QueryResult res : queryResults) {
+//            Assertions.assertTrue(res.getEntity().containsKey(geoField));
+//        }
+//
+//        // search
+//        SearchResp searchResp = client.search(SearchReq.builder()
+//                .collectionName(randomCollectionName)
+//                .annsField(vectorField)
+//                .data(Collections.singletonList(new FloatVec(utils.generateFloatVector())))
+//                .limit(10)
+//                .searchParams(params)
+//                .outputFields(Arrays.asList(pkField, geoField))
+//                .build());
+//        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
+//        Assertions.assertEquals(1, searchResults.size());
+//        for (List<SearchResp.SearchResult> oneResults : searchResults) {
+//            Assertions.assertEquals(2, oneResults.size());
+//            for (SearchResp.SearchResult res : oneResults) {
+//                Assertions.assertTrue(res.getEntity().containsKey(geoField));
+//            }
+//        }
+    }
+
     @Test
     void testTimestamp() {
         String randomCollectionName = generator.generate(10);
@@ -1266,6 +1379,14 @@ class MilvusClientV2DockerTest {
                 .collectionName(randomCollectionName)
                 .build());
 
+        // set database default timezone
+        Map<String, String> props = new HashMap<>();
+        props.put("timezone", "Asia/Shanghai");
+        client.alterDatabaseProperties(AlterDatabasePropertiesReq.builder()
+                .databaseName("default")
+                .properties(props)
+                .build());
+
         // describe
         DescribeCollectionResp descResp = client.describeCollection(DescribeCollectionReq.builder()
                 .collectionName(randomCollectionName)