Ver Fonte

perf: optimize get_notes_by_user_id to reduce database queries

- Replace inefficient memory-based filtering with database-level filtering
- Add proper access control conditions to SQL query
- Reduce memory usage by filtering at database level instead of loading all notes
- Maintain access control validation with post-filtering for complex cases

This change significantly improves performance for users with many notes
by reducing the number of database queries and memory usage.

Signed-off-by: Sihyeon Jang <sihyeon.jang@navercorp.com>
Sihyeon Jang há 2 semanas atrás
pai
commit
6ae6cc9741
2 ficheiros alterados com 57 adições e 10 exclusões
  1. 55 8
      backend/open_webui/models/notes.py
  2. 2 2
      backend/open_webui/routers/notes.py

+ 55 - 8
backend/open_webui/models/notes.py

@@ -2,6 +2,7 @@ import json
 import time
 import uuid
 from typing import Optional
+from functools import lru_cache
 
 from open_webui.internal.db import Base, get_db
 from open_webui.models.groups import Groups
@@ -110,20 +111,66 @@ class NoteTable:
             return [NoteModel.model_validate(note) for note in notes]
 
     def get_notes_by_user_id(
+        self,
+        user_id: str,
+        skip: Optional[int] = None,
+        limit: Optional[int] = None,
+    ) -> list[NoteModel]:
+        with get_db() as db:
+            query = db.query(Note).filter(Note.user_id == user_id)
+            query = query.order_by(Note.updated_at.desc())
+
+            if skip is not None:
+                query = query.offset(skip)
+            if limit is not None:
+                query = query.limit(limit)
+
+            notes = query.all()
+            return [NoteModel.model_validate(note) for note in notes]
+
+    def get_notes_by_access(
         self,
         user_id: str,
         permission: str = "write",
         skip: Optional[int] = None,
         limit: Optional[int] = None,
     ) -> list[NoteModel]:
-        notes = self.get_notes(skip=skip, limit=limit)
-        user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
-        return [
-            note
-            for note in notes
-            if note.user_id == user_id
-            or has_access(user_id, permission, note.access_control, user_group_ids)
-        ]
+        with get_db() as db:
+            user_groups = Groups.get_groups_by_member_id(user_id)
+            user_group_ids = {group_id for group_id in user_groups}
+
+            query = db.query(Note)
+
+            access_conditions = [Note.user_id == user_id]
+
+            if user_group_ids:
+                access_conditions.append(
+                    and_(
+                        Note.access_control.isnot(None),
+                        Note.access_control != '{}',
+                        Note.access_control != 'null'
+                    )
+                )
+
+            query = query.filter(or_(*access_conditions))
+
+            query = query.order_by(Note.updated_at.desc())
+
+            if skip is not None:
+                query = query.offset(skip)
+            if limit is not None:
+                query = query.limit(limit)
+
+            notes = query.all()
+            note_models = [NoteModel.model_validate(note) for note in notes]
+
+            filtered_notes = []
+            for note in note_models:
+                if (note.user_id == user_id or
+                    has_access(user_id, permission, note.access_control, user_group_ids)):
+                    filtered_notes.append(note)
+
+            return filtered_notes
 
     def get_note_by_id(self, id: str) -> Optional[NoteModel]:
         with get_db() as db:

+ 2 - 2
backend/open_webui/routers/notes.py

@@ -48,7 +48,7 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
                 "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
             }
         )
-        for note in Notes.get_notes_by_user_id(user.id, "write")
+        for note in Notes.get_notes_by_access(user.id, "write")
     ]
 
     return notes
@@ -81,7 +81,7 @@ async def get_note_list(
 
     notes = [
         NoteTitleIdResponse(**note.model_dump())
-        for note in Notes.get_notes_by_user_id(user.id, "write", skip=skip, limit=limit)
+        for note in Notes.get_notes_by_access(user.id, "write", skip=skip, limit=limit)
     ]
 
     return notes