FieldDataWrapper.java 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing,
  13. * software distributed under the License is distributed on an
  14. * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. * KIND, either express or implied. See the License for the
  16. * specific language governing permissions and limitations
  17. * under the License.
  18. */
  19. package io.milvus.response;
  20. import com.google.gson.*;
  21. import com.google.protobuf.ProtocolStringList;
  22. import io.milvus.exception.ParamException;
  23. import io.milvus.grpc.*;
  24. import io.milvus.exception.IllegalResponseException;
  25. import io.milvus.param.ParamUtils;
  26. import lombok.NonNull;
  27. import java.nio.ByteBuffer;
  28. import java.nio.ByteOrder;
  29. import java.util.ArrayList;
  30. import java.util.List;
  31. import java.util.SortedMap;
  32. import java.util.TreeMap;
  33. import java.util.stream.Collectors;
  34. import com.google.protobuf.ByteString;
  35. import static io.milvus.grpc.DataType.JSON;
  36. /**
  37. * Utility class to wrap response of <code>query/search</code> interface.
  38. */
  39. public class FieldDataWrapper {
  40. private final FieldData fieldData;
  41. public FieldDataWrapper(@NonNull FieldData fieldData) {
  42. this.fieldData = fieldData;
  43. }
  44. public boolean isVectorField() {
  45. return ParamUtils.isVectorDataType(fieldData.getType());
  46. }
  47. public boolean isJsonField() {
  48. return fieldData.getType() == JSON;
  49. }
  50. public boolean isDynamicField() {
  51. return fieldData.getType() == JSON && fieldData.getIsDynamic();
  52. }
  53. /**
  54. * Gets the dimension value of a vector field.
  55. * Throw {@link IllegalResponseException} if the field is not a vector filed.
  56. *
  57. * @return <code>int</code> dimension of the vector field
  58. */
  59. public int getDim() throws IllegalResponseException {
  60. if (!isVectorField()) {
  61. throw new IllegalResponseException("Not a vector field");
  62. }
  63. return (int) fieldData.getVectors().getDim();
  64. }
  65. // this method returns bytes size of each vector according to vector type
  66. private int checkDim(DataType dt, ByteString data, int dim) {
  67. if (dt == DataType.BinaryVector) {
  68. if ((data.size()*8) % dim != 0) {
  69. String msg = String.format("Returned binary vector data array size %d doesn't match dimension %d",
  70. data.size(), dim);
  71. throw new IllegalResponseException(msg);
  72. }
  73. return dim/8;
  74. } else if (dt == DataType.Float16Vector || dt == DataType.BFloat16Vector) {
  75. if (data.size() % (dim*2) != 0) {
  76. String msg = String.format("Returned float16 vector data array size %d doesn't match dimension %d",
  77. data.size(), dim);
  78. throw new IllegalResponseException(msg);
  79. }
  80. return dim*2;
  81. }
  82. return 0;
  83. }
  84. /**
  85. * Gets the row count of a field.
  86. * * Throws {@link IllegalResponseException} if the field type is illegal.
  87. *
  88. * @return <code>long</code> row count of the field
  89. */
  90. public long getRowCount() throws IllegalResponseException {
  91. DataType dt = fieldData.getType();
  92. switch (dt) {
  93. case FloatVector: {
  94. int dim = getDim();
  95. List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
  96. if (data.size() % dim != 0) {
  97. String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
  98. data.size(), dim);
  99. throw new IllegalResponseException(msg);
  100. }
  101. return data.size()/dim;
  102. }
  103. case BinaryVector: {
  104. // for binary vector, each dimension is one bit, each byte is 8 dim
  105. int dim = getDim();
  106. ByteString data = fieldData.getVectors().getBinaryVector();
  107. int bytePerVec = checkDim(dt, data, dim);
  108. return data.size()/bytePerVec;
  109. }
  110. case Float16Vector:
  111. case BFloat16Vector: {
  112. // for float16 vector, each dimension 2 bytes
  113. int dim = getDim();
  114. ByteString data = (dt == DataType.Float16Vector) ?
  115. fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
  116. int bytePerVec = checkDim(dt, data, dim);
  117. return data.size()/bytePerVec;
  118. }
  119. case SparseFloatVector: {
  120. // for sparse vector, each content is a vector
  121. return fieldData.getVectors().getSparseFloatVector().getContentsCount();
  122. }
  123. case Int64:
  124. return fieldData.getScalars().getLongData().getDataCount();
  125. case Int32:
  126. case Int16:
  127. case Int8:
  128. return fieldData.getScalars().getIntData().getDataCount();
  129. case Bool:
  130. return fieldData.getScalars().getBoolData().getDataCount();
  131. case Float:
  132. return fieldData.getScalars().getFloatData().getDataCount();
  133. case Double:
  134. return fieldData.getScalars().getDoubleData().getDataCount();
  135. case VarChar:
  136. case String:
  137. return fieldData.getScalars().getStringData().getDataCount();
  138. case JSON:
  139. return fieldData.getScalars().getJsonData().getDataCount();
  140. case Array:
  141. return fieldData.getScalars().getArrayData().getDataCount();
  142. default:
  143. throw new IllegalResponseException("Unsupported data type returned by FieldData");
  144. }
  145. }
  146. /**
  147. * Returns the field data according to its type:
  148. * FloatVector field returns List of List Float,
  149. * BinaryVector/Float16Vector/BFloat16Vector fields return List of ByteBuffer
  150. * SparseFloatVector field returns List of SortedMap[Long, Float]
  151. * Int64 field returns List of Long
  152. * Int32/Int16/Int8 fields return List of Integer
  153. * Bool field returns List of Boolean
  154. * Float field returns List of Float
  155. * Double field returns List of Double
  156. * Varchar field returns List of String
  157. * Array field returns List of List
  158. * JSON field returns List of String;
  159. * etc.
  160. *
  161. * Throws {@link IllegalResponseException} if the field type is illegal.
  162. *
  163. * @return <code>List</code>
  164. */
  165. public List<?> getFieldData() throws IllegalResponseException {
  166. DataType dt = fieldData.getType();
  167. switch (dt) {
  168. case FloatVector: {
  169. int dim = getDim();
  170. List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
  171. if (data.size() % dim != 0) {
  172. String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
  173. data.size(), dim);
  174. throw new IllegalResponseException(msg);
  175. }
  176. List<List<Float>> packData = new ArrayList<>();
  177. int count = data.size() / dim;
  178. for (int i = 0; i < count; ++i) {
  179. packData.add(data.subList(i * dim, (i + 1) * dim));
  180. }
  181. return packData;
  182. }
  183. case BinaryVector:
  184. case Float16Vector:
  185. case BFloat16Vector: {
  186. int dim = getDim();
  187. ByteString data = null;
  188. if (dt == DataType.BinaryVector) {
  189. data = fieldData.getVectors().getBinaryVector();
  190. } else if (dt == DataType.Float16Vector) {
  191. data = fieldData.getVectors().getFloat16Vector();
  192. } else {
  193. data = fieldData.getVectors().getBfloat16Vector();
  194. }
  195. int bytePerVec = checkDim(dt, data, dim);
  196. int count = data.size()/bytePerVec;
  197. List<ByteBuffer> packData = new ArrayList<>();
  198. for (int i = 0; i < count; ++i) {
  199. ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
  200. bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
  201. packData.add(bf);
  202. }
  203. return packData;
  204. }
  205. case SparseFloatVector: {
  206. // in Java sdk, each sparse vector is pairs of long+float
  207. // in server side, each sparse vector is stored as uint+float (8 bytes)
  208. // don't use sparseArray.getDim() because the dim is the max index of each rows
  209. SparseFloatArray sparseArray = fieldData.getVectors().getSparseFloatVector();
  210. List<SortedMap<Long, Float>> packData = new ArrayList<>();
  211. for (int i = 0; i < sparseArray.getContentsCount(); ++i) {
  212. ByteString bs = sparseArray.getContents(i);
  213. ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray());
  214. bf.order(ByteOrder.LITTLE_ENDIAN);
  215. SortedMap<Long, Float> sparse = new TreeMap<>();
  216. long num = bf.limit()/8; // each uint+float pair is 8 bytes
  217. for (long j = 0; j < num; j++) {
  218. // here we convert an uint 4-bytes to a long value
  219. ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
  220. pBuf.order(ByteOrder.LITTLE_ENDIAN);
  221. int offset = 8*(int)j;
  222. byte[] aa = bf.array();
  223. for (int k = offset; k < offset + 4; k++) {
  224. pBuf.put(aa[k]); // fill the first 4 bytes with the unit bytes
  225. }
  226. pBuf.putInt(0); // fill the last 4 bytes to zero
  227. pBuf.rewind(); // reset position to head
  228. long k = pBuf.getLong(); // this is the long value converted from the uint
  229. // here we get the float value as normal
  230. bf.position(offset+4); // position offsets 4 bytes since they were converted to long
  231. float v = bf.getFloat(); // this is the float value
  232. sparse.put(k, v);
  233. }
  234. packData.add(sparse);
  235. }
  236. return packData;
  237. }
  238. case Array:
  239. List<List<?>> array = new ArrayList<>();
  240. ArrayArray arrArray = fieldData.getScalars().getArrayData();
  241. for (int i = 0; i < arrArray.getDataCount(); i++) {
  242. ScalarField scalar = arrArray.getData(i);
  243. array.add(getScalarData(arrArray.getElementType(), scalar));
  244. }
  245. return array;
  246. case Int64:
  247. case Int32:
  248. case Int16:
  249. case Int8:
  250. case Bool:
  251. case Float:
  252. case Double:
  253. case VarChar:
  254. case String:
  255. case JSON:
  256. return getScalarData(dt, fieldData.getScalars());
  257. default:
  258. throw new IllegalResponseException("Unsupported data type returned by FieldData");
  259. }
  260. }
  261. private List<?> getScalarData(DataType dt, ScalarField scalar) {
  262. switch (dt) {
  263. case Int64:
  264. return scalar.getLongData().getDataList();
  265. case Int32:
  266. case Int16:
  267. case Int8:
  268. return scalar.getIntData().getDataList();
  269. case Bool:
  270. return scalar.getBoolData().getDataList();
  271. case Float:
  272. return scalar.getFloatData().getDataList();
  273. case Double:
  274. return scalar.getDoubleData().getDataList();
  275. case VarChar:
  276. case String:
  277. ProtocolStringList protoStrList = scalar.getStringData().getDataList();
  278. return protoStrList.subList(0, protoStrList.size());
  279. case JSON:
  280. List<ByteString> dataList = scalar.getJsonData().getDataList();
  281. return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
  282. default:
  283. return new ArrayList<>();
  284. }
  285. }
  286. public Integer getAsInt(int index, String paramName) throws IllegalResponseException {
  287. if (isJsonField()) {
  288. String result = getAsString(index, paramName);
  289. return result == null ? null : Integer.parseInt(result);
  290. }
  291. throw new IllegalResponseException("Only JSON type support this operation");
  292. }
  293. public String getAsString(int index, String paramName) throws IllegalResponseException {
  294. if (isJsonField()) {
  295. JsonElement jsonElement = parseObjectData(index);
  296. if (jsonElement instanceof JsonObject) {
  297. return ((JsonObject)jsonElement).get(paramName).getAsString();
  298. } else {
  299. throw new IllegalResponseException("The JSON element is not a dict");
  300. }
  301. }
  302. throw new IllegalResponseException("Only JSON type support this operation");
  303. }
  304. public Boolean getAsBool(int index, String paramName) throws IllegalResponseException {
  305. if (isJsonField()) {
  306. String result = getAsString(index, paramName);
  307. return result == null ? null : Boolean.parseBoolean(result);
  308. }
  309. throw new IllegalResponseException("Only JSON type support this operation");
  310. }
  311. public Double getAsDouble(int index, String paramName) throws IllegalResponseException {
  312. if (isJsonField()) {
  313. String result = getAsString(index, paramName);
  314. return result == null ? null : Double.parseDouble(result);
  315. }
  316. throw new IllegalResponseException("Only JSON type support this operation");
  317. }
  318. /**
  319. * Gets a field's value by field name.
  320. *
  321. * @param index which row
  322. * @param paramName which field
  323. * @return returns Long for integer value, returns Double for decimal value,
  324. * returns String for string value, returns JsonElement for JSON object and Array.
  325. */
  326. public Object get(int index, String paramName) throws IllegalResponseException {
  327. if (!isJsonField()) {
  328. throw new IllegalResponseException("Only JSON type support this operation");
  329. }
  330. JsonElement jsonElement = parseObjectData(index);
  331. if (!(jsonElement instanceof JsonObject)) {
  332. throw new IllegalResponseException("The JSON element is not a dict");
  333. }
  334. JsonElement element = ((JsonObject)jsonElement).get(paramName);
  335. return ValueOfJSONElement(element);
  336. }
  337. public Object valueByIdx(int index) throws ParamException {
  338. List<?> data = getFieldData();
  339. if (index < 0 || index >= data.size()) {
  340. throw new ParamException(String.format("Value index %d out of range %d", index, data.size()));
  341. }
  342. return data.get(index);
  343. }
  344. private JsonElement parseObjectData(int index) {
  345. Object object = valueByIdx(index);
  346. return ParseJSONObject(object);
  347. }
  348. public static JsonElement ParseJSONObject(Object object) {
  349. if (object instanceof String) {
  350. return JsonParser.parseString((String)object);
  351. } else if (object instanceof byte[]) {
  352. return JsonParser.parseString(new String((byte[]) object));
  353. } else {
  354. throw new IllegalResponseException("Illegal type value for JSON parser");
  355. }
  356. }
  357. public static Object ValueOfJSONElement(JsonElement element) {
  358. if (element == null || element.isJsonNull()) {
  359. return null;
  360. }
  361. if (element.isJsonPrimitive()) {
  362. JsonPrimitive primitive = (JsonPrimitive) element;
  363. if (primitive.isString()) {
  364. return primitive.getAsString();
  365. } else if (primitive.isBoolean()) {
  366. return primitive.getAsBoolean();
  367. } else if (primitive.isNumber()) {
  368. if (primitive.getAsBigDecimal().scale() == 0) {
  369. return primitive.getAsLong();
  370. } else {
  371. return primitive.getAsDouble();
  372. }
  373. }
  374. }
  375. return element;
  376. }
  377. }