Buffer.java 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.bulkwriter;
  20. import com.google.gson.*;
  21. import com.google.common.collect.Lists;
  22. import io.milvus.bulkwriter.common.clientenum.BulkFileType;
  23. import io.milvus.common.utils.ExceptionUtils;
  24. import io.milvus.bulkwriter.common.utils.ParquetUtils;
  25. import io.milvus.grpc.DataType;
  26. import io.milvus.param.collection.CollectionSchemaParam;
  27. import io.milvus.param.collection.FieldType;
  28. import org.apache.hadoop.conf.Configuration;
  29. import org.apache.hadoop.fs.Path;
  30. import org.apache.parquet.example.data.Group;
  31. import org.apache.parquet.example.data.simple.SimpleGroupFactory;
  32. import org.apache.parquet.hadoop.ParquetFileWriter;
  33. import org.apache.parquet.hadoop.ParquetWriter;
  34. import org.apache.parquet.hadoop.example.GroupWriteSupport;
  35. import org.apache.parquet.hadoop.metadata.CompressionCodecName;
  36. import org.apache.parquet.schema.MessageType;
  37. import org.slf4j.Logger;
  38. import org.slf4j.LoggerFactory;
  39. import java.io.IOException;
  40. import java.nio.ByteBuffer;
  41. import java.util.HashMap;
  42. import java.util.List;
  43. import java.util.Map;
  44. import java.util.SortedMap;
  45. import java.util.stream.Collectors;
  46. import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME;
  47. public class Buffer {
  48. private static final Logger logger = LoggerFactory.getLogger(Buffer.class);
  49. private CollectionSchemaParam collectionSchema;
  50. private BulkFileType fileType;
  51. private Map<String, List<Object>> buffer;
  52. private Map<String, FieldType> fields;
  53. private static final Gson GSON_INSTANCE = new Gson();
  54. public Buffer(CollectionSchemaParam collectionSchema, BulkFileType fileType) {
  55. this.collectionSchema = collectionSchema;
  56. this.fileType = fileType;
  57. buffer = new HashMap<>();
  58. fields = new HashMap<>();
  59. for (FieldType fieldType : collectionSchema.getFieldTypes()) {
  60. if (fieldType.isPrimaryKey() && fieldType.isAutoID())
  61. continue;
  62. buffer.put(fieldType.getName(), Lists.newArrayList());
  63. fields.put(fieldType.getName(), fieldType);
  64. }
  65. if (buffer.isEmpty()) {
  66. ExceptionUtils.throwUnExpectedException("Illegal collection schema: fields list is empty");
  67. }
  68. if (collectionSchema.isEnableDynamicField()) {
  69. buffer.put(DYNAMIC_FIELD_NAME, Lists.newArrayList());
  70. fields.put(DYNAMIC_FIELD_NAME, FieldType.newBuilder().withName(DYNAMIC_FIELD_NAME).withDataType(DataType.JSON).build());
  71. }
  72. }
  73. public Integer getRowCount() {
  74. if (buffer.isEmpty()) {
  75. return 0;
  76. }
  77. for (String fieldName : buffer.keySet()) {
  78. return buffer.get(fieldName).size();
  79. }
  80. return null;
  81. }
  82. public void appendRow(Map<String, Object> row) {
  83. for (String key : row.keySet()) {
  84. if (key.equals(DYNAMIC_FIELD_NAME) && !this.collectionSchema.isEnableDynamicField()) {
  85. continue; // skip dynamic field if it is disabled
  86. }
  87. buffer.get(key).add(row.get(key));
  88. }
  89. }
  90. // verify row count of fields are equal
  91. public List<String> persist(String localPath, Integer bufferSize, Integer bufferRowCount) {
  92. int rowCount = -1;
  93. for (String key : buffer.keySet()) {
  94. if (rowCount < 0) {
  95. rowCount = buffer.get(key).size();
  96. } else if (rowCount != buffer.get(key).size()) {
  97. String msg = String.format("Column `%s` row count %s doesn't equal to the first column row count %s", key, buffer.get(key).size(), rowCount);
  98. ExceptionUtils.throwUnExpectedException(msg);
  99. }
  100. }
  101. // output files
  102. if (fileType == BulkFileType.PARQUET) {
  103. return persistParquet(localPath, bufferSize, bufferRowCount);
  104. }
  105. ExceptionUtils.throwUnExpectedException("Unsupported file type: " + fileType);
  106. return null;
  107. }
  108. private List<String> persistParquet(String localPath, Integer bufferSize, Integer bufferRowCount) {
  109. String filePath = localPath + ".parquet";
  110. // calculate a proper row group size
  111. int rowGroupSizeMin = 1000;
  112. int rowGroupSizeMax = 1000000;
  113. int rowGroupSize = 10000;
  114. // 32MB is an experience value that avoid high memory usage of parquet reader on server-side
  115. int rowGroupBytes = 32 * 1024 * 1024;
  116. int sizePerRow = (bufferSize / bufferRowCount) + 1;
  117. rowGroupSize = rowGroupBytes / sizePerRow;
  118. rowGroupSize = Math.max(rowGroupSizeMin, Math.min(rowGroupSizeMax, rowGroupSize));
  119. // declare the messageType of the Parquet
  120. MessageType messageType = ParquetUtils.parseCollectionSchema(collectionSchema);
  121. // declare and define the ParquetWriter.
  122. Path path = new Path(filePath);
  123. Configuration configuration = new Configuration();
  124. GroupWriteSupport.setSchema(messageType, configuration);
  125. GroupWriteSupport writeSupport = new GroupWriteSupport();
  126. try (ParquetWriter<Group> writer = new ParquetWriter<>(path,
  127. ParquetFileWriter.Mode.CREATE,
  128. writeSupport,
  129. CompressionCodecName.UNCOMPRESSED,
  130. rowGroupBytes,
  131. 5 * 1024 * 1024,
  132. 5 * 1024 * 1024,
  133. ParquetWriter.DEFAULT_IS_DICTIONARY_ENABLED,
  134. ParquetWriter.DEFAULT_IS_VALIDATING_ENABLED,
  135. ParquetWriter.DEFAULT_WRITER_VERSION,
  136. configuration)) {
  137. Map<String, FieldType> nameFieldType = collectionSchema.getFieldTypes().stream().collect(Collectors.toMap(FieldType::getName, e -> e));
  138. if (collectionSchema.isEnableDynamicField()) {
  139. nameFieldType.put(DYNAMIC_FIELD_NAME, FieldType.newBuilder()
  140. .withName(DYNAMIC_FIELD_NAME)
  141. .withDataType(DataType.JSON)
  142. .build());
  143. }
  144. List<String> fieldNameList = Lists.newArrayList(buffer.keySet());
  145. int size = buffer.get(fieldNameList.get(0)).size();
  146. for (int i = 0; i < size; ++i) {
  147. // build Parquet data and encapsulate it into a group.
  148. Group group = new SimpleGroupFactory(messageType).newGroup();
  149. for (String fieldName : fieldNameList) {
  150. appendGroup(group, fieldName, buffer.get(fieldName).get(i), nameFieldType.get(fieldName));
  151. }
  152. writer.write(group);
  153. }
  154. } catch (IOException e) {
  155. e.printStackTrace();
  156. }
  157. String msg = String.format("Successfully persist file %s, total size: %s, row count: %s, row group size: %s",
  158. filePath, bufferSize, bufferRowCount, rowGroupSize);
  159. logger.info(msg);
  160. return Lists.newArrayList(filePath);
  161. }
  162. private void appendGroup(Group group, String paramName, Object value, FieldType fieldType) {
  163. DataType dataType = fieldType.getDataType();
  164. switch (dataType) {
  165. case Int8:
  166. case Int16:
  167. group.append(paramName, (Short)value);
  168. break;
  169. case Int32:
  170. group.append(paramName, (Integer)value);
  171. break;
  172. case Int64:
  173. group.append(paramName, (Long)value);
  174. break;
  175. case Float:
  176. group.append(paramName, (Float)value);
  177. break;
  178. case Double:
  179. group.append(paramName, (Double)value);
  180. break;
  181. case Bool:
  182. group.append(paramName, (Boolean)value);
  183. break;
  184. case VarChar:
  185. case String:
  186. case JSON:
  187. group.append(paramName, (String)value);
  188. break;
  189. case FloatVector:
  190. addFloatArray(group, paramName, (List<Float>) value);
  191. break;
  192. case BinaryVector:
  193. case Float16Vector:
  194. case BFloat16Vector:
  195. addBinaryVector(group, paramName, (ByteBuffer) value);
  196. break;
  197. case SparseFloatVector:
  198. addSparseVector(group, paramName, (SortedMap<Long, Float>) value);
  199. break;
  200. case Array:
  201. DataType elementType = fieldType.getElementType();
  202. switch (elementType) {
  203. case Int8:
  204. case Int16:
  205. case Int32:
  206. addIntArray(group, paramName, (List<Integer>) value);
  207. break;
  208. case Int64:
  209. addLongArray(group, paramName, (List<Long>) value);
  210. break;
  211. case Float:
  212. addFloatArray(group, paramName, (List<Float>) value);
  213. break;
  214. case Double:
  215. addDoubleArray(group, paramName, (List<Double>) value);
  216. break;
  217. case String:
  218. case VarChar:
  219. addStringArray(group, paramName, (List<String>) value);
  220. break;
  221. case Bool:
  222. addBooleanArray(group, paramName, (List<Boolean>) value);
  223. break;
  224. }
  225. }
  226. }
  227. private static void addLongArray(Group group, String fieldName, List<Long> values) {
  228. Group arrayGroup = group.addGroup(fieldName);
  229. for (long value : values) {
  230. Group addGroup = arrayGroup.addGroup(0);
  231. addGroup.add(0, value);
  232. }
  233. }
  234. private static void addStringArray(Group group, String fieldName, List<String> values) {
  235. Group arrayGroup = group.addGroup(fieldName);
  236. for (String value : values) {
  237. Group addGroup = arrayGroup.addGroup(0);
  238. addGroup.add(0, value);
  239. }
  240. }
  241. private static void addIntArray(Group group, String fieldName, List<Integer> values) {
  242. Group arrayGroup = group.addGroup(fieldName);
  243. for (int value : values) {
  244. Group addGroup = arrayGroup.addGroup(0);
  245. addGroup.add(0, value);
  246. }
  247. }
  248. private static void addFloatArray(Group group, String fieldName, List<Float> values) {
  249. Group arrayGroup = group.addGroup(fieldName);
  250. for (float value : values) {
  251. Group addGroup = arrayGroup.addGroup(0);
  252. addGroup.add(0, value);
  253. }
  254. }
  255. private static void addDoubleArray(Group group, String fieldName, List<Double> values) {
  256. Group arrayGroup = group.addGroup(fieldName);
  257. for (double value : values) {
  258. Group addGroup = arrayGroup.addGroup(0);
  259. addGroup.add(0, value);
  260. }
  261. }
  262. private static void addBooleanArray(Group group, String fieldName, List<Boolean> values) {
  263. Group arrayGroup = group.addGroup(fieldName);
  264. for (boolean value : values) {
  265. Group addGroup = arrayGroup.addGroup(0);
  266. addGroup.add(0, value);
  267. }
  268. }
  269. private static void addBinaryVector(Group group, String fieldName, ByteBuffer byteBuffer) {
  270. Group arrayGroup = group.addGroup(fieldName);
  271. byte[] bytes = byteBuffer.array();
  272. for (byte value : bytes) {
  273. Group addGroup = arrayGroup.addGroup(0);
  274. addGroup.add(0, value);
  275. }
  276. }
  277. private static void addSparseVector(Group group, String fieldName, SortedMap<Long, Float> sparse) {
  278. // sparse vector is parsed as JSON format string in the server side
  279. String jsonString = GSON_INSTANCE.toJson(sparse);
  280. group.append(fieldName, jsonString);
  281. }
  282. }