VectorSearch.tsx 19 KB


  1. import { TextField, Typography, Button } from '@material-ui/core';
  2. import { useTranslation } from 'react-i18next';
  3. import { useNavigationHook } from '@/hooks/Navigation';
  4. import { ALL_ROUTER_TYPES } from '@/router/Types';
  5. import CustomSelector from '@/components/customSelector/CustomSelector';
  6. import { useCallback, useEffect, useMemo, useState } from 'react';
  7. import SearchParams from './SearchParams';
  8. import { DEFAULT_METRIC_VALUE_MAP } from '../../consts/Milvus';
  9. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  10. import AttuGrid from '@/components/grid/Grid';
  11. import EmptyCard from '@/components/cards/EmptyCard';
  12. import icons from '@/components/icons/Icons';
  13. import { usePaginationHook } from '@/hooks/Pagination';
  14. import CustomButton from '@/components/customButton/CustomButton';
  15. import SimpleMenu from '@/components/menu/SimpleMenu';
  16. import { TOP_K_OPTIONS } from './Constants';
  17. import { Option } from '@/components/customSelector/Types';
  18. import { CollectionHttp } from '@/http/Collection';
  19. import { CollectionData, DataTypeEnum } from '../collections/Types';
  20. import { IndexHttp } from '@/http/Index';
  21. import { getVectorSearchStyles } from './Styles';
  22. import { parseValue } from '@/utils/Insert';
  23. import {
  24. classifyFields,
  25. getDefaultIndexType,
  26. getEmbeddingType,
  27. getNonVectorFieldsForFilter,
  28. getVectorFieldOptions,
  29. transferSearchResult,
  30. } from '@/utils/search';
  31. import { ColDefinitionsType } from '@/components/grid/Types';
  32. import Filter from '@/components/advancedSearch';
  33. import { Field } from '@/components/advancedSearch/Types';
  34. import { useLocation } from 'react-router-dom';
  35. import { parseLocationSearch } from '@/utils/Format';
  36. import { cloneObj, generateVector } from '@/utils/Common';
  37. import { CustomDatePicker } from '@/components/customDatePicker/CustomDatePicker';
  38. import { useTimeTravelHook } from '@/hooks/TimeTravel';
  39. import { LOADING_STATE } from '../../consts/Milvus';
  40. import { getLabelDisplayedRows } from './Utils';
  41. const VectorSearch = () => {
  42. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  43. const location = useLocation();
  44. // i18n
  45. const { t: searchTrans } = useTranslation('search');
  46. const { t: btnTrans } = useTranslation('btn');
  47. const classes = getVectorSearchStyles();
  48. // data stored inside the component
  49. const [tableLoading, setTableLoading] = useState<boolean>(false);
  50. const [collections, setCollections] = useState<CollectionData[]>([]);
  51. const [selectedCollection, setSelectedCollection] = useState<string>('');
  52. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  53. // fields for advanced filter
  54. const [filterFields, setFilterFields] = useState<Field[]>([]);
  55. const [selectedField, setSelectedField] = useState<string>('');
  56. // search params form
  57. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  58. {}
  59. );
  60. // search params disable state
  61. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  62. // use null as init value before search, empty array means no results
  63. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  64. null
  65. );
  66. // default topK is 100
  67. const [topK, setTopK] = useState<number>(100);
  68. const [expression, setExpression] = useState<string>('');
  69. const [vectors, setVectors] = useState<string>('');
  70. // latency
  71. const [latency, setLatency] = useState<number>(0);
  72. const {
  73. pageSize,
  74. handlePageSize,
  75. currentPage,
  76. handleCurrentPage,
  77. total,
  78. data: result,
  79. order,
  80. orderBy,
  81. handleGridSort,
  82. } = usePaginationHook(searchResult || []);
  83. const { timeTravel, setTimeTravel, timeTravelInfo, handleDateTimeChange } =
  84. useTimeTravelHook();
  85. const collectionOptions: Option[] = useMemo(
  86. () =>
  87. collections.map(c => ({
  88. label: c._name,
  89. value: c._name,
  90. })),
  91. [collections]
  92. );
  93. const outputFields: string[] = useMemo(() => {
  94. const fields =
  95. collections.find(c => c._name === selectedCollection)?._fields || [];
  96. // vector field can't be output fields
  97. const invalidTypes = ['BinaryVector', 'FloatVector'];
  98. const nonVectorFields = fields.filter(
  99. field => !invalidTypes.includes(field._fieldType)
  100. );
  101. return nonVectorFields.map(f => f._fieldName);
  102. }, [selectedCollection, collections]);
  103. const primaryKeyField = useMemo(() => {
  104. const selectedCollectionInfo = collections.find(
  105. c => c._name === selectedCollection
  106. );
  107. const fields = selectedCollectionInfo?._fields || [];
  108. return fields.find(f => f._isPrimaryKey)?._fieldName;
  109. }, [selectedCollection, collections]);
  110. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  111. /**
  112. * id represents primary key, score represents distance
  113. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  114. * we filter 'id' and 'score' to avoid redundant data
  115. */
  116. return searchResult && searchResult.length > 0
  117. ? Object.keys(searchResult[0])
  118. .filter(item => {
  119. // if primary key field name is id, don't filter it
  120. const invalidItems =
  121. primaryKeyField === 'id' ? ['score'] : ['id', 'score'];
  122. return !invalidItems.includes(item);
  123. })
  124. .map(key => ({
  125. id: key,
  126. align: 'left',
  127. disablePadding: false,
  128. label: key,
  129. needCopy: primaryKeyField === key,
  130. }))
  131. : [];
  132. }, [searchResult, primaryKeyField]);
  133. const [selectedMetricType, setSelectedMetricType] = useState<string>('');
  134. const {
  135. indexType,
  136. indexParams,
  137. fieldType,
  138. embeddingType,
  139. selectedFieldDimension,
  140. } = useMemo(() => {
  141. if (selectedField !== '') {
  142. // field options must contain selected field, so selectedFieldInfo will never undefined
  143. const selectedFieldInfo = fieldOptions.find(
  144. f => f.value === selectedField
  145. );
  146. const index = selectedFieldInfo?.indexInfo;
  147. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  148. const metric =
  149. index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  150. const indexParams = index?._indexParameterPairs || [];
  151. const dim = selectedFieldInfo?.dimension || 0;
  152. setSelectedMetricType(metric);
  153. return {
  154. metricType: metric,
  155. indexType: index?._indexType || getDefaultIndexType(embeddingType),
  156. indexParams,
  157. fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
  158. embeddingType,
  159. selectedFieldDimension: dim,
  160. };
  161. }
  162. return {
  163. indexType: '',
  164. indexParams: [],
  165. fieldType: 0,
  166. embeddingType: DataTypeEnum.FloatVector,
  167. selectedFieldDimension: 0,
  168. };
  169. }, [selectedField, fieldOptions]);
  170. /**
  171. * vector value validation
  172. * @return whether is valid
  173. */
  174. const vectorValueValid = useMemo(() => {
  175. // if user hasn't input value or not select field, don't trigger validation check
  176. if (vectors === '' || selectedFieldDimension === 0) {
  177. return true;
  178. }
  179. const dim =
  180. fieldType === DataTypeEnum.BinaryVector
  181. ? selectedFieldDimension / 8
  182. : selectedFieldDimension;
  183. const value = parseValue(vectors);
  184. const isArray = Array.isArray(value);
  185. return isArray && value.length === dim;
  186. }, [vectors, selectedFieldDimension, fieldType]);
  187. const searchDisabled = useMemo(() => {
  188. /**
  189. * before search, user must:
  190. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  191. * 2. choose collection and field
  192. * 3. set extra search params
  193. */
  194. const isInvalid =
  195. vectors === '' ||
  196. selectedCollection === '' ||
  197. selectedField === '' ||
  198. paramDisabled ||
  199. !vectorValueValid;
  200. return isInvalid;
  201. }, [
  202. paramDisabled,
  203. selectedField,
  204. selectedCollection,
  205. vectors,
  206. vectorValueValid,
  207. ]);
  208. // fetch data
  209. const fetchCollections = useCallback(async () => {
  210. const collections = await CollectionHttp.getCollections();
  211. setCollections(collections.filter(c => c._status === LOADING_STATE.LOADED));
  212. }, []);
  213. const fetchFieldsWithIndex = useCallback(
  214. async (collectionName: string, collections: CollectionData[]) => {
  215. const fields =
  216. collections.find(c => c._name === collectionName)?._fields || [];
  217. const indexes = await IndexHttp.getIndexInfo(collectionName);
  218. const { vectorFields, nonVectorFields } = classifyFields(fields);
  219. // only vector type fields can be select
  220. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  221. setFieldOptions(fieldOptions);
  222. if (fieldOptions.length > 0) {
  223. // set first option value as default field value
  224. const [{ value: defaultFieldValue }] = fieldOptions;
  225. setSelectedField(defaultFieldValue as string);
  226. }
  227. // only non vector type fields can be advanced filter
  228. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  229. setFilterFields(filterFields);
  230. },
  231. []
  232. );
  233. useEffect(() => {
  234. fetchCollections();
  235. }, [fetchCollections]);
  236. // get field options with index when selected collection changed
  237. useEffect(() => {
  238. if (selectedCollection !== '') {
  239. fetchFieldsWithIndex(selectedCollection, collections);
  240. }
  241. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  242. // set default collection value if is from overview page
  243. useEffect(() => {
  244. if (location.search && collections.length > 0) {
  245. const { collectionName } = parseLocationSearch(location.search);
  246. // collection name validation
  247. const isNameValid = collections
  248. .map(c => c._name)
  249. .includes(collectionName);
  250. isNameValid && setSelectedCollection(collectionName);
  251. }
  252. }, [location, collections]);
  253. // icons
  254. const VectorSearchIcon = icons.vectorSearch;
  255. const ResetIcon = icons.refresh;
  256. const ArrowIcon = icons.dropdown;
  257. // methods
  258. const handlePageChange = (e: any, page: number) => {
  259. handleCurrentPage(page);
  260. };
  261. const handleReset = () => {
  262. /**
  263. * reset search includes:
  264. * 1. reset vectors
  265. * 2. reset selected collection and field
  266. * 3. reset search params
  267. * 4. reset advanced filter expression
  268. * 5. clear search result
  269. */
  270. setVectors('');
  271. setSelectedField('');
  272. setSelectedCollection('');
  273. setSearchResult(null);
  274. setFilterFields([]);
  275. setExpression('');
  276. setTimeTravel(null);
  277. };
  278. const handleSearch = async (topK: number, expr = expression) => {
  279. const clonedSearchParams = cloneObj(searchParam);
  280. delete clonedSearchParams.round_decimal;
  281. const searhParamPairs = {
  282. params: JSON.stringify(clonedSearchParams),
  283. anns_field: selectedField,
  284. topk: topK,
  285. metric_type: selectedMetricType,
  286. round_decimal: searchParam.round_decimal,
  287. };
  288. const params: VectorSearchParam = {
  289. output_fields: outputFields,
  290. expr,
  291. search_params: searhParamPairs,
  292. vectors: [parseValue(vectors)],
  293. vector_type: fieldType,
  294. travel_timestamp: timeTravelInfo.timestamp,
  295. };
  296. setTableLoading(true);
  297. try {
  298. const res = await CollectionHttp.vectorSearchData(
  299. selectedCollection,
  300. params
  301. );
  302. setTableLoading(false);
  303. const result = transferSearchResult(res.results);
  304. setSearchResult(result);
  305. setLatency(res.latency);
  306. } catch (err) {
  307. setTableLoading(false);
  308. }
  309. };
  310. const handleAdvancedFilterChange = (expression: string) => {
  311. setExpression(expression);
  312. if (!searchDisabled) {
  313. handleSearch(topK, expression);
  314. }
  315. };
  316. const handleVectorChange = (value: string) => {
  317. setVectors(value);
  318. };
  319. const fillWithExampleVector = (selectedFieldDimension: number) => {
  320. const v = generateVector(selectedFieldDimension);
  321. setVectors(`[${v}]`);
  322. };
  323. return (
  324. <section className="page-wrapper">
  325. {/* form section */}
  326. <form className={classes.form}>
  327. {/* collection and field selectors */}
  328. <fieldset className="field">
  329. <Typography className="text">{searchTrans('secondTip')}</Typography>
  330. <CustomSelector
  331. options={collectionOptions}
  332. wrapperClass={classes.selector}
  333. variant="filled"
  334. label={searchTrans(
  335. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  336. )}
  337. disabled={collectionOptions.length === 0}
  338. value={selectedCollection}
  339. onChange={(e: { target: { value: unknown } }) => {
  340. const collection = e.target.value;
  341. setSelectedCollection(collection as string);
  342. // every time selected collection changed, reset field
  343. setSelectedField('');
  344. setSearchResult([]);
  345. }}
  346. />
  347. <CustomSelector
  348. options={fieldOptions}
  349. // readOnly can't avoid all events, so we use disabled instead
  350. disabled={selectedCollection === ''}
  351. wrapperClass={classes.selector}
  352. variant="filled"
  353. label={searchTrans('field')}
  354. value={selectedField}
  355. onChange={(e: { target: { value: unknown } }) => {
  356. const field = e.target.value;
  357. setSelectedField(field as string);
  358. }}
  359. />
  360. </fieldset>
  361. {/**
  362. * vector value textarea
  363. * use field-params class because it also has error msg if invalid
  364. */}
  365. <fieldset className="field field-params field-second">
  366. <Typography className="text">
  367. {searchTrans('firstTip', {
  368. dimensionTip:
  369. selectedFieldDimension !== 0
  370. ? `(dimension: ${selectedFieldDimension})`
  371. : '',
  372. })}
  373. {selectedFieldDimension !== 0 ? (
  374. <Button
  375. variant="outlined"
  376. size="small"
  377. onClick={() => {
  378. const dim =
  379. fieldType === DataTypeEnum.BinaryVector
  380. ? selectedFieldDimension / 8
  381. : selectedFieldDimension;
  382. fillWithExampleVector(dim);
  383. }}
  384. >
  385. {btnTrans('example')}
  386. </Button>
  387. ) : null}
  388. </Typography>
  389. <TextField
  390. className="textarea"
  391. InputProps={{
  392. classes: {
  393. root: 'textfield',
  394. multiline: 'multiline',
  395. },
  396. }}
  397. multiline
  398. rows={5}
  399. placeholder={searchTrans('vectorPlaceholder')}
  400. value={vectors}
  401. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  402. handleVectorChange(e.target.value as string);
  403. }}
  404. />
  405. {/* validation */}
  406. {!vectorValueValid && (
  407. <Typography variant="caption" className={classes.error}>
  408. {searchTrans('vectorValueWarning', {
  409. dimension:
  410. fieldType === DataTypeEnum.BinaryVector
  411. ? selectedFieldDimension / 8
  412. : selectedFieldDimension,
  413. })}
  414. </Typography>
  415. )}
  416. </fieldset>
  417. {/* search params selectors */}
  418. <fieldset className="field field-params">
  419. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  420. <SearchParams
  421. wrapperClass={classes.paramsWrapper}
  422. metricType={selectedMetricType}
  423. embeddingType={
  424. embeddingType as
  425. | DataTypeEnum.BinaryVector
  426. | DataTypeEnum.FloatVector
  427. }
  428. indexType={indexType}
  429. indexParams={indexParams!}
  430. searchParamsForm={searchParam}
  431. handleFormChange={setSearchParam}
  432. handleMetricTypeChange={setSelectedMetricType}
  433. topK={topK}
  434. setParamsDisabled={setParamDisabled}
  435. />
  436. </fieldset>
  437. </form>
  438. {/**
  439. * search toolbar section
  440. * including topK selector, advanced filter, search and reset btn
  441. */}
  442. <section className={classes.toolbar}>
  443. <div className="left">
  444. <Typography variant="h5" className="text">
  445. {`${searchTrans('result')}: `}
  446. </Typography>
  447. {/* topK selector */}
  448. <SimpleMenu
  449. label={searchTrans('topK', { number: topK })}
  450. menuItems={TOP_K_OPTIONS.map(item => ({
  451. label: item.toString(),
  452. callback: () => {
  453. setTopK(item);
  454. if (!searchDisabled) {
  455. handleSearch(item);
  456. }
  457. },
  458. wrapperClass: classes.menuItem,
  459. }))}
  460. buttonProps={{
  461. className: classes.menuLabel,
  462. endIcon: <ArrowIcon />,
  463. }}
  464. menuItemWidth="108px"
  465. />
  466. <Filter
  467. title="Advanced Filter"
  468. fields={filterFields}
  469. filterDisabled={selectedField === '' || selectedCollection === ''}
  470. onSubmit={handleAdvancedFilterChange}
  471. />
  472. <CustomDatePicker
  473. label={timeTravelInfo.label}
  474. onChange={handleDateTimeChange}
  475. date={timeTravel}
  476. setDate={setTimeTravel}
  477. />
  478. </div>
  479. <div className="right">
  480. <CustomButton className="btn" onClick={handleReset}>
  481. <ResetIcon classes={{ root: 'icon' }} />
  482. {btnTrans('reset')}
  483. </CustomButton>
  484. <CustomButton
  485. variant="contained"
  486. disabled={searchDisabled}
  487. onClick={() => handleSearch(topK)}
  488. >
  489. {btnTrans('search')}
  490. </CustomButton>
  491. </div>
  492. </section>
  493. {/* search result table section */}
  494. {(searchResult && searchResult.length > 0) || tableLoading ? (
  495. <AttuGrid
  496. toolbarConfigs={[]}
  497. colDefinitions={colDefinitions}
  498. rows={result}
  499. rowCount={total}
  500. primaryKey="rank"
  501. page={currentPage}
  502. onChangePage={handlePageChange}
  503. rowsPerPage={pageSize}
  504. setRowsPerPage={handlePageSize}
  505. openCheckBox={false}
  506. isLoading={tableLoading}
  507. orderBy={orderBy}
  508. order={order}
  509. labelDisplayedRows={getLabelDisplayedRows(`(${latency} ms)`)}
  510. handleSort={handleGridSort}
  511. tableCellMaxWidth="100%"
  512. />
  513. ) : (
  514. <EmptyCard
  515. wrapperClass={`page-empty-card`}
  516. icon={<VectorSearchIcon />}
  517. text={
  518. searchResult !== null
  519. ? searchTrans('empty')
  520. : searchTrans('startTip')
  521. }
  522. />
  523. )}
  524. </section>
  525. );
  526. };
  527. export default VectorSearch;