VectorSearch.tsx 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. import { useCallback, useEffect, useMemo, useState, useContext } from 'react';
  2. import { Typography, Button, Card, CardContent } from '@material-ui/core';
  3. import { useTranslation } from 'react-i18next';
  4. import { useLocation } from 'react-router-dom';
  5. import { ALL_ROUTER_TYPES } from '@/router/Types';
  6. import { useNavigationHook, useSearchResult, usePaginationHook } from '@/hooks';
  7. import { dataContext } from '@/context';
  8. import { saveCsvAs } from '@/utils';
  9. import CustomSelector from '@/components/customSelector/CustomSelector';
  10. import { ColDefinitionsType } from '@/components/grid/Types';
  11. import AttuGrid from '@/components/grid/Grid';
  12. import EmptyCard from '@/components/cards/EmptyCard';
  13. import icons from '@/components/icons/Icons';
  14. import CustomButton from '@/components/customButton/CustomButton';
  15. import SimpleMenu from '@/components/menu/SimpleMenu';
  16. import { Option } from '@/components/customSelector/Types';
  17. import Filter from '@/components/advancedSearch';
  18. import { DataService } from '@/http';
  19. import {
  20. parseValue,
  21. parseLocationSearch,
  22. getVectorFieldOptions,
  23. cloneObj,
  24. generateVector,
  25. formatNumber,
  26. } from '@/utils';
  27. import { LOADING_STATE, DYNAMIC_FIELD, DataTypeEnum } from '@/consts';
  28. import { getLabelDisplayedRows } from './Utils';
  29. import SearchParams from './SearchParams';
  30. import { getVectorSearchStyles } from './Styles';
  31. import { TOP_K_OPTIONS } from './Constants';
  32. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  33. import { FieldObject, CollectionFullObject } from '@server/types';
  34. const VectorSearch = () => {
  35. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  36. const location = useLocation();
  37. const { database, collections } = useContext(dataContext);
  38. // i18n
  39. const { t: searchTrans } = useTranslation('search');
  40. const { t: btnTrans } = useTranslation('btn');
  41. const classes = getVectorSearchStyles();
  42. // data stored inside the component
  43. const [tableLoading, setTableLoading] = useState<boolean>(false);
  44. const [selectedCollection, setSelectedCollection] = useState<string>('');
  45. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  46. // fields for advanced filter
  47. const [filterFields, setFilterFields] = useState<FieldObject[]>([]);
  48. const [selectedField, setSelectedField] = useState<string>('');
  49. // search params form
  50. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  51. {}
  52. );
  53. // search params disable state
  54. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  55. // use null as init value before search, empty array means no results
  56. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  57. null
  58. );
  59. // default topK is 100
  60. const [topK, setTopK] = useState<number>(100);
  61. const [expression, setExpression] = useState<string>('');
  62. const [vectors, setVectors] = useState<string>('');
  63. // latency
  64. const [latency, setLatency] = useState<number>(0);
  65. const searchResultMemo = useSearchResult(searchResult as any);
  66. const {
  67. pageSize,
  68. handlePageSize,
  69. currentPage,
  70. handleCurrentPage,
  71. total,
  72. data: result,
  73. order,
  74. orderBy,
  75. handleGridSort,
  76. } = usePaginationHook(searchResultMemo || []);
  77. const outputFields: string[] = useMemo(() => {
  78. const s = collections.find(
  79. c => c.collection_name === selectedCollection
  80. ) as CollectionFullObject;
  81. if (!s) {
  82. return [];
  83. }
  84. const fields = s.schema.fields || [];
  85. // vector field can't be output fields
  86. const invalidTypes = ['BinaryVector', 'FloatVector'];
  87. const nonVectorFields = fields.filter(
  88. field => !invalidTypes.includes(field.data_type)
  89. );
  90. const _outputFields = nonVectorFields.map(f => f.name);
  91. if (s.schema?.enable_dynamic_field) {
  92. _outputFields.push(DYNAMIC_FIELD);
  93. }
  94. return _outputFields;
  95. }, [selectedCollection, collections]);
  96. const primaryKeyField = useMemo(() => {
  97. const selectedCollectionInfo = collections.find(
  98. c => c.collection_name === selectedCollection
  99. );
  100. const fields = selectedCollectionInfo?.schema?.fields || [];
  101. return fields.find(f => f.is_primary_key)?.name;
  102. }, [selectedCollection, collections]);
  103. const orderArray = [primaryKeyField, 'id', 'score', ...outputFields];
  104. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  105. /**
  106. * id represents primary key, score represents distance
  107. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  108. * we filter 'id' and 'score' to avoid redundant data
  109. */
  110. return searchResult && searchResult.length > 0
  111. ? Object.keys(searchResult[0])
  112. .sort((a, b) => {
  113. const indexA = orderArray.indexOf(a);
  114. const indexB = orderArray.indexOf(b);
  115. return indexA - indexB;
  116. })
  117. .filter(item => {
  118. // if primary key field name is id, don't filter it
  119. const invalidItems = primaryKeyField === 'id' ? [] : ['id'];
  120. return !invalidItems.includes(item);
  121. })
  122. .map(key => ({
  123. id: key,
  124. align: 'left',
  125. disablePadding: false,
  126. label: key === DYNAMIC_FIELD ? searchTrans('dynamicFields') : key,
  127. needCopy: key !== 'score',
  128. }))
  129. : [];
  130. }, [searchResult, primaryKeyField, orderArray]);
  131. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  132. const [selectedConsistencyLevel, setSelectedConsistencyLevel] =
  133. useState<string>('');
  134. const { indexType, indexParams, fieldType, selectedFieldDimension } =
  135. useMemo(() => {
  136. if (selectedField !== '') {
  137. // field options must contain selected field, so selectedFieldInfo will never undefined
  138. const field = fieldOptions.find(f => f.value === selectedField)?.field;
  139. const metric = field?.index.metricType || '';
  140. setSelectedMetricType(metric);
  141. return {
  142. metricType: metric,
  143. indexType: field?.index.indexType,
  144. indexParams: field?.index.indexParameterPairs || [],
  145. fieldType: field?.dataType,
  146. selectedFieldDimension: field?.dimension || 0,
  147. };
  148. }
  149. return {
  150. indexType: '',
  151. indexParams: [],
  152. fieldType: DataTypeEnum.FloatVector,
  153. selectedFieldDimension: 0,
  154. };
  155. }, [selectedField, fieldOptions]);
  156. /**
  157. * vector value validation
  158. * @return whether is valid
  159. */
  160. const vectorValueValid = useMemo(() => {
  161. // if user hasn't input value or not select field, don't trigger validation check
  162. if (vectors === '' || selectedFieldDimension === 0) {
  163. return true;
  164. }
  165. const dim =
  166. fieldType === DataTypeEnum.BinaryVector
  167. ? selectedFieldDimension / 8
  168. : selectedFieldDimension;
  169. const value = parseValue(vectors);
  170. const isArray = Array.isArray(value);
  171. return isArray && value.length === dim;
  172. }, [vectors, selectedFieldDimension, fieldType]);
  173. const searchDisabled = useMemo(() => {
  174. /**
  175. * before search, user must:
  176. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  177. * 2. choose collection and field
  178. * 3. set extra search params
  179. */
  180. const isInvalid =
  181. vectors === '' ||
  182. selectedCollection === '' ||
  183. selectedField === '' ||
  184. paramDisabled ||
  185. !vectorValueValid;
  186. return isInvalid;
  187. }, [
  188. paramDisabled,
  189. selectedField,
  190. selectedCollection,
  191. vectors,
  192. vectorValueValid,
  193. ]);
  194. // fetch data
  195. const loadedCollections = collections.filter(
  196. c => c.status === LOADING_STATE.LOADED
  197. ) as CollectionFullObject[];
  198. // sort by rowCounts
  199. loadedCollections.sort((a, b) => b.rowCount - a.rowCount);
  200. const collectionOptions: Option[] = useMemo(
  201. () =>
  202. loadedCollections.map(c => ({
  203. label: `${c.collection_name}(${formatNumber(c.rowCount)})`,
  204. value: c.collection_name,
  205. })),
  206. [loadedCollections]
  207. );
  208. const fetchFieldsWithIndex = useCallback(
  209. async (collectionName: string, collections: CollectionFullObject[]) => {
  210. const col = collections.find(c => c.collection_name === collectionName);
  211. // only vector type fields can be select
  212. const fieldOptions = getVectorFieldOptions(col?.schema.vectorFields!);
  213. setFieldOptions(fieldOptions);
  214. if (fieldOptions.length > 0) {
  215. // set first option value as default field value
  216. const [{ value: defaultFieldValue }] = fieldOptions;
  217. setSelectedField(defaultFieldValue as string);
  218. }
  219. setFilterFields(col?.schema.scalarFields!);
  220. },
  221. [collections]
  222. );
  223. // clear selection if database is changed
  224. useEffect(() => {
  225. setSelectedCollection('');
  226. setVectors('');
  227. setSearchResult(null);
  228. }, [database]);
  229. // get field options with index when selected collection changed
  230. useEffect(() => {
  231. if (selectedCollection !== '') {
  232. fetchFieldsWithIndex(
  233. selectedCollection,
  234. collections as CollectionFullObject[]
  235. );
  236. }
  237. const level = collections.find(c => c.collection_name == selectedCollection)
  238. ?.consistency_level!;
  239. setSelectedConsistencyLevel(level);
  240. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  241. // set default collection value if is from overview page
  242. useEffect(() => {
  243. if (location.search && collections.length > 0) {
  244. const { collectionName } = parseLocationSearch(location.search);
  245. // collection name validation
  246. const isNameValid = collections
  247. .map(c => c.collection_name)
  248. .includes(collectionName);
  249. isNameValid && setSelectedCollection(collectionName);
  250. }
  251. }, [location, collections]);
  252. // icons
  253. const VectorSearchIcon = icons.vectorSearch;
  254. const ResetIcon = icons.refresh;
  255. const ArrowIcon = icons.dropdown;
  256. const ExportIcon = icons.download;
  257. // methods
  258. const handlePageChange = (e: any, page: number) => {
  259. handleCurrentPage(page);
  260. };
  261. const handleReset = () => {
  262. /**
  263. * reset search includes:
  264. * 1. reset vectors
  265. * 2. reset selected collection and field
  266. * 3. reset search params
  267. * 4. reset advanced filter expression
  268. * 5. clear search result
  269. */
  270. setVectors('');
  271. setSelectedField('');
  272. setSelectedCollection('');
  273. setSearchResult(null);
  274. setFilterFields([]);
  275. setExpression('');
  276. };
  277. const handleSearch = async (topK: number, expr = expression) => {
  278. const clonedSearchParams = cloneObj(searchParam);
  279. delete clonedSearchParams.round_decimal;
  280. const searchParamPairs = {
  281. params: JSON.stringify(clonedSearchParams),
  282. anns_field: selectedField,
  283. topk: topK,
  284. metric_type: selectedMetricType,
  285. round_decimal: searchParam.round_decimal,
  286. };
  287. const params: VectorSearchParam = {
  288. output_fields: outputFields,
  289. expr,
  290. search_params: searchParamPairs,
  291. vectors: [parseValue(vectors)],
  292. vector_type: fieldType as DataTypeEnum,
  293. consistency_level:
  294. selectedConsistencyLevel ||
  295. collections.find(c => c.collection_name == selectedCollection)
  296. ?.consistency_level!,
  297. };
  298. setTableLoading(true);
  299. try {
  300. const res = await DataService.vectorSearchData(
  301. selectedCollection,
  302. params
  303. );
  304. setTableLoading(false);
  305. setSearchResult(res.results);
  306. setLatency(res.latency);
  307. } catch (err) {
  308. setTableLoading(false);
  309. }
  310. };
  311. const handleAdvancedFilterChange = (expression: string) => {
  312. setExpression(expression);
  313. if (!searchDisabled) {
  314. handleSearch(topK, expression);
  315. }
  316. };
  317. const handleVectorChange = (value: string) => {
  318. setVectors(value);
  319. };
  320. const fillWithExampleVector = (selectedFieldDimension: number) => {
  321. const v = generateVector(selectedFieldDimension);
  322. setVectors(`[${v}]`);
  323. };
  324. return (
  325. <section className={`page-wrapper ${classes.pageContainer}`}>
  326. <section className={classes.form}>
  327. <CardContent className={classes.s1}>
  328. <div className="wrapper">
  329. <CustomSelector
  330. options={collectionOptions}
  331. wrapperClass={classes.selector}
  332. variant="filled"
  333. label={searchTrans('collection')}
  334. disabled={collectionOptions.length === 0}
  335. value={selectedCollection}
  336. onChange={(e: { target: { value: unknown } }) => {
  337. const collection = e.target.value as string;
  338. setSelectedCollection(collection);
  339. // every time selected collection changed, reset field
  340. setSelectedField('');
  341. setSearchResult([]);
  342. }}
  343. />
  344. <CustomSelector
  345. options={fieldOptions}
  346. // readOnly can't avoid all events, so we use disabled instead
  347. disabled={selectedCollection === ''}
  348. wrapperClass={classes.selector}
  349. variant="filled"
  350. label={searchTrans('field')}
  351. value={selectedField}
  352. onChange={(e: { target: { value: unknown } }) => {
  353. const field = e.target.value as string;
  354. setSelectedField(field);
  355. }}
  356. />
  357. </div>
  358. </CardContent>
  359. <CardContent className={classes.s2}>
  360. <textarea
  361. className="textarea"
  362. placeholder={searchTrans('vectorPlaceholder')}
  363. value={vectors}
  364. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  365. handleVectorChange(e.target.value as string);
  366. }}
  367. />
  368. {/* validation */}
  369. {!vectorValueValid && (
  370. <Typography variant="caption" className={classes.error}>
  371. {searchTrans('vectorValueWarning', {
  372. dimension:
  373. fieldType === DataTypeEnum.BinaryVector
  374. ? selectedFieldDimension / 8
  375. : selectedFieldDimension,
  376. })}
  377. </Typography>
  378. )}
  379. {selectedFieldDimension !== 0 ? (
  380. <CustomButton
  381. className={classes.exampleBtn}
  382. onClick={() => {
  383. const dim =
  384. fieldType === DataTypeEnum.BinaryVector
  385. ? selectedFieldDimension / 8
  386. : selectedFieldDimension;
  387. fillWithExampleVector(dim);
  388. }}
  389. >
  390. {btnTrans('example')}
  391. </CustomButton>
  392. ) : null}
  393. <CustomButton
  394. variant="contained"
  395. disabled={searchDisabled}
  396. onClick={() => handleSearch(topK)}
  397. >
  398. {btnTrans('search')}
  399. </CustomButton>
  400. </CardContent>
  401. <CardContent className={classes.s3}>
  402. <Typography className="text">
  403. {searchTrans('thirdTip', {
  404. metricType: `${
  405. selectedMetricType ? `(${selectedMetricType})` : ''
  406. }`,
  407. })}
  408. </Typography>
  409. <SearchParams
  410. wrapperClass={classes.paramsWrapper}
  411. consistency_level={selectedConsistencyLevel}
  412. handleConsistencyChange={(level: string) => {
  413. setSelectedConsistencyLevel(level);
  414. }}
  415. indexType={indexType!}
  416. indexParams={indexParams!}
  417. searchParamsForm={searchParam}
  418. handleFormChange={setSearchParam}
  419. topK={topK}
  420. setParamsDisabled={setParamDisabled}
  421. />
  422. </CardContent>
  423. </section>
  424. <section className={classes.resultsWrapper}>
  425. <section className={classes.toolbar}>
  426. <div className="left">
  427. <Typography variant="h5" className="text">
  428. {`${searchTrans('result')}: `}
  429. </Typography>
  430. {/* topK selector */}
  431. <SimpleMenu
  432. label={searchTrans('topK', { number: topK })}
  433. menuItems={TOP_K_OPTIONS.map(item => ({
  434. label: item.toString(),
  435. callback: () => {
  436. setTopK(item);
  437. if (!searchDisabled) {
  438. handleSearch(item);
  439. }
  440. },
  441. wrapperClass: classes.menuItem,
  442. }))}
  443. buttonProps={{
  444. className: classes.menuLabel,
  445. endIcon: <ArrowIcon />,
  446. }}
  447. menuItemWidth="108px"
  448. />
  449. <Filter
  450. title="Advanced Filter"
  451. fields={filterFields}
  452. filterDisabled={selectedField === '' || selectedCollection === ''}
  453. onSubmit={handleAdvancedFilterChange}
  454. />
  455. <CustomButton
  456. className="btn"
  457. disabled={result.length === 0}
  458. onClick={() => {
  459. saveCsvAs(searchResult, `search_result_${selectedCollection}`);
  460. }}
  461. startIcon={<ExportIcon classes={{ root: 'icon' }}></ExportIcon>}
  462. >
  463. {btnTrans('export')}
  464. </CustomButton>
  465. </div>
  466. <div className="right">
  467. <CustomButton
  468. className="btn"
  469. onClick={handleReset}
  470. startIcon={<ResetIcon classes={{ root: 'icon' }} />}
  471. >
  472. {btnTrans('reset')}
  473. </CustomButton>
  474. </div>
  475. </section>
  476. {/* search result table section */}
  477. {(searchResult && searchResult.length > 0) || tableLoading ? (
  478. <AttuGrid
  479. toolbarConfigs={[]}
  480. colDefinitions={colDefinitions}
  481. rows={result}
  482. rowCount={total}
  483. primaryKey="rank"
  484. page={currentPage}
  485. rowHeight={41}
  486. onPageChange={handlePageChange}
  487. rowsPerPage={pageSize}
  488. setRowsPerPage={handlePageSize}
  489. openCheckBox={false}
  490. isLoading={tableLoading}
  491. orderBy={orderBy}
  492. order={order}
  493. labelDisplayedRows={getLabelDisplayedRows(`(${latency} ms)`)}
  494. handleSort={handleGridSort}
  495. />
  496. ) : (
  497. <EmptyCard
  498. wrapperClass={`page-empty-card`}
  499. icon={<VectorSearchIcon />}
  500. text={
  501. searchResult !== null
  502. ? searchTrans('empty')
  503. : searchTrans('startTip')
  504. }
  505. />
  506. )}
  507. </section>
  508. </section>
  509. );
  510. };
  511. export default VectorSearch;