users.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import time
  2. from typing import Optional
  3. from open_webui.internal.db import Base, JSONField, get_db
  4. from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
  5. from open_webui.models.chats import Chats
  6. from open_webui.models.groups import Groups
  7. from open_webui.utils.misc import throttle
  8. from pydantic import BaseModel, ConfigDict
  9. from sqlalchemy import BigInteger, Column, String, Text, Date
  10. from sqlalchemy import or_
  11. import datetime
  12. ####################
  13. # User DB Schema
  14. ####################
  15. class User(Base):
  16. __tablename__ = "user"
  17. id = Column(String, primary_key=True)
  18. name = Column(String)
  19. email = Column(String)
  20. username = Column(String(50), nullable=True)
  21. role = Column(String)
  22. profile_image_url = Column(Text)
  23. bio = Column(Text, nullable=True)
  24. gender = Column(Text, nullable=True)
  25. date_of_birth = Column(Date, nullable=True)
  26. info = Column(JSONField, nullable=True)
  27. settings = Column(JSONField, nullable=True)
  28. api_key = Column(String, nullable=True, unique=True)
  29. oauth_sub = Column(Text, unique=True)
  30. last_active_at = Column(BigInteger)
  31. updated_at = Column(BigInteger)
  32. created_at = Column(BigInteger)
  33. class UserSettings(BaseModel):
  34. ui: Optional[dict] = {}
  35. model_config = ConfigDict(extra="allow")
  36. pass
  37. class UserModel(BaseModel):
  38. id: str
  39. name: str
  40. email: str
  41. username: Optional[str] = None
  42. role: str = "pending"
  43. profile_image_url: str
  44. bio: Optional[str] = None
  45. gender: Optional[str] = None
  46. date_of_birth: Optional[datetime.date] = None
  47. info: Optional[dict] = None
  48. settings: Optional[UserSettings] = None
  49. api_key: Optional[str] = None
  50. oauth_sub: Optional[str] = None
  51. last_active_at: int # timestamp in epoch
  52. updated_at: int # timestamp in epoch
  53. created_at: int # timestamp in epoch
  54. model_config = ConfigDict(from_attributes=True)
  55. ####################
  56. # Forms
  57. ####################
  58. class UpdateProfileForm(BaseModel):
  59. profile_image_url: str
  60. name: str
  61. bio: Optional[str] = None
  62. gender: Optional[str] = None
  63. date_of_birth: Optional[datetime.date] = None
  64. class UserListResponse(BaseModel):
  65. users: list[UserModel]
  66. total: int
  67. class UserInfoResponse(BaseModel):
  68. id: str
  69. name: str
  70. email: str
  71. role: str
  72. class UserInfoListResponse(BaseModel):
  73. users: list[UserInfoResponse]
  74. total: int
  75. class UserResponse(BaseModel):
  76. id: str
  77. name: str
  78. email: str
  79. role: str
  80. profile_image_url: str
  81. class UserNameResponse(BaseModel):
  82. id: str
  83. name: str
  84. role: str
  85. profile_image_url: str
  86. class UserRoleUpdateForm(BaseModel):
  87. id: str
  88. role: str
  89. class UserUpdateForm(BaseModel):
  90. role: str
  91. name: str
  92. email: str
  93. profile_image_url: str
  94. password: Optional[str] = None
  95. class UsersTable:
  96. def insert_new_user(
  97. self,
  98. id: str,
  99. name: str,
  100. email: str,
  101. profile_image_url: str = "/user.png",
  102. role: str = "pending",
  103. oauth_sub: Optional[str] = None,
  104. ) -> Optional[UserModel]:
  105. with get_db() as db:
  106. user = UserModel(
  107. **{
  108. "id": id,
  109. "name": name,
  110. "email": email,
  111. "role": role,
  112. "profile_image_url": profile_image_url,
  113. "last_active_at": int(time.time()),
  114. "created_at": int(time.time()),
  115. "updated_at": int(time.time()),
  116. "oauth_sub": oauth_sub,
  117. }
  118. )
  119. result = User(**user.model_dump())
  120. db.add(result)
  121. db.commit()
  122. db.refresh(result)
  123. if result:
  124. return user
  125. else:
  126. return None
  127. def get_user_by_id(self, id: str) -> Optional[UserModel]:
  128. try:
  129. with get_db() as db:
  130. user = db.query(User).filter_by(id=id).first()
  131. return UserModel.model_validate(user)
  132. except Exception:
  133. return None
  134. def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
  135. try:
  136. with get_db() as db:
  137. user = db.query(User).filter_by(api_key=api_key).first()
  138. return UserModel.model_validate(user)
  139. except Exception:
  140. return None
  141. def get_user_by_email(self, email: str) -> Optional[UserModel]:
  142. try:
  143. with get_db() as db:
  144. user = db.query(User).filter_by(email=email).first()
  145. return UserModel.model_validate(user)
  146. except Exception:
  147. return None
  148. def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
  149. try:
  150. with get_db() as db:
  151. user = db.query(User).filter_by(oauth_sub=sub).first()
  152. return UserModel.model_validate(user)
  153. except Exception:
  154. return None
  155. def get_users(
  156. self,
  157. filter: Optional[dict] = None,
  158. skip: Optional[int] = None,
  159. limit: Optional[int] = None,
  160. ) -> UserListResponse:
  161. with get_db() as db:
  162. query = db.query(User)
  163. if filter:
  164. query_key = filter.get("query")
  165. if query_key:
  166. query = query.filter(
  167. or_(
  168. User.name.ilike(f"%{query_key}%"),
  169. User.email.ilike(f"%{query_key}%"),
  170. )
  171. )
  172. order_by = filter.get("order_by")
  173. direction = filter.get("direction")
  174. if order_by == "name":
  175. if direction == "asc":
  176. query = query.order_by(User.name.asc())
  177. else:
  178. query = query.order_by(User.name.desc())
  179. elif order_by == "email":
  180. if direction == "asc":
  181. query = query.order_by(User.email.asc())
  182. else:
  183. query = query.order_by(User.email.desc())
  184. elif order_by == "created_at":
  185. if direction == "asc":
  186. query = query.order_by(User.created_at.asc())
  187. else:
  188. query = query.order_by(User.created_at.desc())
  189. elif order_by == "last_active_at":
  190. if direction == "asc":
  191. query = query.order_by(User.last_active_at.asc())
  192. else:
  193. query = query.order_by(User.last_active_at.desc())
  194. elif order_by == "updated_at":
  195. if direction == "asc":
  196. query = query.order_by(User.updated_at.asc())
  197. else:
  198. query = query.order_by(User.updated_at.desc())
  199. elif order_by == "role":
  200. if direction == "asc":
  201. query = query.order_by(User.role.asc())
  202. else:
  203. query = query.order_by(User.role.desc())
  204. else:
  205. query = query.order_by(User.created_at.desc())
  206. if skip:
  207. query = query.offset(skip)
  208. if limit:
  209. query = query.limit(limit)
  210. users = query.all()
  211. return {
  212. "users": [UserModel.model_validate(user) for user in users],
  213. "total": db.query(User).count(),
  214. }
  215. def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
  216. with get_db() as db:
  217. users = db.query(User).filter(User.id.in_(user_ids)).all()
  218. return [UserModel.model_validate(user) for user in users]
  219. def get_num_users(self) -> Optional[int]:
  220. with get_db() as db:
  221. return db.query(User).count()
  222. def has_users(self) -> bool:
  223. with get_db() as db:
  224. return db.query(db.query(User).exists()).scalar()
  225. def get_first_user(self) -> UserModel:
  226. try:
  227. with get_db() as db:
  228. user = db.query(User).order_by(User.created_at).first()
  229. return UserModel.model_validate(user)
  230. except Exception:
  231. return None
  232. def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
  233. try:
  234. with get_db() as db:
  235. user = db.query(User).filter_by(id=id).first()
  236. if user.settings is None:
  237. return None
  238. else:
  239. return (
  240. user.settings.get("ui", {})
  241. .get("notifications", {})
  242. .get("webhook_url", None)
  243. )
  244. except Exception:
  245. return None
  246. def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
  247. try:
  248. with get_db() as db:
  249. db.query(User).filter_by(id=id).update({"role": role})
  250. db.commit()
  251. user = db.query(User).filter_by(id=id).first()
  252. return UserModel.model_validate(user)
  253. except Exception:
  254. return None
  255. def update_user_profile_image_url_by_id(
  256. self, id: str, profile_image_url: str
  257. ) -> Optional[UserModel]:
  258. try:
  259. with get_db() as db:
  260. db.query(User).filter_by(id=id).update(
  261. {"profile_image_url": profile_image_url}
  262. )
  263. db.commit()
  264. user = db.query(User).filter_by(id=id).first()
  265. return UserModel.model_validate(user)
  266. except Exception:
  267. return None
  268. @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
  269. def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
  270. try:
  271. with get_db() as db:
  272. db.query(User).filter_by(id=id).update(
  273. {"last_active_at": int(time.time())}
  274. )
  275. db.commit()
  276. user = db.query(User).filter_by(id=id).first()
  277. return UserModel.model_validate(user)
  278. except Exception:
  279. return None
  280. def update_user_oauth_sub_by_id(
  281. self, id: str, oauth_sub: str
  282. ) -> Optional[UserModel]:
  283. try:
  284. with get_db() as db:
  285. db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
  286. db.commit()
  287. user = db.query(User).filter_by(id=id).first()
  288. return UserModel.model_validate(user)
  289. except Exception:
  290. return None
  291. def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
  292. try:
  293. with get_db() as db:
  294. db.query(User).filter_by(id=id).update(updated)
  295. db.commit()
  296. user = db.query(User).filter_by(id=id).first()
  297. return UserModel.model_validate(user)
  298. # return UserModel(**user.dict())
  299. except Exception as e:
  300. print(e)
  301. return None
  302. def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
  303. try:
  304. with get_db() as db:
  305. user_settings = db.query(User).filter_by(id=id).first().settings
  306. if user_settings is None:
  307. user_settings = {}
  308. user_settings.update(updated)
  309. db.query(User).filter_by(id=id).update({"settings": user_settings})
  310. db.commit()
  311. user = db.query(User).filter_by(id=id).first()
  312. return UserModel.model_validate(user)
  313. except Exception:
  314. return None
  315. def delete_user_by_id(self, id: str) -> bool:
  316. try:
  317. # Remove User from Groups
  318. Groups.remove_user_from_all_groups(id)
  319. # Delete User Chats
  320. result = Chats.delete_chats_by_user_id(id)
  321. if result:
  322. with get_db() as db:
  323. # Delete User
  324. db.query(User).filter_by(id=id).delete()
  325. db.commit()
  326. return True
  327. else:
  328. return False
  329. except Exception:
  330. return False
  331. def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
  332. try:
  333. with get_db() as db:
  334. result = db.query(User).filter_by(id=id).update({"api_key": api_key})
  335. db.commit()
  336. return True if result == 1 else False
  337. except Exception:
  338. return False
  339. def get_user_api_key_by_id(self, id: str) -> Optional[str]:
  340. try:
  341. with get_db() as db:
  342. user = db.query(User).filter_by(id=id).first()
  343. return user.api_key
  344. except Exception:
  345. return None
  346. def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
  347. with get_db() as db:
  348. users = db.query(User).filter(User.id.in_(user_ids)).all()
  349. return [user.id for user in users]
  350. def get_super_admin_user(self) -> Optional[UserModel]:
  351. with get_db() as db:
  352. user = db.query(User).filter_by(role="admin").first()
  353. if user:
  354. return UserModel.model_validate(user)
  355. else:
  356. return None
  357. Users = UsersTable()