chats.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917
  1. import json
  2. import logging
  3. from typing import Optional
  4. from open_webui.socket.main import get_event_emitter
  5. from open_webui.models.chats import (
  6. ChatForm,
  7. ChatImportForm,
  8. ChatResponse,
  9. Chats,
  10. ChatTitleIdResponse,
  11. )
  12. from open_webui.models.tags import TagModel, Tags
  13. from open_webui.models.folders import Folders
  14. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import SRC_LOG_LEVELS
  17. from fastapi import APIRouter, Depends, HTTPException, Request, status
  18. from pydantic import BaseModel
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_permission
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  23. router = APIRouter()
  24. ############################
  25. # GetChatList
  26. ############################
  27. @router.get("/", response_model=list[ChatTitleIdResponse])
  28. @router.get("/list", response_model=list[ChatTitleIdResponse])
  29. def get_session_user_chat_list(
  30. user=Depends(get_verified_user),
  31. page: Optional[int] = None,
  32. include_folders: Optional[bool] = False,
  33. ):
  34. try:
  35. if page is not None:
  36. limit = 60
  37. skip = (page - 1) * limit
  38. return Chats.get_chat_title_id_list_by_user_id(
  39. user.id, include_folders=include_folders, skip=skip, limit=limit
  40. )
  41. else:
  42. return Chats.get_chat_title_id_list_by_user_id(
  43. user.id, include_folders=include_folders
  44. )
  45. except Exception as e:
  46. log.exception(e)
  47. raise HTTPException(
  48. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  49. )
  50. ############################
  51. # DeleteAllChats
  52. ############################
  53. @router.delete("/", response_model=bool)
  54. async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
  55. if user.role == "user" and not has_permission(
  56. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  57. ):
  58. raise HTTPException(
  59. status_code=status.HTTP_401_UNAUTHORIZED,
  60. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  61. )
  62. result = Chats.delete_chats_by_user_id(user.id)
  63. return result
  64. ############################
  65. # GetUserChatList
  66. ############################
  67. @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
  68. async def get_user_chat_list_by_user_id(
  69. user_id: str,
  70. page: Optional[int] = None,
  71. query: Optional[str] = None,
  72. order_by: Optional[str] = None,
  73. direction: Optional[str] = None,
  74. user=Depends(get_admin_user),
  75. ):
  76. if not ENABLE_ADMIN_CHAT_ACCESS:
  77. raise HTTPException(
  78. status_code=status.HTTP_401_UNAUTHORIZED,
  79. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  80. )
  81. if page is None:
  82. page = 1
  83. limit = 60
  84. skip = (page - 1) * limit
  85. filter = {}
  86. if query:
  87. filter["query"] = query
  88. if order_by:
  89. filter["order_by"] = order_by
  90. if direction:
  91. filter["direction"] = direction
  92. return Chats.get_chat_list_by_user_id(
  93. user_id, include_archived=True, filter=filter, skip=skip, limit=limit
  94. )
  95. ############################
  96. # CreateNewChat
  97. ############################
  98. @router.post("/new", response_model=Optional[ChatResponse])
  99. async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
  100. try:
  101. chat = Chats.insert_new_chat(user.id, form_data)
  102. return ChatResponse(**chat.model_dump())
  103. except Exception as e:
  104. log.exception(e)
  105. raise HTTPException(
  106. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  107. )
  108. ############################
  109. # ImportChat
  110. ############################
  111. @router.post("/import", response_model=Optional[ChatResponse])
  112. async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
  113. try:
  114. chat = Chats.import_chat(user.id, form_data)
  115. if chat:
  116. tags = chat.meta.get("tags", [])
  117. for tag_id in tags:
  118. tag_id = tag_id.replace(" ", "_").lower()
  119. tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
  120. if (
  121. tag_id != "none"
  122. and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
  123. ):
  124. Tags.insert_new_tag(tag_name, user.id)
  125. return ChatResponse(**chat.model_dump())
  126. except Exception as e:
  127. log.exception(e)
  128. raise HTTPException(
  129. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  130. )
  131. ############################
  132. # GetChats
  133. ############################
  134. @router.get("/search", response_model=list[ChatTitleIdResponse])
  135. def search_user_chats(
  136. text: str, page: Optional[int] = None, user=Depends(get_verified_user)
  137. ):
  138. if page is None:
  139. page = 1
  140. limit = 60
  141. skip = (page - 1) * limit
  142. chat_list = [
  143. ChatTitleIdResponse(**chat.model_dump())
  144. for chat in Chats.get_chats_by_user_id_and_search_text(
  145. user.id, text, skip=skip, limit=limit
  146. )
  147. ]
  148. # Delete tag if no chat is found
  149. words = text.strip().split(" ")
  150. if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
  151. tag_id = words[0].replace("tag:", "")
  152. if len(chat_list) == 0:
  153. if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
  154. log.debug(f"deleting tag: {tag_id}")
  155. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  156. return chat_list
  157. ############################
  158. # GetChatsByFolderId
  159. ############################
  160. @router.get("/folder/{folder_id}", response_model=list[ChatResponse])
  161. async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
  162. folder_ids = [folder_id]
  163. children_folders = Folders.get_children_folders_by_id_and_user_id(
  164. folder_id, user.id
  165. )
  166. if children_folders:
  167. folder_ids.extend([folder.id for folder in children_folders])
  168. return [
  169. ChatResponse(**chat.model_dump())
  170. for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
  171. ]
  172. @router.get("/folder/{folder_id}/list")
  173. async def get_chat_list_by_folder_id(
  174. folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
  175. ):
  176. try:
  177. limit = 60
  178. skip = (page - 1) * limit
  179. return [
  180. {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
  181. for chat in Chats.get_chats_by_folder_id_and_user_id(
  182. folder_id, user.id, skip=skip, limit=limit
  183. )
  184. ]
  185. except Exception as e:
  186. log.exception(e)
  187. raise HTTPException(
  188. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  189. )
  190. ############################
  191. # GetPinnedChats
  192. ############################
  193. @router.get("/pinned", response_model=list[ChatTitleIdResponse])
  194. async def get_user_pinned_chats(user=Depends(get_verified_user)):
  195. return [
  196. ChatTitleIdResponse(**chat.model_dump())
  197. for chat in Chats.get_pinned_chats_by_user_id(user.id)
  198. ]
  199. ############################
  200. # GetChats
  201. ############################
  202. @router.get("/all", response_model=list[ChatResponse])
  203. async def get_user_chats(user=Depends(get_verified_user)):
  204. return [
  205. ChatResponse(**chat.model_dump())
  206. for chat in Chats.get_chats_by_user_id(user.id)
  207. ]
  208. ############################
  209. # GetArchivedChats
  210. ############################
  211. @router.get("/all/archived", response_model=list[ChatResponse])
  212. async def get_user_archived_chats(user=Depends(get_verified_user)):
  213. return [
  214. ChatResponse(**chat.model_dump())
  215. for chat in Chats.get_archived_chats_by_user_id(user.id)
  216. ]
  217. ############################
  218. # GetAllTags
  219. ############################
  220. @router.get("/all/tags", response_model=list[TagModel])
  221. async def get_all_user_tags(user=Depends(get_verified_user)):
  222. try:
  223. tags = Tags.get_tags_by_user_id(user.id)
  224. return tags
  225. except Exception as e:
  226. log.exception(e)
  227. raise HTTPException(
  228. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  229. )
  230. ############################
  231. # GetAllChatsInDB
  232. ############################
  233. @router.get("/all/db", response_model=list[ChatResponse])
  234. async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
  235. if not ENABLE_ADMIN_EXPORT:
  236. raise HTTPException(
  237. status_code=status.HTTP_401_UNAUTHORIZED,
  238. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  239. )
  240. return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
  241. ############################
  242. # GetArchivedChats
  243. ############################
  244. @router.get("/archived", response_model=list[ChatTitleIdResponse])
  245. async def get_archived_session_user_chat_list(
  246. page: Optional[int] = None,
  247. query: Optional[str] = None,
  248. order_by: Optional[str] = None,
  249. direction: Optional[str] = None,
  250. user=Depends(get_verified_user),
  251. ):
  252. if page is None:
  253. page = 1
  254. limit = 60
  255. skip = (page - 1) * limit
  256. filter = {}
  257. if query:
  258. filter["query"] = query
  259. if order_by:
  260. filter["order_by"] = order_by
  261. if direction:
  262. filter["direction"] = direction
  263. chat_list = [
  264. ChatTitleIdResponse(**chat.model_dump())
  265. for chat in Chats.get_archived_chat_list_by_user_id(
  266. user.id,
  267. filter=filter,
  268. skip=skip,
  269. limit=limit,
  270. )
  271. ]
  272. return chat_list
  273. ############################
  274. # ArchiveAllChats
  275. ############################
  276. @router.post("/archive/all", response_model=bool)
  277. async def archive_all_chats(user=Depends(get_verified_user)):
  278. return Chats.archive_all_chats_by_user_id(user.id)
  279. ############################
  280. # GetSharedChatById
  281. ############################
  282. @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
  283. async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
  284. if user.role == "pending":
  285. raise HTTPException(
  286. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  287. )
  288. if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
  289. chat = Chats.get_chat_by_share_id(share_id)
  290. elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
  291. chat = Chats.get_chat_by_id(share_id)
  292. if chat:
  293. return ChatResponse(**chat.model_dump())
  294. else:
  295. raise HTTPException(
  296. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  297. )
  298. ############################
  299. # GetChatsByTags
  300. ############################
  301. class TagForm(BaseModel):
  302. name: str
  303. class TagFilterForm(TagForm):
  304. skip: Optional[int] = 0
  305. limit: Optional[int] = 50
  306. @router.post("/tags", response_model=list[ChatTitleIdResponse])
  307. async def get_user_chat_list_by_tag_name(
  308. form_data: TagFilterForm, user=Depends(get_verified_user)
  309. ):
  310. chats = Chats.get_chat_list_by_user_id_and_tag_name(
  311. user.id, form_data.name, form_data.skip, form_data.limit
  312. )
  313. if len(chats) == 0:
  314. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  315. return chats
  316. ############################
  317. # GetChatById
  318. ############################
  319. @router.get("/{id}", response_model=Optional[ChatResponse])
  320. async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
  321. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  322. if chat:
  323. return ChatResponse(**chat.model_dump())
  324. else:
  325. raise HTTPException(
  326. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  327. )
  328. ############################
  329. # UpdateChatById
  330. ############################
  331. @router.post("/{id}", response_model=Optional[ChatResponse])
  332. async def update_chat_by_id(
  333. id: str, form_data: ChatForm, user=Depends(get_verified_user)
  334. ):
  335. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  336. if chat:
  337. updated_chat = {**chat.chat, **form_data.chat}
  338. chat = Chats.update_chat_by_id(id, updated_chat)
  339. return ChatResponse(**chat.model_dump())
  340. else:
  341. raise HTTPException(
  342. status_code=status.HTTP_401_UNAUTHORIZED,
  343. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  344. )
  345. ############################
  346. # UpdateChatMessageById
  347. ############################
  348. class MessageForm(BaseModel):
  349. content: str
  350. @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
  351. async def update_chat_message_by_id(
  352. id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
  353. ):
  354. chat = Chats.get_chat_by_id(id)
  355. if not chat:
  356. raise HTTPException(
  357. status_code=status.HTTP_401_UNAUTHORIZED,
  358. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  359. )
  360. if chat.user_id != user.id and user.role != "admin":
  361. raise HTTPException(
  362. status_code=status.HTTP_401_UNAUTHORIZED,
  363. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  364. )
  365. chat = Chats.upsert_message_to_chat_by_id_and_message_id(
  366. id,
  367. message_id,
  368. {
  369. "content": form_data.content,
  370. },
  371. )
  372. event_emitter = get_event_emitter(
  373. {
  374. "user_id": user.id,
  375. "chat_id": id,
  376. "message_id": message_id,
  377. },
  378. False,
  379. )
  380. if event_emitter:
  381. await event_emitter(
  382. {
  383. "type": "chat:message",
  384. "data": {
  385. "chat_id": id,
  386. "message_id": message_id,
  387. "content": form_data.content,
  388. },
  389. }
  390. )
  391. return ChatResponse(**chat.model_dump())
  392. ############################
  393. # SendChatMessageEventById
  394. ############################
  395. class EventForm(BaseModel):
  396. type: str
  397. data: dict
  398. @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
  399. async def send_chat_message_event_by_id(
  400. id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
  401. ):
  402. chat = Chats.get_chat_by_id(id)
  403. if not chat:
  404. raise HTTPException(
  405. status_code=status.HTTP_401_UNAUTHORIZED,
  406. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  407. )
  408. if chat.user_id != user.id and user.role != "admin":
  409. raise HTTPException(
  410. status_code=status.HTTP_401_UNAUTHORIZED,
  411. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  412. )
  413. event_emitter = get_event_emitter(
  414. {
  415. "user_id": user.id,
  416. "chat_id": id,
  417. "message_id": message_id,
  418. }
  419. )
  420. try:
  421. if event_emitter:
  422. await event_emitter(form_data.model_dump())
  423. else:
  424. return False
  425. return True
  426. except:
  427. return False
  428. ############################
  429. # DeleteChatById
  430. ############################
  431. @router.delete("/{id}", response_model=bool)
  432. async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  433. if user.role == "admin":
  434. chat = Chats.get_chat_by_id(id)
  435. for tag in chat.meta.get("tags", []):
  436. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  437. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  438. result = Chats.delete_chat_by_id(id)
  439. return result
  440. else:
  441. if not has_permission(
  442. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  443. ):
  444. raise HTTPException(
  445. status_code=status.HTTP_401_UNAUTHORIZED,
  446. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  447. )
  448. chat = Chats.get_chat_by_id(id)
  449. for tag in chat.meta.get("tags", []):
  450. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  451. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  452. result = Chats.delete_chat_by_id_and_user_id(id, user.id)
  453. return result
  454. ############################
  455. # GetPinnedStatusById
  456. ############################
  457. @router.get("/{id}/pinned", response_model=Optional[bool])
  458. async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
  459. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  460. if chat:
  461. return chat.pinned
  462. else:
  463. raise HTTPException(
  464. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  465. )
  466. ############################
  467. # PinChatById
  468. ############################
  469. @router.post("/{id}/pin", response_model=Optional[ChatResponse])
  470. async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
  471. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  472. if chat:
  473. chat = Chats.toggle_chat_pinned_by_id(id)
  474. return chat
  475. else:
  476. raise HTTPException(
  477. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  478. )
  479. ############################
  480. # CloneChat
  481. ############################
  482. class CloneForm(BaseModel):
  483. title: Optional[str] = None
  484. @router.post("/{id}/clone", response_model=Optional[ChatResponse])
  485. async def clone_chat_by_id(
  486. form_data: CloneForm, id: str, user=Depends(get_verified_user)
  487. ):
  488. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  489. if chat:
  490. updated_chat = {
  491. **chat.chat,
  492. "originalChatId": chat.id,
  493. "branchPointMessageId": chat.chat["history"]["currentId"],
  494. "title": form_data.title if form_data.title else f"Clone of {chat.title}",
  495. }
  496. chat = Chats.import_chat(
  497. user.id,
  498. ChatImportForm(
  499. **{
  500. "chat": updated_chat,
  501. "meta": chat.meta,
  502. "pinned": chat.pinned,
  503. "folder_id": chat.folder_id,
  504. }
  505. ),
  506. )
  507. return ChatResponse(**chat.model_dump())
  508. else:
  509. raise HTTPException(
  510. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  511. )
  512. ############################
  513. # CloneSharedChatById
  514. ############################
  515. @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
  516. async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  517. if user.role == "admin":
  518. chat = Chats.get_chat_by_id(id)
  519. else:
  520. chat = Chats.get_chat_by_share_id(id)
  521. if chat:
  522. updated_chat = {
  523. **chat.chat,
  524. "originalChatId": chat.id,
  525. "branchPointMessageId": chat.chat["history"]["currentId"],
  526. "title": f"Clone of {chat.title}",
  527. }
  528. chat = Chats.import_chat(
  529. user.id,
  530. ChatImportForm(
  531. **{
  532. "chat": updated_chat,
  533. "meta": chat.meta,
  534. "pinned": chat.pinned,
  535. "folder_id": chat.folder_id,
  536. }
  537. ),
  538. )
  539. return ChatResponse(**chat.model_dump())
  540. else:
  541. raise HTTPException(
  542. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  543. )
  544. ############################
  545. # ArchiveChat
  546. ############################
  547. @router.post("/{id}/archive", response_model=Optional[ChatResponse])
  548. async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
  549. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  550. if chat:
  551. chat = Chats.toggle_chat_archive_by_id(id)
  552. # Delete tags if chat is archived
  553. if chat.archived:
  554. for tag_id in chat.meta.get("tags", []):
  555. if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
  556. log.debug(f"deleting tag: {tag_id}")
  557. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  558. else:
  559. for tag_id in chat.meta.get("tags", []):
  560. tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
  561. if tag is None:
  562. log.debug(f"inserting tag: {tag_id}")
  563. tag = Tags.insert_new_tag(tag_id, user.id)
  564. return ChatResponse(**chat.model_dump())
  565. else:
  566. raise HTTPException(
  567. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  568. )
  569. ############################
  570. # ShareChatById
  571. ############################
  572. @router.post("/{id}/share", response_model=Optional[ChatResponse])
  573. async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  574. if (user.role != "admin") and (
  575. not has_permission(
  576. user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
  577. )
  578. ):
  579. raise HTTPException(
  580. status_code=status.HTTP_401_UNAUTHORIZED,
  581. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  582. )
  583. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  584. if chat:
  585. if chat.share_id:
  586. shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
  587. return ChatResponse(**shared_chat.model_dump())
  588. shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
  589. if not shared_chat:
  590. raise HTTPException(
  591. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  592. detail=ERROR_MESSAGES.DEFAULT(),
  593. )
  594. return ChatResponse(**shared_chat.model_dump())
  595. else:
  596. raise HTTPException(
  597. status_code=status.HTTP_401_UNAUTHORIZED,
  598. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  599. )
  600. ############################
  601. # DeletedSharedChatById
  602. ############################
  603. @router.delete("/{id}/share", response_model=Optional[bool])
  604. async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  605. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  606. if chat:
  607. if not chat.share_id:
  608. return False
  609. result = Chats.delete_shared_chat_by_chat_id(id)
  610. update_result = Chats.update_chat_share_id_by_id(id, None)
  611. return result and update_result != None
  612. else:
  613. raise HTTPException(
  614. status_code=status.HTTP_401_UNAUTHORIZED,
  615. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  616. )
  617. ############################
  618. # UpdateChatFolderIdById
  619. ############################
  620. class ChatFolderIdForm(BaseModel):
  621. folder_id: Optional[str] = None
  622. @router.post("/{id}/folder", response_model=Optional[ChatResponse])
  623. async def update_chat_folder_id_by_id(
  624. id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
  625. ):
  626. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  627. if chat:
  628. chat = Chats.update_chat_folder_id_by_id_and_user_id(
  629. id, user.id, form_data.folder_id
  630. )
  631. return ChatResponse(**chat.model_dump())
  632. else:
  633. raise HTTPException(
  634. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  635. )
  636. ############################
  637. # GetChatTagsById
  638. ############################
  639. @router.get("/{id}/tags", response_model=list[TagModel])
  640. async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  641. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  642. if chat:
  643. tags = chat.meta.get("tags", [])
  644. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  645. else:
  646. raise HTTPException(
  647. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  648. )
  649. ############################
  650. # AddChatTagById
  651. ############################
  652. @router.post("/{id}/tags", response_model=list[TagModel])
  653. async def add_tag_by_id_and_tag_name(
  654. id: str, form_data: TagForm, user=Depends(get_verified_user)
  655. ):
  656. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  657. if chat:
  658. tags = chat.meta.get("tags", [])
  659. tag_id = form_data.name.replace(" ", "_").lower()
  660. if tag_id == "none":
  661. raise HTTPException(
  662. status_code=status.HTTP_400_BAD_REQUEST,
  663. detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
  664. )
  665. if tag_id not in tags:
  666. Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
  667. id, user.id, form_data.name
  668. )
  669. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  670. tags = chat.meta.get("tags", [])
  671. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  672. else:
  673. raise HTTPException(
  674. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  675. )
  676. ############################
  677. # DeleteChatTagById
  678. ############################
  679. @router.delete("/{id}/tags", response_model=list[TagModel])
  680. async def delete_tag_by_id_and_tag_name(
  681. id: str, form_data: TagForm, user=Depends(get_verified_user)
  682. ):
  683. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  684. if chat:
  685. Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
  686. if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
  687. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  688. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  689. tags = chat.meta.get("tags", [])
  690. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  691. else:
  692. raise HTTPException(
  693. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  694. )
  695. ############################
  696. # DeleteAllTagsById
  697. ############################
  698. @router.delete("/{id}/tags/all", response_model=Optional[bool])
  699. async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
  700. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  701. if chat:
  702. Chats.delete_all_tags_by_id_and_user_id(id, user.id)
  703. for tag in chat.meta.get("tags", []):
  704. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
  705. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  706. return True
  707. else:
  708. raise HTTPException(
  709. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  710. )