groups.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import json
  2. import logging
  3. import time
  4. from typing import Optional
  5. import uuid
  6. from open_webui.internal.db import Base, get_db
  7. from open_webui.env import SRC_LOG_LEVELS
  8. from open_webui.models.files import FileMetadataResponse
  9. from pydantic import BaseModel, ConfigDict
  10. from sqlalchemy import BigInteger, Column, String, Text, JSON, func
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  13. ####################
  14. # UserGroup DB Schema
  15. ####################
  16. class Group(Base):
  17. __tablename__ = "group"
  18. id = Column(Text, unique=True, primary_key=True)
  19. user_id = Column(Text)
  20. name = Column(Text)
  21. description = Column(Text)
  22. data = Column(JSON, nullable=True)
  23. meta = Column(JSON, nullable=True)
  24. permissions = Column(JSON, nullable=True)
  25. user_ids = Column(JSON, nullable=True)
  26. created_at = Column(BigInteger)
  27. updated_at = Column(BigInteger)
  28. class GroupModel(BaseModel):
  29. model_config = ConfigDict(from_attributes=True)
  30. id: str
  31. user_id: str
  32. name: str
  33. description: str
  34. data: Optional[dict] = None
  35. meta: Optional[dict] = None
  36. permissions: Optional[dict] = None
  37. user_ids: list[str] = []
  38. created_at: int # timestamp in epoch
  39. updated_at: int # timestamp in epoch
  40. ####################
  41. # Forms
  42. ####################
  43. class GroupResponse(BaseModel):
  44. id: str
  45. user_id: str
  46. name: str
  47. description: str
  48. permissions: Optional[dict] = None
  49. data: Optional[dict] = None
  50. meta: Optional[dict] = None
  51. user_ids: list[str] = []
  52. created_at: int # timestamp in epoch
  53. updated_at: int # timestamp in epoch
  54. class GroupForm(BaseModel):
  55. name: str
  56. description: str
  57. permissions: Optional[dict] = None
  58. class UserIdsForm(BaseModel):
  59. user_ids: Optional[list[str]] = None
  60. class GroupUpdateForm(GroupForm, UserIdsForm):
  61. pass
  62. class GroupTable:
  63. def insert_new_group(
  64. self, user_id: str, form_data: GroupForm
  65. ) -> Optional[GroupModel]:
  66. with get_db() as db:
  67. group = GroupModel(
  68. **{
  69. **form_data.model_dump(exclude_none=True),
  70. "id": str(uuid.uuid4()),
  71. "user_id": user_id,
  72. "created_at": int(time.time()),
  73. "updated_at": int(time.time()),
  74. }
  75. )
  76. try:
  77. result = Group(**group.model_dump())
  78. db.add(result)
  79. db.commit()
  80. db.refresh(result)
  81. if result:
  82. return GroupModel.model_validate(result)
  83. else:
  84. return None
  85. except Exception:
  86. return None
  87. def get_groups(self) -> list[GroupModel]:
  88. with get_db() as db:
  89. return [
  90. GroupModel.model_validate(group)
  91. for group in db.query(Group).order_by(Group.updated_at.desc()).all()
  92. ]
  93. def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
  94. with get_db() as db:
  95. return [
  96. GroupModel.model_validate(group)
  97. for group in db.query(Group)
  98. .filter(
  99. func.json_array_length(Group.user_ids) > 0
  100. ) # Ensure array exists
  101. .filter(
  102. Group.user_ids.cast(String).like(f'%"{user_id}"%')
  103. ) # String-based check
  104. .order_by(Group.updated_at.desc())
  105. .all()
  106. ]
  107. def get_group_by_id(self, id: str) -> Optional[GroupModel]:
  108. try:
  109. with get_db() as db:
  110. group = db.query(Group).filter_by(id=id).first()
  111. return GroupModel.model_validate(group) if group else None
  112. except Exception:
  113. return None
  114. def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
  115. group = self.get_group_by_id(id)
  116. if group:
  117. return group.user_ids
  118. else:
  119. return None
  120. def update_group_by_id(
  121. self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
  122. ) -> Optional[GroupModel]:
  123. try:
  124. with get_db() as db:
  125. db.query(Group).filter_by(id=id).update(
  126. {
  127. **form_data.model_dump(exclude_none=True),
  128. "updated_at": int(time.time()),
  129. }
  130. )
  131. db.commit()
  132. return self.get_group_by_id(id=id)
  133. except Exception as e:
  134. log.exception(e)
  135. return None
  136. def delete_group_by_id(self, id: str) -> bool:
  137. try:
  138. with get_db() as db:
  139. db.query(Group).filter_by(id=id).delete()
  140. db.commit()
  141. return True
  142. except Exception:
  143. return False
  144. def delete_all_groups(self) -> bool:
  145. with get_db() as db:
  146. try:
  147. db.query(Group).delete()
  148. db.commit()
  149. return True
  150. except Exception:
  151. return False
  152. def remove_user_from_all_groups(self, user_id: str) -> bool:
  153. with get_db() as db:
  154. try:
  155. groups = self.get_groups_by_member_id(user_id)
  156. for group in groups:
  157. group.user_ids.remove(user_id)
  158. db.query(Group).filter_by(id=group.id).update(
  159. {
  160. "user_ids": group.user_ids,
  161. "updated_at": int(time.time()),
  162. }
  163. )
  164. db.commit()
  165. return True
  166. except Exception:
  167. return False
  168. def create_groups_by_group_names(
  169. self, user_id: str, group_names: list[str]
  170. ) -> list[GroupModel]:
  171. # check for existing groups
  172. existing_groups = self.get_groups()
  173. existing_group_names = {group.name for group in existing_groups}
  174. new_groups = []
  175. with get_db() as db:
  176. for group_name in group_names:
  177. if group_name not in existing_group_names:
  178. new_group = GroupModel(
  179. id=str(uuid.uuid4()),
  180. user_id=user_id,
  181. name=group_name,
  182. description="",
  183. created_at=int(time.time()),
  184. updated_at=int(time.time()),
  185. )
  186. try:
  187. result = Group(**new_group.model_dump())
  188. db.add(result)
  189. db.commit()
  190. db.refresh(result)
  191. new_groups.append(GroupModel.model_validate(result))
  192. except Exception as e:
  193. log.exception(e)
  194. continue
  195. return new_groups
  196. def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
  197. with get_db() as db:
  198. try:
  199. groups = db.query(Group).filter(Group.name.in_(group_names)).all()
  200. group_ids = [group.id for group in groups]
  201. # Remove user from groups not in the new list
  202. existing_groups = self.get_groups_by_member_id(user_id)
  203. for group in existing_groups:
  204. if group.id not in group_ids:
  205. group.user_ids.remove(user_id)
  206. db.query(Group).filter_by(id=group.id).update(
  207. {
  208. "user_ids": group.user_ids,
  209. "updated_at": int(time.time()),
  210. }
  211. )
  212. # Add user to new groups
  213. for group in groups:
  214. if user_id not in group.user_ids:
  215. group.user_ids.append(user_id)
  216. db.query(Group).filter_by(id=group.id).update(
  217. {
  218. "user_ids": group.user_ids,
  219. "updated_at": int(time.time()),
  220. }
  221. )
  222. db.commit()
  223. return True
  224. except Exception as e:
  225. log.exception(e)
  226. return False
  227. def add_users_to_group(
  228. self, id: str, user_ids: Optional[list[str]] = None
  229. ) -> Optional[GroupModel]:
  230. try:
  231. with get_db() as db:
  232. group = db.query(Group).filter_by(id=id).first()
  233. if not group:
  234. return None
  235. if not group.user_ids:
  236. group.user_ids = []
  237. for user_id in user_ids:
  238. if user_id not in group.user_ids:
  239. group.user_ids.append(user_id)
  240. group.updated_at = int(time.time())
  241. db.commit()
  242. db.refresh(group)
  243. return GroupModel.model_validate(group)
  244. except Exception as e:
  245. log.exception(e)
  246. return None
  247. def remove_users_from_group(
  248. self, id: str, user_ids: Optional[list[str]] = None
  249. ) -> Optional[GroupModel]:
  250. try:
  251. with get_db() as db:
  252. group = db.query(Group).filter_by(id=id).first()
  253. if not group:
  254. return None
  255. if not group.user_ids:
  256. return GroupModel.model_validate(group)
  257. for user_id in user_ids:
  258. if user_id in group.user_ids:
  259. group.user_ids.remove(user_id)
  260. group.updated_at = int(time.time())
  261. db.commit()
  262. db.refresh(group)
  263. return GroupModel.model_validate(group)
  264. except Exception as e:
  265. log.exception(e)
  266. return None
  267. Groups = GroupTable()