|
@@ -2,6 +2,7 @@ import json
|
|
|
import time
|
|
|
import uuid
|
|
|
from typing import Optional
|
|
|
+from functools import lru_cache
|
|
|
|
|
|
from open_webui.internal.db import Base, get_db
|
|
|
from open_webui.models.groups import Groups
|
|
@@ -110,20 +111,66 @@ class NoteTable:
|
|
|
return [NoteModel.model_validate(note) for note in notes]
|
|
|
|
|
|
def get_notes_by_user_id(
|
|
|
+ self,
|
|
|
+ user_id: str,
|
|
|
+ skip: Optional[int] = None,
|
|
|
+ limit: Optional[int] = None,
|
|
|
+ ) -> list[NoteModel]:
|
|
|
+ with get_db() as db:
|
|
|
+ query = db.query(Note).filter(Note.user_id == user_id)
|
|
|
+ query = query.order_by(Note.updated_at.desc())
|
|
|
+
|
|
|
+ if skip is not None:
|
|
|
+ query = query.offset(skip)
|
|
|
+ if limit is not None:
|
|
|
+ query = query.limit(limit)
|
|
|
+
|
|
|
+ notes = query.all()
|
|
|
+ return [NoteModel.model_validate(note) for note in notes]
|
|
|
+
|
|
|
+ def get_notes_by_access(
|
|
|
self,
|
|
|
user_id: str,
|
|
|
permission: str = "write",
|
|
|
skip: Optional[int] = None,
|
|
|
limit: Optional[int] = None,
|
|
|
) -> list[NoteModel]:
|
|
|
- notes = self.get_notes(skip=skip, limit=limit)
|
|
|
- user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
|
|
|
- return [
|
|
|
- note
|
|
|
- for note in notes
|
|
|
- if note.user_id == user_id
|
|
|
- or has_access(user_id, permission, note.access_control, user_group_ids)
|
|
|
- ]
|
|
|
+ with get_db() as db:
|
|
|
+ user_groups = Groups.get_groups_by_member_id(user_id)
|
|
|
+ user_group_ids = {group_id for group_id in user_groups}
|
|
|
+
|
|
|
+ query = db.query(Note)
|
|
|
+
|
|
|
+ access_conditions = [Note.user_id == user_id]
|
|
|
+
|
|
|
+ if user_group_ids:
|
|
|
+ access_conditions.append(
|
|
|
+ and_(
|
|
|
+ Note.access_control.isnot(None),
|
|
|
+ Note.access_control != '{}',
|
|
|
+ Note.access_control != 'null'
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ query = query.filter(or_(*access_conditions))
|
|
|
+
|
|
|
+ query = query.order_by(Note.updated_at.desc())
|
|
|
+
|
|
|
+ if skip is not None:
|
|
|
+ query = query.offset(skip)
|
|
|
+ if limit is not None:
|
|
|
+ query = query.limit(limit)
|
|
|
+
|
|
|
+ notes = query.all()
|
|
|
+ note_models = [NoteModel.model_validate(note) for note in notes]
|
|
|
+
|
|
|
+ filtered_notes = []
|
|
|
+ for note in note_models:
|
|
|
+ if (note.user_id == user_id or
|
|
|
+ has_access(user_id, permission, note.access_control, user_group_ids)):
|
|
|
+ filtered_notes.append(note)
|
|
|
+
|
|
|
+ return filtered_notes
|
|
|
|
|
|
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
|
|
with get_db() as db:
|