FieldDataWrapper.java 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package io.milvus.response;
  2. import com.alibaba.fastjson.JSONObject;
  3. import com.google.protobuf.ProtocolStringList;
  4. import io.milvus.exception.ParamException;
  5. import io.milvus.grpc.DataType;
  6. import io.milvus.grpc.FieldData;
  7. import io.milvus.exception.IllegalResponseException;
  8. import lombok.NonNull;
  9. import java.nio.ByteBuffer;
  10. import java.util.ArrayList;
  11. import java.util.List;
  12. import java.util.stream.Collectors;
  13. import com.google.protobuf.ByteString;
  14. import static io.milvus.grpc.DataType.JSON;
  15. /**
  16. * Utility class to wrap response of <code>query/search</code> interface.
  17. */
  18. public class FieldDataWrapper {
  19. private final FieldData fieldData;
  20. public FieldDataWrapper(@NonNull FieldData fieldData) {
  21. this.fieldData = fieldData;
  22. }
  23. public boolean isVectorField() {
  24. return fieldData.getType() == DataType.FloatVector || fieldData.getType() == DataType.BinaryVector;
  25. }
  26. public boolean isJsonField() {
  27. return fieldData.getType() == JSON;
  28. }
  29. public boolean isDynamicField() {
  30. return fieldData.getType() == JSON && fieldData.getIsDynamic();
  31. }
  32. /**
  33. * Gets the dimension value of a vector field.
  34. * Throw {@link IllegalResponseException} if the field is not a vector filed.
  35. *
  36. * @return <code>int</code> dimension of the vector field
  37. */
  38. public int getDim() throws IllegalResponseException {
  39. if (!isVectorField()) {
  40. throw new IllegalResponseException("Not a vector field");
  41. }
  42. return (int) fieldData.getVectors().getDim();
  43. }
  44. /**
  45. * Gets the row count of a field.
  46. * * Throws {@link IllegalResponseException} if the field type is illegal.
  47. *
  48. * @return <code>long</code> row count of the field
  49. */
  50. public long getRowCount() throws IllegalResponseException {
  51. DataType dt = fieldData.getType();
  52. switch (dt) {
  53. case FloatVector: {
  54. int dim = getDim();
  55. List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
  56. if (data.size() % dim != 0) {
  57. throw new IllegalResponseException("Returned float vector field data array size doesn't match dimension");
  58. }
  59. return data.size()/dim;
  60. }
  61. case BinaryVector: {
  62. int dim = getDim();
  63. ByteString data = fieldData.getVectors().getBinaryVector();
  64. if (data.size() % dim != 0) {
  65. throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
  66. }
  67. return data.size()/dim;
  68. }
  69. case Int64:
  70. return fieldData.getScalars().getLongData().getDataList().size();
  71. case Int32:
  72. case Int16:
  73. case Int8:
  74. return fieldData.getScalars().getIntData().getDataList().size();
  75. case Bool:
  76. return fieldData.getScalars().getBoolData().getDataList().size();
  77. case Float:
  78. return fieldData.getScalars().getFloatData().getDataList().size();
  79. case Double:
  80. return fieldData.getScalars().getDoubleData().getDataList().size();
  81. case VarChar:
  82. case String:
  83. return fieldData.getScalars().getStringData().getDataList().size();
  84. case JSON:
  85. return fieldData.getScalars().getJsonData().getDataList().size();
  86. default:
  87. throw new IllegalResponseException("Unsupported data type returned by FieldData");
  88. }
  89. }
  90. /**
  91. * Returns the field data according to its type:
  92. * float vector field return List of List Float,
  93. * binary vector field return List of ByteBuffer
  94. * int64 field return List of Long
  95. * int32/int16/int8 field return List of Integer
  96. * boolean field return List of Boolean
  97. * float field return List of Float
  98. * double field return List of Double
  99. * varchar field return List of String
  100. * etc.
  101. *
  102. * Throws {@link IllegalResponseException} if the field type is illegal.
  103. *
  104. * @return <code>List</code>
  105. */
  106. public List<?> getFieldData() throws IllegalResponseException {
  107. DataType dt = fieldData.getType();
  108. switch (dt) {
  109. case FloatVector: {
  110. int dim = getDim();
  111. List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
  112. if (data.size() % dim != 0) {
  113. throw new IllegalResponseException("Returned float vector field data array size doesn't match dimension");
  114. }
  115. List<List<Float>> packData = new ArrayList<>();
  116. int count = data.size() / dim;
  117. for (int i = 0; i < count; ++i) {
  118. packData.add(data.subList(i * dim, (i + 1) * dim));
  119. }
  120. return packData;
  121. }
  122. case BinaryVector: {
  123. int dim = getDim();
  124. ByteString data = fieldData.getVectors().getBinaryVector();
  125. if (data.size() % dim != 0) {
  126. throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
  127. }
  128. List<ByteBuffer> packData = new ArrayList<>();
  129. int count = data.size() / dim;
  130. for (int i = 0; i < count; ++i) {
  131. ByteBuffer bf = ByteBuffer.allocate(dim);
  132. bf.put(data.substring(i * dim, (i + 1) * dim).toByteArray());
  133. packData.add(bf);
  134. }
  135. return packData;
  136. }
  137. case Int64:
  138. return fieldData.getScalars().getLongData().getDataList();
  139. case Int32:
  140. case Int16:
  141. case Int8:
  142. return fieldData.getScalars().getIntData().getDataList();
  143. case Bool:
  144. return fieldData.getScalars().getBoolData().getDataList();
  145. case Float:
  146. return fieldData.getScalars().getFloatData().getDataList();
  147. case Double:
  148. return fieldData.getScalars().getDoubleData().getDataList();
  149. case VarChar:
  150. case String:
  151. ProtocolStringList protoStrList = fieldData.getScalars().getStringData().getDataList();
  152. return protoStrList.subList(0, protoStrList.size());
  153. case JSON:
  154. List<ByteString> dataList = fieldData.getScalars().getJsonData().getDataList();
  155. return dataList.stream().map(ByteString::toByteArray).collect(Collectors.toList());
  156. default:
  157. throw new IllegalResponseException("Unsupported data type returned by FieldData");
  158. }
  159. }
  160. public Integer getAsInt(int index, String paramName) throws IllegalResponseException {
  161. if (isJsonField()) {
  162. String result = getAsString(index, paramName);
  163. return result == null ? null : Integer.parseInt(result);
  164. }
  165. throw new IllegalResponseException("Only JSON type support this operation");
  166. }
  167. public String getAsString(int index, String paramName) throws IllegalResponseException {
  168. if (isJsonField()) {
  169. JSONObject jsonObject = parseObjectData(index);
  170. return jsonObject.getString(paramName);
  171. }
  172. throw new IllegalResponseException("Only JSON type support this operation");
  173. }
  174. public Boolean getAsBool(int index, String paramName) throws IllegalResponseException {
  175. if (isJsonField()) {
  176. String result = getAsString(index, paramName);
  177. return result == null ? null : Boolean.parseBoolean(result);
  178. }
  179. throw new IllegalResponseException("Only JSON type support this operation");
  180. }
  181. public Double getAsDouble(int index, String paramName) throws IllegalResponseException {
  182. if (isJsonField()) {
  183. String result = getAsString(index, paramName);
  184. return result == null ? null : Double.parseDouble(result);
  185. }
  186. throw new IllegalResponseException("Only JSON type support this operation");
  187. }
  188. public Object get(int index, String paramName) throws IllegalResponseException {
  189. if (isJsonField()) {
  190. JSONObject jsonObject = parseObjectData(index);
  191. return jsonObject.get(paramName);
  192. }
  193. throw new IllegalResponseException("Only JSON type support this operation");
  194. }
  195. public Object valueByIdx(int index) throws ParamException {
  196. if (index < 0 || index >= getFieldData().size()) {
  197. throw new ParamException("index out of range");
  198. }
  199. return getFieldData().get(index);
  200. }
  201. private JSONObject parseObjectData(int index) {
  202. Object object = valueByIdx(index);
  203. return JSONObject.parseObject(new String((byte[])object));
  204. }
  205. }