VectorSearch.tsx 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import { TextField, Typography } from '@material-ui/core';
  2. import { useTranslation } from 'react-i18next';
  3. import { useNavigationHook } from '../../hooks/Navigation';
  4. import { ALL_ROUTER_TYPES } from '../../router/Types';
  5. import CustomSelector from '../../components/customSelector/CustomSelector';
  6. import { useCallback, useEffect, useMemo, useState } from 'react';
  7. import SearchParams from './SearchParams';
  8. import { DEFAULT_METRIC_VALUE_MAP } from '../../consts/Milvus';
  9. import { FieldOption, SearchResultView, VectorSearchParam } from './Types';
  10. import MilvusGrid from '../../components/grid/Grid';
  11. import EmptyCard from '../../components/cards/EmptyCard';
  12. import icons from '../../components/icons/Icons';
  13. import { usePaginationHook } from '../../hooks/Pagination';
  14. import CustomButton from '../../components/customButton/CustomButton';
  15. import SimpleMenu from '../../components/menu/SimpleMenu';
  16. import { TOP_K_OPTIONS } from './Constants';
  17. import { Option } from '../../components/customSelector/Types';
  18. import { CollectionHttp } from '../../http/Collection';
  19. import { CollectionData, DataTypeEnum } from '../collections/Types';
  20. import { IndexHttp } from '../../http/Index';
  21. import { getVectorSearchStyles } from './Styles';
  22. import { parseValue } from '../../utils/Insert';
  23. import {
  24. classifyFields,
  25. getDefaultIndexType,
  26. getEmbeddingType,
  27. getNonVectorFieldsForFilter,
  28. getVectorFieldOptions,
  29. transferSearchResult,
  30. } from '../../utils/search';
  31. import { ColDefinitionsType } from '../../components/grid/Types';
  32. import Filter from '../../components/advancedSearch';
  33. import { Field } from '../../components/advancedSearch/Types';
  34. import { useLocation } from 'react-router-dom';
  35. import { parseLocationSearch } from '../../utils/Format';
  36. const VectorSearch = () => {
  37. useNavigationHook(ALL_ROUTER_TYPES.SEARCH);
  38. const location = useLocation();
  39. // i18n
  40. const { t: searchTrans } = useTranslation('search');
  41. const { t: btnTrans } = useTranslation('btn');
  42. const classes = getVectorSearchStyles();
  43. // data stored inside the component
  44. const [tableLoading, setTableLoading] = useState<boolean>(false);
  45. const [collections, setCollections] = useState<CollectionData[]>([]);
  46. const [selectedCollection, setSelectedCollection] = useState<string>('');
  47. const [fieldOptions, setFieldOptions] = useState<FieldOption[]>([]);
  48. // fields for advanced filter
  49. const [filterFields, setFilterFields] = useState<Field[]>([]);
  50. const [selectedField, setSelectedField] = useState<string>('');
  51. // search params form
  52. const [searchParam, setSearchParam] = useState<{ [key in string]: number }>(
  53. {}
  54. );
  55. // search params disable state
  56. const [paramDisabled, setParamDisabled] = useState<boolean>(true);
  57. // use null as init value before search, empty array means no results
  58. const [searchResult, setSearchResult] = useState<SearchResultView[] | null>(
  59. null
  60. );
  61. // default topK is 100
  62. const [topK, setTopK] = useState<number>(100);
  63. const [expression, setExpression] = useState<string>('');
  64. const [vectors, setVectors] = useState<string>('');
  65. const {
  66. pageSize,
  67. handlePageSize,
  68. currentPage,
  69. handleCurrentPage,
  70. total,
  71. data: result,
  72. order,
  73. orderBy,
  74. handleGridSort,
  75. } = usePaginationHook(searchResult || []);
  76. const collectionOptions: Option[] = useMemo(
  77. () =>
  78. collections.map(c => ({
  79. label: c._name,
  80. value: c._name,
  81. })),
  82. [collections]
  83. );
  84. const outputFields: string[] = useMemo(() => {
  85. const fields =
  86. collections.find(c => c._name === selectedCollection)?._fields || [];
  87. // vector field can't be output fields
  88. const invalidTypes = ['BinaryVector', 'FloatVector'];
  89. const nonVectorFields = fields.filter(
  90. field => !invalidTypes.includes(field._fieldType)
  91. );
  92. return nonVectorFields.map(f => f._fieldName);
  93. }, [selectedCollection, collections]);
  94. const primaryKeyField = useMemo(() => {
  95. const selectedCollectionInfo = collections.find(
  96. c => c._name === selectedCollection
  97. );
  98. const fields = selectedCollectionInfo?._fields || [];
  99. return fields.find(f => f._isPrimaryKey)?._fieldName;
  100. }, [selectedCollection, collections]);
  101. const colDefinitions: ColDefinitionsType[] = useMemo(() => {
  102. /**
  103. * id represents primary key, score represents distance
  104. * since we transfer score to distance in the view, and original field which is primary key has already in the table
  105. * we filter 'id' and 'score' to avoid redundant data
  106. */
  107. return searchResult && searchResult.length > 0
  108. ? Object.keys(searchResult[0])
  109. .filter(item => {
  110. // if primary key field name is id, don't filter it
  111. const invalidItems =
  112. primaryKeyField === 'id' ? ['score'] : ['id', 'score'];
  113. return !invalidItems.includes(item);
  114. })
  115. .map(key => ({
  116. id: key,
  117. align: 'left',
  118. disablePadding: false,
  119. label: key,
  120. }))
  121. : [];
  122. }, [searchResult, primaryKeyField]);
  123. const {
  124. metricType,
  125. indexType,
  126. indexParams,
  127. fieldType,
  128. embeddingType,
  129. selectedFieldDimension,
  130. } = useMemo(() => {
  131. if (selectedField !== '') {
  132. // field options must contain selected field, so selectedFieldInfo will never undefined
  133. const selectedFieldInfo = fieldOptions.find(
  134. f => f.value === selectedField
  135. );
  136. const index = selectedFieldInfo?.indexInfo;
  137. const embeddingType = getEmbeddingType(selectedFieldInfo!.fieldType);
  138. const metric =
  139. index?._metricType || DEFAULT_METRIC_VALUE_MAP[embeddingType];
  140. const indexParams = index?._indexParameterPairs || [];
  141. const dim = selectedFieldInfo?.dimension || 0;
  142. return {
  143. metricType: metric,
  144. indexType: index?._indexType || getDefaultIndexType(embeddingType),
  145. indexParams,
  146. fieldType: DataTypeEnum[selectedFieldInfo?.fieldType!],
  147. embeddingType,
  148. selectedFieldDimension: dim,
  149. };
  150. }
  151. return {
  152. metricType: '',
  153. indexType: '',
  154. indexParams: [],
  155. fieldType: 0,
  156. embeddingType: DataTypeEnum.FloatVector,
  157. selectedFieldDimension: 0,
  158. };
  159. }, [selectedField, fieldOptions]);
  160. /**
  161. * vector value validation
  162. * @return whether is valid
  163. */
  164. const vectorValueValid = useMemo(() => {
  165. // if user hasn't input value or not select field, don't trigger validation check
  166. if (vectors === '' || selectedFieldDimension === 0) {
  167. return true;
  168. }
  169. const value = parseValue(vectors);
  170. const isArray = Array.isArray(value);
  171. return isArray && value.length === selectedFieldDimension;
  172. }, [vectors, selectedFieldDimension]);
  173. const searchDisabled = useMemo(() => {
  174. /**
  175. * before search, user must:
  176. * 1. enter vector value, it should be an array and length should be equal to selected field dimension
  177. * 2. choose collection and field
  178. * 3. set extra search params
  179. */
  180. const isInvalid =
  181. vectors === '' ||
  182. selectedCollection === '' ||
  183. selectedField === '' ||
  184. paramDisabled ||
  185. !vectorValueValid;
  186. return isInvalid;
  187. }, [
  188. paramDisabled,
  189. selectedField,
  190. selectedCollection,
  191. vectors,
  192. vectorValueValid,
  193. ]);
  194. // fetch data
  195. const fetchCollections = useCallback(async () => {
  196. const collections = await CollectionHttp.getCollections();
  197. setCollections(collections);
  198. }, []);
  199. const fetchFieldsWithIndex = useCallback(
  200. async (collectionName: string, collections: CollectionData[]) => {
  201. const fields =
  202. collections.find(c => c._name === collectionName)?._fields || [];
  203. const indexes = await IndexHttp.getIndexInfo(collectionName);
  204. const { vectorFields, nonVectorFields } = classifyFields(fields);
  205. // only vector type fields can be select
  206. const fieldOptions = getVectorFieldOptions(vectorFields, indexes);
  207. setFieldOptions(fieldOptions);
  208. if (fieldOptions.length > 0) {
  209. // set first option value as default field value
  210. const [{ value: defaultFieldValue }] = fieldOptions;
  211. setSelectedField(defaultFieldValue as string);
  212. }
  213. // only non vector type fields can be advanced filter
  214. const filterFields = getNonVectorFieldsForFilter(nonVectorFields);
  215. setFilterFields(filterFields);
  216. },
  217. []
  218. );
  219. useEffect(() => {
  220. fetchCollections();
  221. }, [fetchCollections]);
  222. // get field options with index when selected collection changed
  223. useEffect(() => {
  224. if (selectedCollection !== '') {
  225. fetchFieldsWithIndex(selectedCollection, collections);
  226. }
  227. }, [selectedCollection, collections, fetchFieldsWithIndex]);
  228. // set default collection value if is from overview page
  229. useEffect(() => {
  230. if (location.search && collections.length > 0) {
  231. const { collectionName } = parseLocationSearch(location.search);
  232. // collection name validation
  233. const isNameValid = collections
  234. .map(c => c._name)
  235. .includes(collectionName);
  236. isNameValid && setSelectedCollection(collectionName);
  237. }
  238. }, [location, collections]);
  239. // icons
  240. const VectorSearchIcon = icons.vectorSearch;
  241. const ResetIcon = icons.refresh;
  242. const ArrowIcon = icons.dropdown;
  243. // methods
  244. const handlePageChange = (e: any, page: number) => {
  245. handleCurrentPage(page);
  246. };
  247. const handleReset = () => {
  248. /**
  249. * reset search includes:
  250. * 1. reset vectors
  251. * 2. reset selected collection and field
  252. * 3. reset search params
  253. * 4. reset advanced filter expression
  254. * 5. clear search result
  255. */
  256. setVectors('');
  257. setSelectedField('');
  258. setSelectedCollection('');
  259. setSearchResult(null);
  260. setFilterFields([]);
  261. setExpression('');
  262. };
  263. const handleSearch = async (topK: number, expr = expression) => {
  264. const searhParamPairs = {
  265. params: JSON.stringify(searchParam),
  266. anns_field: selectedField,
  267. topk: topK,
  268. metric_type: metricType,
  269. round_decimal: searchParam.round_decimal,
  270. };
  271. const params: VectorSearchParam = {
  272. output_fields: outputFields,
  273. expr,
  274. search_params: searhParamPairs,
  275. vectors: [parseValue(vectors)],
  276. vector_type: fieldType,
  277. };
  278. setTableLoading(true);
  279. try {
  280. const res = await CollectionHttp.vectorSearchData(
  281. selectedCollection,
  282. params
  283. );
  284. setTableLoading(false);
  285. const result = transferSearchResult(res.results);
  286. setSearchResult(result);
  287. } catch (err) {
  288. setTableLoading(false);
  289. }
  290. };
  291. const handleAdvancedFilterChange = (expression: string) => {
  292. setExpression(expression);
  293. if (!searchDisabled) {
  294. handleSearch(topK, expression);
  295. }
  296. };
  297. const handleVectorChange = (value: string) => {
  298. setVectors(value);
  299. };
  300. return (
  301. <section className="page-wrapper">
  302. {/* form section */}
  303. <form className={classes.form}>
  304. {/**
  305. * vector value textarea
  306. * use field-params class because it also has error msg if invalid
  307. */}
  308. <fieldset className="field field-params">
  309. <Typography className="text">
  310. {searchTrans('firstTip', {
  311. dimensionTip:
  312. selectedFieldDimension !== 0
  313. ? `(dimension: ${selectedFieldDimension})`
  314. : '',
  315. })}
  316. </Typography>
  317. <TextField
  318. className="textarea"
  319. InputProps={{
  320. classes: {
  321. root: 'textfield',
  322. multiline: 'multiline',
  323. },
  324. }}
  325. multiline
  326. rows={5}
  327. placeholder={searchTrans('vectorPlaceholder')}
  328. value={vectors}
  329. onChange={(e: React.ChangeEvent<{ value: unknown }>) => {
  330. handleVectorChange(e.target.value as string);
  331. }}
  332. />
  333. {/* validation */}
  334. {!vectorValueValid && (
  335. <Typography variant="caption" className={classes.error}>
  336. {searchTrans('vectorValueWarning', {
  337. dimension: selectedFieldDimension,
  338. })}
  339. </Typography>
  340. )}
  341. </fieldset>
  342. {/* collection and field selectors */}
  343. <fieldset className="field field-second">
  344. <Typography className="text">{searchTrans('secondTip')}</Typography>
  345. <CustomSelector
  346. options={collectionOptions}
  347. wrapperClass={classes.selector}
  348. variant="filled"
  349. label={searchTrans(
  350. collectionOptions.length === 0 ? 'noCollection' : 'collection'
  351. )}
  352. disabled={collectionOptions.length === 0}
  353. value={selectedCollection}
  354. onChange={(e: { target: { value: unknown } }) => {
  355. const collection = e.target.value;
  356. setSelectedCollection(collection as string);
  357. // every time selected collection changed, reset field
  358. setSelectedField('');
  359. }}
  360. />
  361. <CustomSelector
  362. options={fieldOptions}
  363. // readOnly can't avoid all events, so we use disabled instead
  364. disabled={selectedCollection === ''}
  365. wrapperClass={classes.selector}
  366. variant="filled"
  367. label={searchTrans('field')}
  368. value={selectedField}
  369. onChange={(e: { target: { value: unknown } }) => {
  370. const field = e.target.value;
  371. setSelectedField(field as string);
  372. }}
  373. />
  374. </fieldset>
  375. {/* search params selectors */}
  376. <fieldset className="field field-params">
  377. <Typography className="text">{searchTrans('thirdTip')}</Typography>
  378. <SearchParams
  379. wrapperClass={classes.paramsWrapper}
  380. metricType={metricType!}
  381. embeddingType={
  382. embeddingType as
  383. | DataTypeEnum.BinaryVector
  384. | DataTypeEnum.FloatVector
  385. }
  386. indexType={indexType}
  387. indexParams={indexParams!}
  388. searchParamsForm={searchParam}
  389. handleFormChange={setSearchParam}
  390. topK={topK}
  391. setParamsDisabled={setParamDisabled}
  392. />
  393. </fieldset>
  394. </form>
  395. {/**
  396. * search toolbar section
  397. * including topK selector, advanced filter, search and reset btn
  398. */}
  399. <section className={classes.toolbar}>
  400. <div className="left">
  401. <Typography variant="h5" className="text">
  402. {`${searchTrans('result')}: `}
  403. </Typography>
  404. {/* topK selector */}
  405. <SimpleMenu
  406. label={searchTrans('topK', { number: topK })}
  407. menuItems={TOP_K_OPTIONS.map(item => ({
  408. label: item.toString(),
  409. callback: () => {
  410. setTopK(item);
  411. if (!searchDisabled) {
  412. handleSearch(item);
  413. }
  414. },
  415. wrapperClass: classes.menuItem,
  416. }))}
  417. buttonProps={{
  418. className: classes.menuLabel,
  419. endIcon: <ArrowIcon />,
  420. }}
  421. menuItemWidth="108px"
  422. />
  423. <Filter
  424. title="Advanced Filter"
  425. fields={filterFields}
  426. filterDisabled={selectedField === '' || selectedCollection === ''}
  427. onSubmit={handleAdvancedFilterChange}
  428. />
  429. </div>
  430. <div className="right">
  431. <CustomButton className="btn" onClick={handleReset}>
  432. <ResetIcon classes={{ root: 'icon' }} />
  433. {btnTrans('reset')}
  434. </CustomButton>
  435. <CustomButton
  436. variant="contained"
  437. disabled={searchDisabled}
  438. onClick={() => handleSearch(topK)}
  439. >
  440. {btnTrans('search')}
  441. </CustomButton>
  442. </div>
  443. </section>
  444. {/* search result table section */}
  445. {(searchResult && searchResult.length > 0) || tableLoading ? (
  446. <MilvusGrid
  447. toolbarConfigs={[]}
  448. colDefinitions={colDefinitions}
  449. rows={result}
  450. rowCount={total}
  451. primaryKey="rank"
  452. page={currentPage}
  453. onChangePage={handlePageChange}
  454. rowsPerPage={pageSize}
  455. setRowsPerPage={handlePageSize}
  456. openCheckBox={false}
  457. isLoading={tableLoading}
  458. orderBy={orderBy}
  459. order={order}
  460. handleSort={handleGridSort}
  461. />
  462. ) : (
  463. <EmptyCard
  464. wrapperClass={`page-empty-card`}
  465. icon={<VectorSearchIcon />}
  466. text={
  467. searchResult !== null
  468. ? searchTrans('empty')
  469. : searchTrans('startTip')
  470. }
  471. />
  472. )}
  473. </section>
  474. );
  475. };
  476. export default VectorSearch;