Browse Source

fix(vector-search): add dimension validation before search

tumao 3 years ago
parent
commit
fe99b89e26

+ 2 - 0
client/src/i18n/cn/search.ts

@@ -11,6 +11,8 @@ const searchTrans = {
   result: 'Search Results',
   result: 'Search Results',
   topK: 'TopK {{number}}',
   topK: 'TopK {{number}}',
   filter: 'Advanced Filter',
   filter: 'Advanced Filter',
+  vectorValueWarning:
+    'Vector value should be an array of length {{dimension}}(dimension)',
 };
 };
 
 
 export default searchTrans;
 export default searchTrans;

+ 1 - 0
client/src/i18n/en/search.ts

@@ -11,6 +11,7 @@ const searchTrans = {
   result: 'Search Results',
   result: 'Search Results',
   topK: 'TopK {{number}}',
   topK: 'TopK {{number}}',
   filter: 'Advanced Filter',
   filter: 'Advanced Filter',
+  vectorValueWarning: 'Vector value should be an array of length {{dimension}}',
 };
 };
 
 
 export default searchTrans;
 export default searchTrans;

+ 4 - 0
client/src/pages/seach/Styles.ts

@@ -122,4 +122,8 @@ export const getVectorSearchStyles = makeStyles((theme: Theme) => ({
     lineHeight: '16px',
     lineHeight: '16px',
     color: theme.palette.milvusGrey.dark,
     color: theme.palette.milvusGrey.dark,
   },
   },
+  error: {
+    marginTop: theme.spacing(1),
+    color: theme.palette.error.main,
+  },
 }));
 }));

+ 2 - 0
client/src/pages/seach/Types.ts

@@ -34,6 +34,8 @@ export interface FieldOption extends Option {
   // used to get metric type, index type and index params for search params
   // used to get metric type, index type and index params for search params
   // if user doesn't create index, default value is null
   // if user doesn't create index, default value is null
   indexInfo: IndexView | null;
   indexInfo: IndexView | null;
+  // used for check vector input validation
+  dimension: number;
 }
 }
 
 
 export interface SearchParamInputConfig {
 export interface SearchParamInputConfig {

+ 85 - 44
client/src/pages/seach/VectorSearch.tsx

@@ -75,21 +75,6 @@ const VectorSearch = () => {
     data: result,
     data: result,
   } = usePaginationHook(searchResult || []);
   } = usePaginationHook(searchResult || []);
 
 
-  const searchDisabled = useMemo(() => {
-    /**
-     * before search, user must:
-     * 1. enter vector value
-     * 2. choose collection and field
-     * 3. set extra search params
-     */
-    const isInvalid =
-      vectors === '' ||
-      selectedCollection === '' ||
-      selectedField === '' ||
-      paramDisabled;
-    return isInvalid;
-  }, [paramDisabled, selectedField, selectedCollection, vectors]);
-
   const collectionOptions: Option[] = useMemo(
   const collectionOptions: Option[] = useMemo(
     () =>
     () =>
       collections.map(c => ({
       collections.map(c => ({
@@ -124,36 +109,81 @@ const VectorSearch = () => {
       : [];
       : [];
   }, [searchResult]);
   }, [searchResult]);
 
 
-  const { metricType, indexType, indexParams, fieldType, embeddingType } =
-    useMemo(() => {
-      if (selectedField !== '') {
-        // field options must contain selected field, so selectedFieldInfo will never undefined
-        const selectedFieldInfo = fieldOptions.find(
-          f => f.value === selectedField
-        );
-        const index = selectedFieldInfo?.indexInfo;
-        const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
-        const metric =
-          index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
-        const indexParams = index?._indexParameterPairs || [];
-
-        return {
-          metricType: metric,
-          indexType: index?._indexType || getDefaultIndexType(embeddingType),
-          indexParams,
-          fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
-          embeddingType,
-        };
-      }
+  const {
+    metricType,
+    indexType,
+    indexParams,
+    fieldType,
+    embeddingType,
+    selectedFieldDimension,
+  } = useMemo(() => {
+    if (selectedField !== '') {
+      // field options must contain selected field, so selectedFieldInfo will never undefined
+      const selectedFieldInfo = fieldOptions.find(
+        f => f.value === selectedField
+      );
+      const index = selectedFieldInfo?.indexInfo;
+      const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
+      const metric =
+        index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
+      const indexParams = index?._indexParameterPairs || [];
+      const dim = selectedFieldInfo?.dimension || 0;
 
 
       return {
       return {
-        metricType: '',
-        indexType: '',
-        indexParams: [],
-        fieldType: 0,
-        embeddingType: DataTypeEnum.FloatVector,
+        metricType: metric,
+        indexType: index?._indexType || getDefaultIndexType(embeddingType),
+        indexParams,
+        fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
+        embeddingType,
+        selectedFieldDimension: dim,
       };
       };
-    }, [selectedField, fieldOptions]);
+    }
+
+    return {
+      metricType: '',
+      indexType: '',
+      indexParams: [],
+      fieldType: 0,
+      embeddingType: DataTypeEnum.FloatVector,
+      selectedFieldDimension: 0,
+    };
+  }, [selectedField, fieldOptions]);
+
+  /**
+   * vector value validation
+   * @return whether is valid
+   */
+  const vectorValueValid = useMemo(() => {
+    // if user hasn't input value or not select field, don't trigger validation check
+    if (vectors === '' || selectedFieldDimension === 0) {
+      return true;
+    }
+    const value = parseValue(vectors);
+    const isArray = Array.isArray(value);
+    return isArray && value.length === selectedFieldDimension;
+  }, [vectors, selectedFieldDimension]);
+
+  const searchDisabled = useMemo(() => {
+    /**
+     * before search, user must:
+     * 1. enter vector value, it should be an array and length should be equal to selected field dimension
+     * 2. choose collection and field
+     * 3. set extra search params
+     */
+    const isInvalid =
+      vectors === '' ||
+      selectedCollection === '' ||
+      selectedField === '' ||
+      paramDisabled ||
+      !vectorValueValid;
+    return isInvalid;
+  }, [
+    paramDisabled,
+    selectedField,
+    selectedCollection,
+    vectors,
+    vectorValueValid,
+  ]);
 
 
   // fetch data
   // fetch data
   const fetchCollections = useCallback(async () => {
   const fetchCollections = useCallback(async () => {
@@ -285,8 +315,11 @@ const VectorSearch = () => {
     <section className="page-wrapper">
     <section className="page-wrapper">
       {/* form section */}
       {/* form section */}
       <form className={classes.form}>
       <form className={classes.form}>
-        {/* vector value textarea */}
-        <fieldset className="field">
+        {/**
+         * vector value textarea
+         * use field-params class because it also has error msg if invalid
+         */}
+        <fieldset className="field field-params">
           <Typography className="text">{searchTrans('firstTip')}</Typography>
           <Typography className="text">{searchTrans('firstTip')}</Typography>
           <TextField
           <TextField
             className="textarea"
             className="textarea"
@@ -304,6 +337,14 @@ const VectorSearch = () => {
               handleVectorChange(e.target.value as string);
               handleVectorChange(e.target.value as string);
             }}
             }}
           />
           />
+          {/* validation */}
+          {!vectorValueValid && (
+            <Typography variant="caption" className={classes.error}>
+              {searchTrans('vectorValueWarning', {
+                dimension: selectedFieldDimension,
+              })}
+            </Typography>
+          )}
         </fieldset>
         </fieldset>
         {/* collection and field selectors */}
         {/* collection and field selectors */}
         <fieldset className="field field-second">
         <fieldset className="field field-second">

+ 1 - 0
client/src/utils/search.ts

@@ -89,6 +89,7 @@ export const getVectorFieldOptions = (
       value: f._fieldName,
       value: f._fieldName,
       fieldType: f._fieldType,
       fieldType: f._fieldType,
       indexInfo: index || null,
       indexInfo: index || null,
+      dimension: Number(f._dimension),
     };
     };
   });
   });