chats.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. import json
  2. import logging
  3. from typing import Optional
  4. from open_webui.apps.webui.models.chats import (
  5. ChatForm,
  6. ChatResponse,
  7. Chats,
  8. ChatTitleIdResponse,
  9. )
  10. from open_webui.apps.webui.models.tags import TagModel, Tags
  11. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  12. from open_webui.constants import ERROR_MESSAGES
  13. from open_webui.env import SRC_LOG_LEVELS
  14. from fastapi import APIRouter, Depends, HTTPException, Request, status
  15. from pydantic import BaseModel
  16. from open_webui.utils.utils import get_admin_user, get_verified_user
  17. log = logging.getLogger(__name__)
  18. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  19. router = APIRouter()
  20. ############################
  21. # GetChatList
  22. ############################
  23. @router.get("/", response_model=list[ChatTitleIdResponse])
  24. @router.get("/list", response_model=list[ChatTitleIdResponse])
  25. async def get_session_user_chat_list(
  26. user=Depends(get_verified_user), page: Optional[int] = None
  27. ):
  28. if page is not None:
  29. limit = 60
  30. skip = (page - 1) * limit
  31. return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit)
  32. else:
  33. return Chats.get_chat_title_id_list_by_user_id(user.id)
  34. ############################
  35. # DeleteAllChats
  36. ############################
  37. @router.delete("/", response_model=bool)
  38. async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
  39. if user.role == "user" and not request.app.state.config.USER_PERMISSIONS.get(
  40. "chat", {}
  41. ).get("deletion", {}):
  42. raise HTTPException(
  43. status_code=status.HTTP_401_UNAUTHORIZED,
  44. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  45. )
  46. result = Chats.delete_chats_by_user_id(user.id)
  47. return result
  48. ############################
  49. # GetUserChatList
  50. ############################
  51. @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
  52. async def get_user_chat_list_by_user_id(
  53. user_id: str,
  54. user=Depends(get_admin_user),
  55. skip: int = 0,
  56. limit: int = 50,
  57. ):
  58. if not ENABLE_ADMIN_CHAT_ACCESS:
  59. raise HTTPException(
  60. status_code=status.HTTP_401_UNAUTHORIZED,
  61. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  62. )
  63. return Chats.get_chat_list_by_user_id(
  64. user_id, include_archived=True, skip=skip, limit=limit
  65. )
  66. ############################
  67. # CreateNewChat
  68. ############################
  69. @router.post("/new", response_model=Optional[ChatResponse])
  70. async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
  71. try:
  72. chat = Chats.insert_new_chat(user.id, form_data)
  73. return ChatResponse(**chat.model_dump())
  74. except Exception as e:
  75. log.exception(e)
  76. raise HTTPException(
  77. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  78. )
  79. ############################
  80. # GetChats
  81. ############################
  82. @router.get("/search", response_model=list[ChatTitleIdResponse])
  83. async def search_user_chats(
  84. text: str, page: Optional[int] = None, user=Depends(get_verified_user)
  85. ):
  86. if page is None:
  87. page = 1
  88. limit = 60
  89. skip = (page - 1) * limit
  90. chat_list = [
  91. ChatTitleIdResponse(**chat.model_dump())
  92. for chat in Chats.get_chats_by_user_id_and_search_text(
  93. user.id, text, skip=skip, limit=limit
  94. )
  95. ]
  96. # Delete tag if no chat is found
  97. words = text.strip().split(" ")
  98. if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
  99. tag_id = words[0].replace("tag:", "")
  100. if len(chat_list) == 0:
  101. if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
  102. log.debug(f"deleting tag: {tag_id}")
  103. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  104. return chat_list
  105. ############################
  106. # GetPinnedChats
  107. ############################
  108. @router.get("/pinned", response_model=list[ChatResponse])
  109. async def get_user_pinned_chats(user=Depends(get_verified_user)):
  110. return [
  111. ChatResponse(**chat.model_dump())
  112. for chat in Chats.get_pinned_chats_by_user_id(user.id)
  113. ]
  114. ############################
  115. # GetChats
  116. ############################
  117. @router.get("/all", response_model=list[ChatResponse])
  118. async def get_user_chats(user=Depends(get_verified_user)):
  119. return [
  120. ChatResponse(**chat.model_dump())
  121. for chat in Chats.get_chats_by_user_id(user.id)
  122. ]
  123. ############################
  124. # GetArchivedChats
  125. ############################
  126. @router.get("/all/archived", response_model=list[ChatResponse])
  127. async def get_user_archived_chats(user=Depends(get_verified_user)):
  128. return [
  129. ChatResponse(**chat.model_dump())
  130. for chat in Chats.get_archived_chats_by_user_id(user.id)
  131. ]
  132. ############################
  133. # GetAllTags
  134. ############################
  135. @router.get("/all/tags", response_model=list[TagModel])
  136. async def get_all_user_tags(user=Depends(get_verified_user)):
  137. try:
  138. tags = Tags.get_tags_by_user_id(user.id)
  139. return tags
  140. except Exception as e:
  141. log.exception(e)
  142. raise HTTPException(
  143. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  144. )
  145. ############################
  146. # GetAllChatsInDB
  147. ############################
  148. @router.get("/all/db", response_model=list[ChatResponse])
  149. async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
  150. if not ENABLE_ADMIN_EXPORT:
  151. raise HTTPException(
  152. status_code=status.HTTP_401_UNAUTHORIZED,
  153. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  154. )
  155. return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
  156. ############################
  157. # GetArchivedChats
  158. ############################
  159. @router.get("/archived", response_model=list[ChatTitleIdResponse])
  160. async def get_archived_session_user_chat_list(
  161. user=Depends(get_verified_user), skip: int = 0, limit: int = 50
  162. ):
  163. return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
  164. ############################
  165. # ArchiveAllChats
  166. ############################
  167. @router.post("/archive/all", response_model=bool)
  168. async def archive_all_chats(user=Depends(get_verified_user)):
  169. return Chats.archive_all_chats_by_user_id(user.id)
  170. ############################
  171. # GetSharedChatById
  172. ############################
  173. @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
  174. async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
  175. if user.role == "pending":
  176. raise HTTPException(
  177. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  178. )
  179. if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
  180. chat = Chats.get_chat_by_share_id(share_id)
  181. elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
  182. chat = Chats.get_chat_by_id(share_id)
  183. if chat:
  184. return ChatResponse(**chat.model_dump())
  185. else:
  186. raise HTTPException(
  187. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  188. )
  189. ############################
  190. # GetChatsByTags
  191. ############################
  192. class TagForm(BaseModel):
  193. name: str
  194. class TagFilterForm(TagForm):
  195. skip: Optional[int] = 0
  196. limit: Optional[int] = 50
  197. @router.post("/tags", response_model=list[ChatTitleIdResponse])
  198. async def get_user_chat_list_by_tag_name(
  199. form_data: TagFilterForm, user=Depends(get_verified_user)
  200. ):
  201. chats = Chats.get_chat_list_by_user_id_and_tag_name(
  202. user.id, form_data.name, form_data.skip, form_data.limit
  203. )
  204. if len(chats) == 0:
  205. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  206. return chats
  207. ############################
  208. # GetChatById
  209. ############################
  210. @router.get("/{id}", response_model=Optional[ChatResponse])
  211. async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
  212. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  213. if chat:
  214. return ChatResponse(**chat.model_dump())
  215. else:
  216. raise HTTPException(
  217. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  218. )
  219. ############################
  220. # UpdateChatById
  221. ############################
  222. @router.post("/{id}", response_model=Optional[ChatResponse])
  223. async def update_chat_by_id(
  224. id: str, form_data: ChatForm, user=Depends(get_verified_user)
  225. ):
  226. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  227. if chat:
  228. updated_chat = {**chat.chat, **form_data.chat}
  229. chat = Chats.update_chat_by_id(id, updated_chat)
  230. return ChatResponse(**chat.model_dump())
  231. else:
  232. raise HTTPException(
  233. status_code=status.HTTP_401_UNAUTHORIZED,
  234. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  235. )
  236. ############################
  237. # DeleteChatById
  238. ############################
  239. @router.delete("/{id}", response_model=bool)
  240. async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  241. if user.role == "admin":
  242. chat = Chats.get_chat_by_id(id)
  243. for tag in chat.meta.get("tags", []):
  244. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  245. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  246. result = Chats.delete_chat_by_id(id)
  247. return result
  248. else:
  249. if not request.app.state.config.USER_PERMISSIONS.get("chat", {}).get(
  250. "deletion", {}
  251. ):
  252. raise HTTPException(
  253. status_code=status.HTTP_401_UNAUTHORIZED,
  254. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  255. )
  256. chat = Chats.get_chat_by_id(id)
  257. for tag in chat.meta.get("tags", []):
  258. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  259. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  260. result = Chats.delete_chat_by_id_and_user_id(id, user.id)
  261. return result
  262. ############################
  263. # GetPinnedStatusById
  264. ############################
  265. @router.get("/{id}/pinned", response_model=Optional[bool])
  266. async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
  267. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  268. if chat:
  269. return chat.pinned
  270. else:
  271. raise HTTPException(
  272. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  273. )
  274. ############################
  275. # PinChatById
  276. ############################
  277. @router.post("/{id}/pin", response_model=Optional[ChatResponse])
  278. async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
  279. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  280. if chat:
  281. chat = Chats.toggle_chat_pinned_by_id(id)
  282. return chat
  283. else:
  284. raise HTTPException(
  285. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  286. )
  287. ############################
  288. # CloneChat
  289. ############################
  290. @router.post("/{id}/clone", response_model=Optional[ChatResponse])
  291. async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
  292. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  293. if chat:
  294. updated_chat = {
  295. **chat.chat,
  296. "originalChatId": chat.id,
  297. "branchPointMessageId": chat.chat["history"]["currentId"],
  298. "title": f"Clone of {chat.title}",
  299. }
  300. chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
  301. return ChatResponse(**chat.model_dump())
  302. else:
  303. raise HTTPException(
  304. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  305. )
  306. ############################
  307. # ArchiveChat
  308. ############################
  309. @router.post("/{id}/archive", response_model=Optional[ChatResponse])
  310. async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
  311. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  312. if chat:
  313. chat = Chats.toggle_chat_archive_by_id(id)
  314. # Delete tags if chat is archived
  315. if chat.archived:
  316. for tag_id in chat.meta.get("tags", []):
  317. if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
  318. log.debug(f"deleting tag: {tag_id}")
  319. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  320. else:
  321. for tag_id in chat.meta.get("tags", []):
  322. tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
  323. if tag is None:
  324. log.debug(f"inserting tag: {tag_id}")
  325. tag = Tags.insert_new_tag(tag_id, user.id)
  326. return ChatResponse(**chat.model_dump())
  327. else:
  328. raise HTTPException(
  329. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  330. )
  331. ############################
  332. # ShareChatById
  333. ############################
  334. @router.post("/{id}/share", response_model=Optional[ChatResponse])
  335. async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
  336. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  337. if chat:
  338. if chat.share_id:
  339. shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
  340. return ChatResponse(**shared_chat.model_dump())
  341. shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
  342. if not shared_chat:
  343. raise HTTPException(
  344. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  345. detail=ERROR_MESSAGES.DEFAULT(),
  346. )
  347. return ChatResponse(**shared_chat.model_dump())
  348. else:
  349. raise HTTPException(
  350. status_code=status.HTTP_401_UNAUTHORIZED,
  351. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  352. )
  353. ############################
  354. # DeletedSharedChatById
  355. ############################
  356. @router.delete("/{id}/share", response_model=Optional[bool])
  357. async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  358. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  359. if chat:
  360. if not chat.share_id:
  361. return False
  362. result = Chats.delete_shared_chat_by_chat_id(id)
  363. update_result = Chats.update_chat_share_id_by_id(id, None)
  364. return result and update_result != None
  365. else:
  366. raise HTTPException(
  367. status_code=status.HTTP_401_UNAUTHORIZED,
  368. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  369. )
  370. ############################
  371. # UpdateChatFolderIdById
  372. ############################
  373. class ChatFolderIdForm(BaseModel):
  374. folder_id: Optional[str] = None
  375. @router.post("/{id}/folder", response_model=Optional[ChatResponse])
  376. async def update_chat_folder_id_by_id(
  377. id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
  378. ):
  379. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  380. if chat:
  381. chat = Chats.update_chat_folder_id_by_id_and_user_id(
  382. id, user.id, form_data.folder_id
  383. )
  384. return ChatResponse(**chat.model_dump())
  385. else:
  386. raise HTTPException(
  387. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  388. )
  389. ############################
  390. # GetChatTagsById
  391. ############################
  392. @router.get("/{id}/tags", response_model=list[TagModel])
  393. async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  394. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  395. if chat:
  396. tags = chat.meta.get("tags", [])
  397. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  398. else:
  399. raise HTTPException(
  400. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  401. )
  402. ############################
  403. # AddChatTagById
  404. ############################
  405. @router.post("/{id}/tags", response_model=list[TagModel])
  406. async def add_tag_by_id_and_tag_name(
  407. id: str, form_data: TagForm, user=Depends(get_verified_user)
  408. ):
  409. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  410. if chat:
  411. tags = chat.meta.get("tags", [])
  412. tag_id = form_data.name.replace(" ", "_").lower()
  413. print(tags, tag_id)
  414. if tag_id not in tags:
  415. Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
  416. id, user.id, form_data.name
  417. )
  418. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  419. tags = chat.meta.get("tags", [])
  420. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  421. else:
  422. raise HTTPException(
  423. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  424. )
  425. ############################
  426. # DeleteChatTagById
  427. ############################
  428. @router.delete("/{id}/tags", response_model=list[TagModel])
  429. async def delete_tag_by_id_and_tag_name(
  430. id: str, form_data: TagForm, user=Depends(get_verified_user)
  431. ):
  432. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  433. if chat:
  434. Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
  435. if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
  436. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  437. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  438. tags = chat.meta.get("tags", [])
  439. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  440. else:
  441. raise HTTPException(
  442. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  443. )
  444. ############################
  445. # DeleteAllChatTagsById
  446. ############################
  447. @router.delete("/{id}/tags/all", response_model=Optional[bool])
  448. async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  449. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  450. if chat:
  451. Chats.delete_all_tags_by_id_and_user_id(id, user.id)
  452. for tag in chat.meta.get("tags", []):
  453. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
  454. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  455. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  456. tags = chat.meta.get("tags", [])
  457. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  458. else:
  459. raise HTTPException(
  460. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  461. )