Browse Source

only return permitted databases (#429)

* only return permitted database

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* remove roles

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

---------

Signed-off-by: ryjiang <jiangruiyi@gmail.com>
ryjiang 1 year ago
parent
commit
253493eaf0

+ 4 - 3
client/src/context/Data.tsx

@@ -278,15 +278,16 @@ export const DataProvider = (props: { children: React.ReactNode }) => {
 
   useEffect(() => {
     if (isAuth) {
-      // fetch db
-      fetchDatabases();
       // connect to socket server
       socket.current = io(url as string);
       // register client
       socket.current.emit(WS_EVENTS.REGISTER, clientId);
 
-      socket.current.on('connect', function () {
+      socket.current.on('connect', async () => {
         console.log('--- ws connected ---', clientId);
+        // fetch db
+        await fetchDatabases();
+        // set connected to trues
         setConnected(true);
       });
     } else {

+ 1 - 1
client/src/pages/databases/Databases.tsx

@@ -1,4 +1,4 @@
-import { useEffect, useContext } from 'react';
+import { useContext } from 'react';
 import { useParams } from 'react-router-dom';
 import { useTranslation } from 'react-i18next';
 import { makeStyles, Theme } from '@material-ui/core';

+ 12 - 6
server/src/collections/collections.service.ts

@@ -342,12 +342,16 @@ export class CollectionsService {
     });
 
     // get collection statistic data
-    const collectionStatisticsRes = await this.getCollectionStatistics(
-      clientId,
-      {
+    let collectionStatisticsRes;
+
+    try {
+      collectionStatisticsRes = await this.getCollectionStatistics(clientId, {
         collection_name: collection.name,
-      }
-    );
+      });
+    } catch (e) {
+      console.log('ignore getCollectionStatistics');
+    }
+
     // extract autoID
     const autoID = collectionInfo.schema.fields.find(
       v => v.is_primary_key === true
@@ -380,7 +384,9 @@ export class CollectionsService {
     return {
       collection_name: collection.name,
       schema: collectionInfo.schema,
-      rowCount: Number(collectionStatisticsRes.data.row_count),
+      rowCount: Number(
+        (collectionStatisticsRes && collectionStatisticsRes.data.row_count) || 0
+      ),
       createdTime: parseInt(collectionInfo.created_utc_timestamp, 10),
       aliases: collectionInfo.aliases,
       description: collectionInfo.schema.description,

+ 28 - 7
server/src/database/databases.service.ts

@@ -1,31 +1,52 @@
-import { MilvusService } from '../milvus/milvus.service';
 import {
   CreateDatabaseRequest,
   ListDatabasesRequest,
   DropDatabasesRequest,
+  ListDatabasesResponse,
 } from '@zilliz/milvus2-sdk-node';
 import { throwErrorFromSDK } from '../utils/Error';
 import { clientCache } from '../app';
 
 export class DatabasesService {
   async createDatabase(clientId: string, data: CreateDatabaseRequest) {
-        const { milvusClient } = clientCache.get(clientId);
+    const { milvusClient } = clientCache.get(clientId);
 
     const res = await milvusClient.createDatabase(data);
     throwErrorFromSDK(res);
     return res;
   }
 
-  async listDatabase(clientId: string, data?: ListDatabasesRequest) {
-        const { milvusClient } = clientCache.get(clientId);
+  async listDatabase(
+    clientId: string,
+    data?: ListDatabasesRequest
+  ): Promise<ListDatabasesResponse> {
+    const { milvusClient, database } = clientCache.get(clientId);
 
     const res = await milvusClient.listDatabases(data);
+
+    // test if the user has permission to access the database, loop through all databases
+    // and check if the user has permission to access the database
+    const availableDatabases: string[] = [];
+
+    for (const db of res.db_names) {
+      try {
+        await milvusClient.use({ db_name: db });
+        await milvusClient.listDatabases(data);
+        availableDatabases.push(db);
+      } catch (e) {
+        // ignore
+      }
+    }
+
+    // recover current database
+    await milvusClient.use({ db_name: database });
+
     throwErrorFromSDK(res.status);
-    return res;
+    return { ...res, db_names: availableDatabases };
   }
 
   async dropDatabase(clientId: string, data: DropDatabasesRequest) {
-        const { milvusClient } = clientCache.get(clientId);
+    const { milvusClient } = clientCache.get(clientId);
 
     const res = await milvusClient.dropDatabase(data);
     throwErrorFromSDK(res);
@@ -33,7 +54,7 @@ export class DatabasesService {
   }
 
   async use(clientId: string, db_name: string) {
-        const { milvusClient } = clientCache.get(clientId);
+    const { milvusClient } = clientCache.get(clientId);
 
     return await await milvusClient.use({ db_name });
   }

+ 16 - 9
server/src/milvus/milvus.service.ts

@@ -7,18 +7,12 @@ import {
 import { LRUCache } from 'lru-cache';
 import { DEFAULT_MILVUS_PORT, INDEX_TTL, SimpleQueue } from '../utils';
 import { connectivityState } from '@grpc/grpc-js';
-import { DatabasesService } from '../database/databases.service';
 import { clientCache } from '../app';
 import { DescribeIndexRes } from '../types';
 
 export class MilvusService {
-  private databaseService: DatabasesService;
   private DEFAULT_DATABASE = 'default';
 
-  constructor() {
-    this.databaseService = new DatabasesService();
-  }
-
   get sdkInfo() {
     return MilvusClient.sdkInfo;
   }
@@ -91,6 +85,9 @@ export class MilvusService {
         throw new Error('Milvus is not ready yet.');
       }
 
+      // database
+      const db = database || this.DEFAULT_DATABASE;
+
       // If the server is healthy, set the active address and add the client to the cache
       clientCache.set(milvusClient.clientId, {
         milvusClient,
@@ -99,16 +96,26 @@ export class MilvusService {
           ttl: INDEX_TTL,
           ttlAutopurge: true,
         }),
-        database,
+        database: db,
         collectionsQueue: new SimpleQueue<string>(),
       });
 
-      await this.databaseService.use(milvusClient.clientId, database);
+      // test ListDatabases permission
+      try {
+        await milvusClient.use({ db_name: db });
+        await milvusClient.listDatabases();
+      } catch (e) {
+        throw new Error(
+          `You don't have permission to access the database: ${db}.`
+        );
+      }
+
+      await milvusClient.use({ db_name: db });
 
       // Return the address and the database (if it exists, otherwise return 'default')
       return {
         address,
-        database: database || this.DEFAULT_DATABASE,
+        database: db,
         clientId: milvusClient.clientId,
       };
     } catch (error) {