VectorSearch.tsx 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. import { useCallback, useEffect, useMemo, useState, useContext } from 'react';
  2. import { Typography, Button, Card, CardContent } from '@material-ui/core';
  3. import { useTranslation } from 'react-i18next';
  4. import { useLocation } from 'react-router-dom';
  5. import { ALL_ROUTER_TYPES } from '@/router/Types';
  6. import { useNavigationHook, useSearchResult, usePaginationHook } from '@/hooks';
  7. import { dataContext } from '@/context';
  8. import CustomSelector from '@/components/customSelector/CustomSelector';
  9. import { ColDefinitionsType } from '@/components/grid/Types';
  10. import AttuGrid from '@/components/grid/Grid';
  11. import EmptyCard from '@/components/cards/EmptyCard';
  12. import icons from '@/components/icons/Icons';
  13. import CustomButton from '@/components/customButton/CustomButton';
  14. import SimpleMenu from '@/components/menu/SimpleMenu';
  15. import { Option } from '@/components/customSelector/Types';
  16. import Filter from '@/components/advancedSearch';
  17. import { Field } from '@/components/advancedSearch/Types';
  18. import { Collection, MilvusIndex } from '@/http';
  19. import {
  20. parseValue,
  21. parseLocationSearch,
  22. classifyFields,
  23. getDefaultIndexType,
  24. getEmbeddingType,
  25. getNonVectorFieldsForFilter,
  26. getVectorFieldOptions,
  27. cloneObj,
  28. generateVector,
  29. } from '@/utils';
  30. import {
  31. LOADING_STATE,
  32. DEFAULT_METRIC_VALUE_MAP,
  33. DYNAMIC_FIELD,
  34. DataTypeEnum,
  35. } from '@/consts';
  36. import { getLabelDisplayedRows } from './Utils';
  37. import SearchParams from './SearchParams';
  38. import { getVectorSearchStyles } from './Styles';
  39. import { TOP_K_OPTIONS } from './Constants';
  40. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  41. const VectorSearch = () => {
  42. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  43. const location = useLocation();
  44. const { database } = useContext(dataContext);
  45. // i18n
  46. const { t: searchTrans } = useTranslation('search');
  47. const { t: btnTrans } = useTranslation('btn');
  48. const classes = getVectorSearchStyles();
  49. // data stored inside the component
  50. const [tableLoading, setTableLoading] = useState<boolean>(false);
  51. const [collections, setCollections] = useState<Collection[]>([]);
  52. const [selectedCollection, setSelectedCollection] = useState<string>('');
  53. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  54. // fields for advanced filter
  55. const [filterFields, setFilterFields] = useState<Field[]>([]);
  56. const [selectedField, setSelectedField] = useState<string>('');
  57. // search params form
  58. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  59. {}
  60. );
  61. // search params disable state
  62. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  63. // use null as init value before search, empty array means no results
  64. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  65. null
  66. );
  67. // default topK is 100
  68. const [topK, setTopK] = useState<number>(100);
  69. const [expression, setExpression] = useState<string>('');
  70. const [vectors, setVectors] = useState<string>('');
  71. // latency
  72. const [latency, setLatency] = useState<number>(0);
  73. const searchResultMemo = useSearchResult(searchResult as any, classes);
  74. const {
  75. pageSize,
  76. handlePageSize,
  77. currentPage,
  78. handleCurrentPage,
  79. total,
  80. data: result,
  81. order,
  82. orderBy,
  83. handleGridSort,
  84. } = usePaginationHook(searchResultMemo || []);
  85. const collectionOptions: Option[] = useMemo(
  86. () =>
  87. collections.map(c => ({
  88. label: c.collectionName,
  89. value: c.collectionName,
  90. })),
  91. [collections]
  92. );
  93. const outputFields: string[] = useMemo(() => {
  94. const s = collections.find(c => c.collectionName === selectedCollection);
  95. if (!s) {
  96. return [];
  97. }
  98. const fields = s.fields || [];
  99. // vector field can't be output fields
  100. const invalidTypes = ['BinaryVector', 'FloatVector'];
  101. const nonVectorFields = fields.filter(
  102. field => !invalidTypes.includes(field.fieldType)
  103. );
  104. const _outputFields = nonVectorFields.map(f => f.name);
  105. if (s.enableDynamicField) {
  106. _outputFields.push(DYNAMIC_FIELD);
  107. }
  108. return _outputFields;
  109. }, [selectedCollection, collections]);
  110. const primaryKeyField = useMemo(() => {
  111. const selectedCollectionInfo = collections.find(
  112. c => c.collectionName === selectedCollection
  113. );
  114. const fields = selectedCollectionInfo?.fields || [];
  115. return fields.find(f => f.isPrimaryKey)?.name;
  116. }, [selectedCollection, collections]);
  117. const orderArray = [primaryKeyField, 'id', 'score', ...outputFields];
  118. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  119. /**
  120. * id represents primary key, score represents distance
  121. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  122. * we filter 'id' and 'score' to avoid redundant data
  123. */
  124. return searchResult && searchResult.length > 0
  125. ? Object.keys(searchResult[0])
  126. .sort((a, b) => {
  127. const indexA = orderArray.indexOf(a);
  128. const indexB = orderArray.indexOf(b);
  129. return indexA - indexB;
  130. })
  131. .filter(item => {
  132. // if primary key field name is id, don't filter it
  133. const invalidItems = primaryKeyField === 'id' ? [] : ['id'];
  134. return !invalidItems.includes(item);
  135. })
  136. .map(key => ({
  137. id: key,
  138. align: 'left',
  139. disablePadding: false,
  140. label: key === DYNAMIC_FIELD ? searchTrans('dynamicFields') : key,
  141. needCopy: primaryKeyField === key,
  142. }))
  143. : [];
  144. }, [searchResult, primaryKeyField, orderArray]);
  145. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  146. const {
  147. indexType,
  148. indexParams,
  149. fieldType,
  150. embeddingType,
  151. selectedFieldDimension,
  152. } = useMemo(() => {
  153. if (selectedField !== '') {
  154. // field options must contain selected field, so selectedFieldInfo will never undefined
  155. const selectedFieldInfo = fieldOptions.find(
  156. f => f.value === selectedField
  157. );
  158. const index = selectedFieldInfo?.indexInfo;
  159. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  160. const metric =
  161. index?.metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  162. const indexParams = index?.indexParameterPairs || [];
  163. const dim = selectedFieldInfo?.dimension || 0;
  164. setSelectedMetricType(metric);
  165. return {
  166. metricType: metric,
  167. indexType: index?.indexType || getDefaultIndexType(embeddingType),
  168. indexParams,
  169. fieldType:
  170. DataTypeEnum[
  171. selectedFieldInfo?.fieldType! as keyof typeof DataTypeEnum
  172. ],
  173. embeddingType,
  174. selectedFieldDimension: dim,
  175. };
  176. }
  177. return {
  178. indexType: '',
  179. indexParams: [],
  180. fieldType: 0,
  181. embeddingType: DataTypeEnum.FloatVector,
  182. selectedFieldDimension: 0,
  183. };
  184. }, [selectedField, fieldOptions]);
  185. /**
  186. * vector value validation
  187. * @return whether is valid
  188. */
  189. const vectorValueValid = useMemo(() => {
  190. // if user hasn't input value or not select field, don't trigger validation check
  191. if (vectors === '' || selectedFieldDimension === 0) {
  192. return true;
  193. }
  194. const dim =
  195. fieldType === DataTypeEnum.BinaryVector
  196. ? selectedFieldDimension / 8
  197. : selectedFieldDimension;
  198. const value = parseValue(vectors);
  199. const isArray = Array.isArray(value);
  200. return isArray && value.length === dim;
  201. }, [vectors, selectedFieldDimension, fieldType]);
  202. const searchDisabled = useMemo(() => {
  203. /**
  204. * before search, user must:
  205. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  206. * 2. choose collection and field
  207. * 3. set extra search params
  208. */
  209. const isInvalid =
  210. vectors === '' ||
  211. selectedCollection === '' ||
  212. selectedField === '' ||
  213. paramDisabled ||
  214. !vectorValueValid;
  215. return isInvalid;
  216. }, [
  217. paramDisabled,
  218. selectedField,
  219. selectedCollection,
  220. vectors,
  221. vectorValueValid,
  222. ]);
  223. // fetch data
  224. const fetchCollections = useCallback(async () => {
  225. const collections = await Collection.getCollections();
  226. setCollections(collections.filter(c => c.status === LOADING_STATE.LOADED));
  227. }, [database]);
  228. const fetchFieldsWithIndex = useCallback(
  229. async (collectionName: string, collections: Collection[]) => {
  230. const fields =
  231. collections.find(c => c.collectionName === collectionName)?.fields ||
  232. [];
  233. const indexes = await MilvusIndex.getIndexInfo(collectionName);
  234. const { vectorFields, nonVectorFields } = classifyFields(fields);
  235. // only vector type fields can be select
  236. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  237. setFieldOptions(fieldOptions);
  238. if (fieldOptions.length > 0) {
  239. // set first option value as default field value
  240. const [{ value: defaultFieldValue }] = fieldOptions;
  241. setSelectedField(defaultFieldValue as string);
  242. }
  243. // only non vector type fields can be advanced filter
  244. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  245. setFilterFields(filterFields);
  246. },
  247. [collections]
  248. );
  249. useEffect(() => {
  250. fetchCollections();
  251. }, [fetchCollections]);
  252. // clear selection if database is changed
  253. useEffect(() => {
  254. setSelectedCollection('');
  255. }, [database]);
  256. // get field options with index when selected collection changed
  257. useEffect(() => {
  258. if (selectedCollection !== '') {
  259. fetchFieldsWithIndex(selectedCollection, collections);
  260. }
  261. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  262. // set default collection value if is from overview page
  263. useEffect(() => {
  264. if (location.search && collections.length > 0) {
  265. const { collectionName } = parseLocationSearch(location.search);
  266. // collection name validation
  267. const isNameValid = collections
  268. .map(c => c.collectionName)
  269. .includes(collectionName);
  270. isNameValid && setSelectedCollection(collectionName);
  271. }
  272. }, [location, collections]);
  273. // icons
  274. const VectorSearchIcon = icons.vectorSearch;
  275. const ResetIcon = icons.refresh;
  276. const ArrowIcon = icons.dropdown;
  277. // methods
  278. const handlePageChange = (e: any, page: number) => {
  279. handleCurrentPage(page);
  280. };
  281. const handleReset = () => {
  282. /**
  283. * reset search includes:
  284. * 1. reset vectors
  285. * 2. reset selected collection and field
  286. * 3. reset search params
  287. * 4. reset advanced filter expression
  288. * 5. clear search result
  289. */
  290. setVectors('');
  291. setSelectedField('');
  292. setSelectedCollection('');
  293. setSearchResult(null);
  294. setFilterFields([]);
  295. setExpression('');
  296. };
  297. const handleSearch = async (topK: number, expr = expression) => {
  298. const clonedSearchParams = cloneObj(searchParam);
  299. delete clonedSearchParams.round_decimal;
  300. const searchParamPairs = {
  301. params: JSON.stringify(clonedSearchParams),
  302. anns_field: selectedField,
  303. topk: topK,
  304. metric_type: selectedMetricType,
  305. round_decimal: searchParam.round_decimal,
  306. };
  307. const params: VectorSearchParam = {
  308. output_fields: outputFields,
  309. expr,
  310. search_params: searchParamPairs,
  311. vectors: [parseValue(vectors)],
  312. vector_type: fieldType,
  313. };
  314. setTableLoading(true);
  315. try {
  316. const res = await Collection.vectorSearchData(
  317. selectedCollection,
  318. params
  319. );
  320. setTableLoading(false);
  321. setSearchResult(res.results);
  322. setLatency(res.latency);
  323. } catch (err) {
  324. setTableLoading(false);
  325. }
  326. };
  327. const handleAdvancedFilterChange = (expression: string) => {
  328. setExpression(expression);
  329. if (!searchDisabled) {
  330. handleSearch(topK, expression);
  331. }
  332. };
  333. const handleVectorChange = (value: string) => {
  334. setVectors(value);
  335. };
  336. const fillWithExampleVector = (selectedFieldDimension: number) => {
  337. const v = generateVector(selectedFieldDimension);
  338. setVectors(`[${v}]`);
  339. };
  340. return (
  341. <section className={`page-wrapper ${classes.pageContainer}`}>
  342. <Card className={classes.form}>
  343. <CardContent className={classes.s1}>
  344. <Typography className="text">{searchTrans('secondTip')}</Typography>
  345. <CustomSelector
  346. options={collectionOptions}
  347. wrapperClass={classes.selector}
  348. variant="filled"
  349. label={searchTrans(
  350. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  351. )}
  352. disabled={collectionOptions.length === 0}
  353. value={selectedCollection}
  354. onChange={(e: { target: { value: unknown } }) => {
  355. const collection = e.target.value as string;
  356. setSelectedCollection(collection);
  357. // every time selected collection changed, reset field
  358. setSelectedField('');
  359. setSearchResult([]);
  360. }}
  361. />
  362. <CustomSelector
  363. options={fieldOptions}
  364. // readOnly can't avoid all events, so we use disabled instead
  365. disabled={selectedCollection === ''}
  366. wrapperClass={classes.selector}
  367. variant="filled"
  368. label={searchTrans('field')}
  369. value={selectedField}
  370. onChange={(e: { target: { value: unknown } }) => {
  371. const field = e.target.value as string;
  372. setSelectedField(field);
  373. }}
  374. />
  375. </CardContent>
  376. <CardContent className={classes.s2}>
  377. <Typography className="text">
  378. {searchTrans('firstTip', {
  379. dimensionTip:
  380. selectedFieldDimension !== 0
  381. ? `(dimension: ${selectedFieldDimension})`
  382. : '',
  383. })}
  384. {selectedFieldDimension !== 0 ? (
  385. <Button
  386. className={classes.exampleBtn}
  387. variant="outlined"
  388. size="small"
  389. onClick={() => {
  390. const dim =
  391. fieldType === DataTypeEnum.BinaryVector
  392. ? selectedFieldDimension / 8
  393. : selectedFieldDimension;
  394. fillWithExampleVector(dim);
  395. }}
  396. >
  397. {btnTrans('example')}
  398. </Button>
  399. ) : null}
  400. </Typography>
  401. <textarea
  402. className="textarea"
  403. placeholder={searchTrans('vectorPlaceholder')}
  404. value={vectors}
  405. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  406. handleVectorChange(e.target.value as string);
  407. }}
  408. />
  409. {/* validation */}
  410. {!vectorValueValid && (
  411. <Typography variant="caption" className={classes.error}>
  412. {searchTrans('vectorValueWarning', {
  413. dimension:
  414. fieldType === DataTypeEnum.BinaryVector
  415. ? selectedFieldDimension / 8
  416. : selectedFieldDimension,
  417. })}
  418. </Typography>
  419. )}
  420. </CardContent>
  421. <CardContent className={classes.s3}>
  422. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  423. <SearchParams
  424. wrapperClass={classes.paramsWrapper}
  425. metricType={selectedMetricType}
  426. embeddingType={
  427. embeddingType as
  428. | DataTypeEnum.BinaryVector
  429. | DataTypeEnum.FloatVector
  430. }
  431. indexType={indexType}
  432. indexParams={indexParams!}
  433. searchParamsForm={searchParam}
  434. handleFormChange={setSearchParam}
  435. handleMetricTypeChange={setSelectedMetricType}
  436. topK={topK}
  437. setParamsDisabled={setParamDisabled}
  438. />
  439. </CardContent>
  440. </Card>
  441. <section className={classes.resultsWrapper}>
  442. <section className={classes.toolbar}>
  443. <div className="left">
  444. <Typography variant="h5" className="text">
  445. {`${searchTrans('result')}: `}
  446. </Typography>
  447. {/* topK selector */}
  448. <SimpleMenu
  449. label={searchTrans('topK', { number: topK })}
  450. menuItems={TOP_K_OPTIONS.map(item => ({
  451. label: item.toString(),
  452. callback: () => {
  453. setTopK(item);
  454. if (!searchDisabled) {
  455. handleSearch(item);
  456. }
  457. },
  458. wrapperClass: classes.menuItem,
  459. }))}
  460. buttonProps={{
  461. className: classes.menuLabel,
  462. endIcon: <ArrowIcon />,
  463. }}
  464. menuItemWidth="108px"
  465. />
  466. <Filter
  467. title="Advanced Filter"
  468. fields={filterFields}
  469. filterDisabled={selectedField === '' || selectedCollection === ''}
  470. onSubmit={handleAdvancedFilterChange}
  471. />
  472. </div>
  473. <div className="right">
  474. <CustomButton className="btn" onClick={handleReset}>
  475. <ResetIcon classes={{ root: 'icon' }} />
  476. {btnTrans('reset')}
  477. </CustomButton>
  478. <CustomButton
  479. variant="contained"
  480. disabled={searchDisabled}
  481. onClick={() => handleSearch(topK)}
  482. >
  483. {btnTrans('search')}
  484. </CustomButton>
  485. </div>
  486. </section>
  487. {/* search result table section */}
  488. {(searchResult && searchResult.length > 0) || tableLoading ? (
  489. <AttuGrid
  490. toolbarConfigs={[]}
  491. colDefinitions={colDefinitions}
  492. rows={result}
  493. rowCount={total}
  494. primaryKey="rank"
  495. page={currentPage}
  496. onPageChange={handlePageChange}
  497. rowsPerPage={pageSize}
  498. setRowsPerPage={handlePageSize}
  499. openCheckBox={false}
  500. isLoading={tableLoading}
  501. orderBy={orderBy}
  502. order={order}
  503. labelDisplayedRows={getLabelDisplayedRows(`(${latency} ms)`)}
  504. handleSort={handleGridSort}
  505. tableCellMaxWidth="100%"
  506. />
  507. ) : (
  508. <EmptyCard
  509. wrapperClass={`page-empty-card`}
  510. icon={<VectorSearchIcon />}
  511. text={
  512. searchResult !== null
  513. ? searchTrans('empty')
  514. : searchTrans('startTip')
  515. }
  516. />
  517. )}
  518. </section>
  519. </section>
  520. );
  521. };
  522. export default VectorSearch;