VectorSearch.tsx 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. import { TextField, Typography } from '@material-ui/core';
  2. import { useTranslation } from 'react-i18next';
  3. import { useNavigationHook } from '../../hooks/Navigation';
  4. import { ALL_ROUTER_TYPES } from '../../router/Types';
  5. import CustomSelector from '../../components/customSelector/CustomSelector';
  6. import { useCallback, useEffect, useMemo, useState } from 'react';
  7. import SearchParams from './SearchParams';
  8. import { DEFAULT_METRIC_VALUE_MAP } from '../../consts/Milvus';
  9. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  10. import MilvusGrid from '../../components/grid/Grid';
  11. import EmptyCard from '../../components/cards/EmptyCard';
  12. import icons from '../../components/icons/Icons';
  13. import { usePaginationHook } from '../../hooks/Pagination';
  14. import CustomButton from '../../components/customButton/CustomButton';
  15. import SimpleMenu from '../../components/menu/SimpleMenu';
  16. import { TOP_K_OPTIONS } from './Constants';
  17. import { Option } from '../../components/customSelector/Types';
  18. import { CollectionHttp } from '../../http/Collection';
  19. import { CollectionData, DataTypeEnum } from '../collections/Types';
  20. import { IndexHttp } from '../../http/Index';
  21. import { getVectorSearchStyles } from './Styles';
  22. import { parseValue } from '../../utils/Insert';
  23. import {
  24. classifyFields,
  25. getDefaultIndexType,
  26. getEmbeddingType,
  27. getNonVectorFieldsForFilter,
  28. getVectorFieldOptions,
  29. transferSearchResult,
  30. } from '../../utils/search';
  31. import { ColDefinitionsType } from '../../components/grid/Types';
  32. import Filter from '../../components/advancedSearch';
  33. import { Field } from '../../components/advancedSearch/Types';
  34. import { useLocation } from 'react-router-dom';
  35. import { parseLocationSearch } from '../../utils/Format';
  36. const VectorSearch = () => {
  37. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  38. const location = useLocation();
  39. // i18n
  40. const { t: searchTrans } = useTranslation('search');
  41. const { t: btnTrans } = useTranslation('btn');
  42. const classes = getVectorSearchStyles();
  43. // data stored inside the component
  44. const [tableLoading, setTableLoading] = useState<boolean>(false);
  45. const [collections, setCollections] = useState<CollectionData[]>([]);
  46. const [selectedCollection, setSelectedCollection] = useState<string>('');
  47. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  48. // fields for advanced filter
  49. const [filterFields, setFilterFields] = useState<Field[]>([]);
  50. const [selectedField, setSelectedField] = useState<string>('');
  51. // search params form
  52. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  53. {}
  54. );
  55. // search params disable state
  56. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  57. // use null as init value before search, empty array means no results
  58. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  59. null
  60. );
  61. // default topK is 100
  62. const [topK, setTopK] = useState<number>(100);
  63. const [expression, setExpression] = useState<string>('');
  64. const [vectors, setVectors] = useState<string>('');
  65. const {
  66. pageSize,
  67. handlePageSize,
  68. currentPage,
  69. handleCurrentPage,
  70. total,
  71. data: result,
  72. } = usePaginationHook(searchResult || []);
  73. const collectionOptions: Option[] = useMemo(
  74. () =>
  75. collections.map(c => ({
  76. label: c._name,
  77. value: c._name,
  78. })),
  79. [collections]
  80. );
  81. const outputFields: string[] = useMemo(() => {
  82. const fields =
  83. collections.find(c => c._name === selectedCollection)?._fields || [];
  84. // vector field can't be output fields
  85. const invalidTypes = ['BinaryVector', 'FloatVector'];
  86. const nonVectorFields = fields.filter(
  87. field => !invalidTypes.includes(field._fieldType)
  88. );
  89. return nonVectorFields.map(f => f._fieldName);
  90. }, [selectedCollection, collections]);
  91. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  92. // filter id and score
  93. return searchResult && searchResult.length > 0
  94. ? Object.keys(searchResult[0])
  95. .filter(item => item !== 'id' && item !== 'score')
  96. .map(key => ({
  97. id: key,
  98. align: 'left',
  99. disablePadding: false,
  100. label: key,
  101. }))
  102. : [];
  103. }, [searchResult]);
  104. const {
  105. metricType,
  106. indexType,
  107. indexParams,
  108. fieldType,
  109. embeddingType,
  110. selectedFieldDimension,
  111. } = useMemo(() => {
  112. if (selectedField !== '') {
  113. // field options must contain selected field, so selectedFieldInfo will never undefined
  114. const selectedFieldInfo = fieldOptions.find(
  115. f => f.value === selectedField
  116. );
  117. const index = selectedFieldInfo?.indexInfo;
  118. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  119. const metric =
  120. index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  121. const indexParams = index?._indexParameterPairs || [];
  122. const dim = selectedFieldInfo?.dimension || 0;
  123. return {
  124. metricType: metric,
  125. indexType: index?._indexType || getDefaultIndexType(embeddingType),
  126. indexParams,
  127. fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
  128. embeddingType,
  129. selectedFieldDimension: dim,
  130. };
  131. }
  132. return {
  133. metricType: '',
  134. indexType: '',
  135. indexParams: [],
  136. fieldType: 0,
  137. embeddingType: DataTypeEnum.FloatVector,
  138. selectedFieldDimension: 0,
  139. };
  140. }, [selectedField, fieldOptions]);
  141. /**
  142. * vector value validation
  143. * @return whether is valid
  144. */
  145. const vectorValueValid = useMemo(() => {
  146. // if user hasn't input value or not select field, don't trigger validation check
  147. if (vectors === '' || selectedFieldDimension === 0) {
  148. return true;
  149. }
  150. const value = parseValue(vectors);
  151. const isArray = Array.isArray(value);
  152. return isArray && value.length === selectedFieldDimension;
  153. }, [vectors, selectedFieldDimension]);
  154. const searchDisabled = useMemo(() => {
  155. /**
  156. * before search, user must:
  157. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  158. * 2. choose collection and field
  159. * 3. set extra search params
  160. */
  161. const isInvalid =
  162. vectors === '' ||
  163. selectedCollection === '' ||
  164. selectedField === '' ||
  165. paramDisabled ||
  166. !vectorValueValid;
  167. return isInvalid;
  168. }, [
  169. paramDisabled,
  170. selectedField,
  171. selectedCollection,
  172. vectors,
  173. vectorValueValid,
  174. ]);
  175. // fetch data
  176. const fetchCollections = useCallback(async () => {
  177. const collections = await CollectionHttp.getCollections();
  178. setCollections(collections);
  179. }, []);
  180. const fetchFieldsWithIndex = useCallback(
  181. async (collectionName: string, collections: CollectionData[]) => {
  182. const fields =
  183. collections.find(c => c._name === collectionName)?._fields || [];
  184. const indexes = await IndexHttp.getIndexInfo(collectionName);
  185. const { vectorFields, nonVectorFields } = classifyFields(fields);
  186. // only vector type fields can be select
  187. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  188. setFieldOptions(fieldOptions);
  189. if (fieldOptions.length > 0) {
  190. // set first option value as default field value
  191. const [{ value: defaultFieldValue }] = fieldOptions;
  192. setSelectedField(defaultFieldValue as string);
  193. }
  194. // only non vector type fields can be advanced filter
  195. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  196. setFilterFields(filterFields);
  197. },
  198. []
  199. );
  200. useEffect(() => {
  201. fetchCollections();
  202. }, [fetchCollections]);
  203. // get field options with index when selected collection changed
  204. useEffect(() => {
  205. if (selectedCollection !== '') {
  206. fetchFieldsWithIndex(selectedCollection, collections);
  207. }
  208. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  209. // set default collection value if is from overview page
  210. useEffect(() => {
  211. if (location.search && collections.length > 0) {
  212. const { collectionName } = parseLocationSearch(location.search);
  213. // collection name validation
  214. const isNameValid = collections
  215. .map(c => c._name)
  216. .includes(collectionName);
  217. isNameValid && setSelectedCollection(collectionName);
  218. }
  219. }, [location, collections]);
  220. // icons
  221. const VectorSearchIcon = icons.vectorSearch;
  222. const ResetIcon = icons.refresh;
  223. const ArrowIcon = icons.dropdown;
  224. // methods
  225. const handlePageChange = (e: any, page: number) => {
  226. handleCurrentPage(page);
  227. };
  228. const handleReset = () => {
  229. /**
  230. * reset search includes:
  231. * 1. reset vectors
  232. * 2. reset selected collection and field
  233. * 3. reset search params
  234. * 4. reset advanced filter expression
  235. * 5. clear search result
  236. */
  237. setVectors('');
  238. setSelectedField('');
  239. setSelectedCollection('');
  240. setSearchResult(null);
  241. setFilterFields([]);
  242. setExpression('');
  243. };
  244. const handleSearch = async (topK: number, expr = expression) => {
  245. const searhParamPairs = [
  246. // dynamic search params
  247. {
  248. key: 'params',
  249. value: JSON.stringify(searchParam),
  250. },
  251. {
  252. key: 'anns_field',
  253. value: selectedField,
  254. },
  255. {
  256. key: 'topk',
  257. value: topK,
  258. },
  259. {
  260. key: 'metric_type',
  261. value: metricType,
  262. },
  263. ];
  264. const params: VectorSearchParam = {
  265. output_fields: outputFields,
  266. expr,
  267. search_params: searhParamPairs,
  268. vectors: [parseValue(vectors)],
  269. vector_type: fieldType,
  270. };
  271. setTableLoading(true);
  272. try {
  273. const res = await CollectionHttp.vectorSearchData(
  274. selectedCollection,
  275. params
  276. );
  277. setTableLoading(false);
  278. const result = transferSearchResult(res.results);
  279. setSearchResult(result);
  280. } catch (err) {
  281. setTableLoading(false);
  282. }
  283. };
  284. const handleAdvancedFilterChange = (expression: string) => {
  285. setExpression(expression);
  286. if (!searchDisabled) {
  287. handleSearch(topK, expression);
  288. }
  289. };
  290. const handleVectorChange = (value: string) => {
  291. setVectors(value);
  292. };
  293. return (
  294. <section className="page-wrapper">
  295. {/* form section */}
  296. <form className={classes.form}>
  297. {/**
  298. * vector value textarea
  299. * use field-params class because it also has error msg if invalid
  300. */}
  301. <fieldset className="field field-params">
  302. <Typography className="text">
  303. {searchTrans('firstTip', {
  304. dimensionTip:
  305. selectedFieldDimension !== 0
  306. ? `(dimension: ${selectedFieldDimension})`
  307. : '',
  308. })}
  309. </Typography>
  310. <TextField
  311. className="textarea"
  312. InputProps={{
  313. classes: {
  314. root: 'textfield',
  315. multiline: 'multiline',
  316. },
  317. }}
  318. multiline
  319. rows={5}
  320. placeholder={searchTrans('vectorPlaceholder')}
  321. value={vectors}
  322. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  323. handleVectorChange(e.target.value as string);
  324. }}
  325. />
  326. {/* validation */}
  327. {!vectorValueValid && (
  328. <Typography variant="caption" className={classes.error}>
  329. {searchTrans('vectorValueWarning', {
  330. dimension: selectedFieldDimension,
  331. })}
  332. </Typography>
  333. )}
  334. </fieldset>
  335. {/* collection and field selectors */}
  336. <fieldset className="field field-second">
  337. <Typography className="text">{searchTrans('secondTip')}</Typography>
  338. <CustomSelector
  339. options={collectionOptions}
  340. wrapperClass={classes.selector}
  341. variant="filled"
  342. label={searchTrans(
  343. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  344. )}
  345. disabled={collectionOptions.length === 0}
  346. value={selectedCollection}
  347. onChange={(e: { target: { value: unknown } }) => {
  348. const collection = e.target.value;
  349. setSelectedCollection(collection as string);
  350. // every time selected collection changed, reset field
  351. setSelectedField('');
  352. }}
  353. />
  354. <CustomSelector
  355. options={fieldOptions}
  356. // readOnly can't avoid all events, so we use disabled instead
  357. disabled={selectedCollection === ''}
  358. wrapperClass={classes.selector}
  359. variant="filled"
  360. label={searchTrans('field')}
  361. value={selectedField}
  362. onChange={(e: { target: { value: unknown } }) => {
  363. const field = e.target.value;
  364. setSelectedField(field as string);
  365. }}
  366. />
  367. </fieldset>
  368. {/* search params selectors */}
  369. <fieldset className="field field-params">
  370. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  371. <SearchParams
  372. wrapperClass={classes.paramsWrapper}
  373. metricType={metricType!}
  374. embeddingType={
  375. embeddingType as
  376. | DataTypeEnum.BinaryVector
  377. | DataTypeEnum.FloatVector
  378. }
  379. indexType={indexType}
  380. indexParams={indexParams!}
  381. searchParamsForm={searchParam}
  382. handleFormChange={setSearchParam}
  383. topK={topK}
  384. setParamsDisabled={setParamDisabled}
  385. />
  386. </fieldset>
  387. </form>
  388. {/**
  389. * search toolbar section
  390. * including topK selector, advanced filter, search and reset btn
  391. */}
  392. <section className={classes.toolbar}>
  393. <div className="left">
  394. <Typography variant="h5" className="text">
  395. {`${searchTrans('result')}: `}
  396. </Typography>
  397. {/* topK selector */}
  398. <SimpleMenu
  399. label={searchTrans('topK', { number: topK })}
  400. menuItems={TOP_K_OPTIONS.map(item => ({
  401. label: item.toString(),
  402. callback: () => {
  403. setTopK(item);
  404. if (!searchDisabled) {
  405. handleSearch(item);
  406. }
  407. },
  408. wrapperClass: classes.menuItem,
  409. }))}
  410. buttonProps={{
  411. className: classes.menuLabel,
  412. endIcon: <ArrowIcon />,
  413. }}
  414. menuItemWidth="108px"
  415. />
  416. <Filter
  417. title="Advanced Filter"
  418. fields={filterFields}
  419. filterDisabled={selectedField === '' || selectedCollection === ''}
  420. onSubmit={handleAdvancedFilterChange}
  421. />
  422. </div>
  423. <div className="right">
  424. <CustomButton className="btn" onClick={handleReset}>
  425. <ResetIcon classes={{ root: 'icon' }} />
  426. {btnTrans('reset')}
  427. </CustomButton>
  428. <CustomButton
  429. variant="contained"
  430. disabled={searchDisabled}
  431. onClick={() => handleSearch(topK)}
  432. >
  433. {btnTrans('search')}
  434. </CustomButton>
  435. </div>
  436. </section>
  437. {/* search result table section */}
  438. {(searchResult && searchResult.length > 0) || tableLoading ? (
  439. <MilvusGrid
  440. toolbarConfigs={[]}
  441. colDefinitions={colDefinitions}
  442. rows={result}
  443. rowCount={total}
  444. primaryKey="rank"
  445. page={currentPage}
  446. onChangePage={handlePageChange}
  447. rowsPerPage={pageSize}
  448. setRowsPerPage={handlePageSize}
  449. openCheckBox={false}
  450. isLoading={tableLoading}
  451. />
  452. ) : (
  453. <EmptyCard
  454. wrapperClass={`page-empty-card`}
  455. icon={<VectorSearchIcon />}
  456. text={
  457. searchResult !== null
  458. ? searchTrans('empty')
  459. : searchTrans('startTip')
  460. }
  461. />
  462. )}
  463. </section>
  464. );
  465. };
  466. export default VectorSearch;