chats.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. import json
  2. import logging
  3. from typing import Optional
  4. from open_webui.socket.main import get_event_emitter
  5. from open_webui.models.chats import (
  6. ChatForm,
  7. ChatImportForm,
  8. ChatResponse,
  9. Chats,
  10. ChatTitleIdResponse,
  11. )
  12. from open_webui.models.tags import TagModel, Tags
  13. from open_webui.models.folders import Folders
  14. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import SRC_LOG_LEVELS
  17. from fastapi import APIRouter, Depends, HTTPException, Request, status
  18. from pydantic import BaseModel
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_permission
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  23. router = APIRouter()
  24. ############################
  25. # GetChatList
  26. ############################
  27. @router.get("/", response_model=list[ChatTitleIdResponse])
  28. @router.get("/list", response_model=list[ChatTitleIdResponse])
  29. def get_session_user_chat_list(
  30. user=Depends(get_verified_user), page: Optional[int] = None
  31. ):
  32. try:
  33. if page is not None:
  34. limit = 60
  35. skip = (page - 1) * limit
  36. return Chats.get_chat_title_id_list_by_user_id(
  37. user.id, skip=skip, limit=limit
  38. )
  39. else:
  40. return Chats.get_chat_title_id_list_by_user_id(user.id)
  41. except Exception as e:
  42. log.exception(e)
  43. raise HTTPException(
  44. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  45. )
  46. ############################
  47. # DeleteAllChats
  48. ############################
  49. @router.delete("/", response_model=bool)
  50. async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
  51. if user.role == "user" and not has_permission(
  52. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  53. ):
  54. raise HTTPException(
  55. status_code=status.HTTP_401_UNAUTHORIZED,
  56. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  57. )
  58. result = Chats.delete_chats_by_user_id(user.id)
  59. return result
  60. ############################
  61. # GetUserChatList
  62. ############################
  63. @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
  64. async def get_user_chat_list_by_user_id(
  65. user_id: str,
  66. page: Optional[int] = None,
  67. query: Optional[str] = None,
  68. order_by: Optional[str] = None,
  69. direction: Optional[str] = None,
  70. user=Depends(get_admin_user),
  71. ):
  72. if not ENABLE_ADMIN_CHAT_ACCESS:
  73. raise HTTPException(
  74. status_code=status.HTTP_401_UNAUTHORIZED,
  75. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  76. )
  77. if page is None:
  78. page = 1
  79. limit = 60
  80. skip = (page - 1) * limit
  81. filter = {}
  82. if query:
  83. filter["query"] = query
  84. if order_by:
  85. filter["order_by"] = order_by
  86. if direction:
  87. filter["direction"] = direction
  88. return Chats.get_chat_list_by_user_id(
  89. user_id, include_archived=True, filter=filter, skip=skip, limit=limit
  90. )
  91. ############################
  92. # CreateNewChat
  93. ############################
  94. @router.post("/new", response_model=Optional[ChatResponse])
  95. async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
  96. try:
  97. chat = Chats.insert_new_chat(user.id, form_data)
  98. return ChatResponse(**chat.model_dump())
  99. except Exception as e:
  100. log.exception(e)
  101. raise HTTPException(
  102. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  103. )
  104. ############################
  105. # ImportChat
  106. ############################
  107. @router.post("/import", response_model=Optional[ChatResponse])
  108. async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
  109. try:
  110. chat = Chats.import_chat(user.id, form_data)
  111. if chat:
  112. tags = chat.meta.get("tags", [])
  113. for tag_id in tags:
  114. tag_id = tag_id.replace(" ", "_").lower()
  115. tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
  116. if (
  117. tag_id != "none"
  118. and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
  119. ):
  120. Tags.insert_new_tag(tag_name, user.id)
  121. return ChatResponse(**chat.model_dump())
  122. except Exception as e:
  123. log.exception(e)
  124. raise HTTPException(
  125. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  126. )
  127. ############################
  128. # GetChats
  129. ############################
  130. @router.get("/search", response_model=list[ChatTitleIdResponse])
  131. async def search_user_chats(
  132. text: str, page: Optional[int] = None, user=Depends(get_verified_user)
  133. ):
  134. if page is None:
  135. page = 1
  136. limit = 60
  137. skip = (page - 1) * limit
  138. chat_list = [
  139. ChatTitleIdResponse(**chat.model_dump())
  140. for chat in Chats.get_chats_by_user_id_and_search_text(
  141. user.id, text, skip=skip, limit=limit
  142. )
  143. ]
  144. # Delete tag if no chat is found
  145. words = text.strip().split(" ")
  146. if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
  147. tag_id = words[0].replace("tag:", "")
  148. if len(chat_list) == 0:
  149. if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
  150. log.debug(f"deleting tag: {tag_id}")
  151. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  152. return chat_list
  153. ############################
  154. # GetChatsByFolderId
  155. ############################
  156. @router.get("/folder/{folder_id}", response_model=list[ChatResponse])
  157. async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
  158. folder_ids = [folder_id]
  159. children_folders = Folders.get_children_folders_by_id_and_user_id(
  160. folder_id, user.id
  161. )
  162. if children_folders:
  163. folder_ids.extend([folder.id for folder in children_folders])
  164. return [
  165. ChatResponse(**chat.model_dump())
  166. for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
  167. ]
  168. ############################
  169. # GetPinnedChats
  170. ############################
  171. @router.get("/pinned", response_model=list[ChatTitleIdResponse])
  172. async def get_user_pinned_chats(user=Depends(get_verified_user)):
  173. return [
  174. ChatTitleIdResponse(**chat.model_dump())
  175. for chat in Chats.get_pinned_chats_by_user_id(user.id)
  176. ]
  177. ############################
  178. # GetChats
  179. ############################
  180. @router.get("/all", response_model=list[ChatResponse])
  181. async def get_user_chats(user=Depends(get_verified_user)):
  182. return [
  183. ChatResponse(**chat.model_dump())
  184. for chat in Chats.get_chats_by_user_id(user.id)
  185. ]
  186. ############################
  187. # GetArchivedChats
  188. ############################
  189. @router.get("/all/archived", response_model=list[ChatResponse])
  190. async def get_user_archived_chats(user=Depends(get_verified_user)):
  191. return [
  192. ChatResponse(**chat.model_dump())
  193. for chat in Chats.get_archived_chats_by_user_id(user.id)
  194. ]
  195. ############################
  196. # GetAllTags
  197. ############################
  198. @router.get("/all/tags", response_model=list[TagModel])
  199. async def get_all_user_tags(user=Depends(get_verified_user)):
  200. try:
  201. tags = Tags.get_tags_by_user_id(user.id)
  202. return tags
  203. except Exception as e:
  204. log.exception(e)
  205. raise HTTPException(
  206. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  207. )
  208. ############################
  209. # GetAllChatsInDB
  210. ############################
  211. @router.get("/all/db", response_model=list[ChatResponse])
  212. async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
  213. if not ENABLE_ADMIN_EXPORT:
  214. raise HTTPException(
  215. status_code=status.HTTP_401_UNAUTHORIZED,
  216. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  217. )
  218. return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
  219. ############################
  220. # GetArchivedChats
  221. ############################
  222. @router.get("/archived", response_model=list[ChatTitleIdResponse])
  223. async def get_archived_session_user_chat_list(
  224. page: Optional[int] = None,
  225. query: Optional[str] = None,
  226. order_by: Optional[str] = None,
  227. direction: Optional[str] = None,
  228. user=Depends(get_verified_user),
  229. ):
  230. if page is None:
  231. page = 1
  232. limit = 60
  233. skip = (page - 1) * limit
  234. filter = {}
  235. if query:
  236. filter["query"] = query
  237. if order_by:
  238. filter["order_by"] = order_by
  239. if direction:
  240. filter["direction"] = direction
  241. chat_list = [
  242. ChatTitleIdResponse(**chat.model_dump())
  243. for chat in Chats.get_archived_chat_list_by_user_id(
  244. user.id,
  245. filter=filter,
  246. skip=skip,
  247. limit=limit,
  248. )
  249. ]
  250. return chat_list
  251. ############################
  252. # ArchiveAllChats
  253. ############################
  254. @router.post("/archive/all", response_model=bool)
  255. async def archive_all_chats(user=Depends(get_verified_user)):
  256. return Chats.archive_all_chats_by_user_id(user.id)
  257. ############################
  258. # GetSharedChatById
  259. ############################
  260. @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
  261. async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
  262. if user.role == "pending":
  263. raise HTTPException(
  264. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  265. )
  266. if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
  267. chat = Chats.get_chat_by_share_id(share_id)
  268. elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
  269. chat = Chats.get_chat_by_id(share_id)
  270. if chat:
  271. return ChatResponse(**chat.model_dump())
  272. else:
  273. raise HTTPException(
  274. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  275. )
  276. ############################
  277. # GetChatsByTags
  278. ############################
  279. class TagForm(BaseModel):
  280. name: str
  281. class TagFilterForm(TagForm):
  282. skip: Optional[int] = 0
  283. limit: Optional[int] = 50
  284. @router.post("/tags", response_model=list[ChatTitleIdResponse])
  285. async def get_user_chat_list_by_tag_name(
  286. form_data: TagFilterForm, user=Depends(get_verified_user)
  287. ):
  288. chats = Chats.get_chat_list_by_user_id_and_tag_name(
  289. user.id, form_data.name, form_data.skip, form_data.limit
  290. )
  291. if len(chats) == 0:
  292. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  293. return chats
  294. ############################
  295. # GetChatById
  296. ############################
  297. @router.get("/{id}", response_model=Optional[ChatResponse])
  298. async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
  299. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  300. if chat:
  301. return ChatResponse(**chat.model_dump())
  302. else:
  303. raise HTTPException(
  304. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  305. )
  306. ############################
  307. # UpdateChatById
  308. ############################
  309. @router.post("/{id}", response_model=Optional[ChatResponse])
  310. async def update_chat_by_id(
  311. id: str, form_data: ChatForm, user=Depends(get_verified_user)
  312. ):
  313. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  314. if chat:
  315. updated_chat = {**chat.chat, **form_data.chat}
  316. chat = Chats.update_chat_by_id(id, updated_chat)
  317. return ChatResponse(**chat.model_dump())
  318. else:
  319. raise HTTPException(
  320. status_code=status.HTTP_401_UNAUTHORIZED,
  321. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  322. )
  323. ############################
  324. # UpdateChatMessageById
  325. ############################
  326. class MessageForm(BaseModel):
  327. content: str
  328. @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
  329. async def update_chat_message_by_id(
  330. id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
  331. ):
  332. chat = Chats.get_chat_by_id(id)
  333. if not chat:
  334. raise HTTPException(
  335. status_code=status.HTTP_401_UNAUTHORIZED,
  336. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  337. )
  338. if chat.user_id != user.id and user.role != "admin":
  339. raise HTTPException(
  340. status_code=status.HTTP_401_UNAUTHORIZED,
  341. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  342. )
  343. chat = Chats.upsert_message_to_chat_by_id_and_message_id(
  344. id,
  345. message_id,
  346. {
  347. "content": form_data.content,
  348. },
  349. )
  350. event_emitter = get_event_emitter(
  351. {
  352. "user_id": user.id,
  353. "chat_id": id,
  354. "message_id": message_id,
  355. },
  356. False,
  357. )
  358. if event_emitter:
  359. await event_emitter(
  360. {
  361. "type": "chat:message",
  362. "data": {
  363. "chat_id": id,
  364. "message_id": message_id,
  365. "content": form_data.content,
  366. },
  367. }
  368. )
  369. return ChatResponse(**chat.model_dump())
  370. ############################
  371. # SendChatMessageEventById
  372. ############################
  373. class EventForm(BaseModel):
  374. type: str
  375. data: dict
  376. @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
  377. async def send_chat_message_event_by_id(
  378. id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
  379. ):
  380. chat = Chats.get_chat_by_id(id)
  381. if not chat:
  382. raise HTTPException(
  383. status_code=status.HTTP_401_UNAUTHORIZED,
  384. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  385. )
  386. if chat.user_id != user.id and user.role != "admin":
  387. raise HTTPException(
  388. status_code=status.HTTP_401_UNAUTHORIZED,
  389. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  390. )
  391. event_emitter = get_event_emitter(
  392. {
  393. "user_id": user.id,
  394. "chat_id": id,
  395. "message_id": message_id,
  396. }
  397. )
  398. try:
  399. if event_emitter:
  400. await event_emitter(form_data.model_dump())
  401. else:
  402. return False
  403. return True
  404. except:
  405. return False
  406. ############################
  407. # DeleteChatById
  408. ############################
  409. @router.delete("/{id}", response_model=bool)
  410. async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  411. if user.role == "admin":
  412. chat = Chats.get_chat_by_id(id)
  413. for tag in chat.meta.get("tags", []):
  414. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  415. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  416. result = Chats.delete_chat_by_id(id)
  417. return result
  418. else:
  419. if not has_permission(
  420. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  421. ):
  422. raise HTTPException(
  423. status_code=status.HTTP_401_UNAUTHORIZED,
  424. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  425. )
  426. chat = Chats.get_chat_by_id(id)
  427. for tag in chat.meta.get("tags", []):
  428. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  429. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  430. result = Chats.delete_chat_by_id_and_user_id(id, user.id)
  431. return result
  432. ############################
  433. # GetPinnedStatusById
  434. ############################
  435. @router.get("/{id}/pinned", response_model=Optional[bool])
  436. async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
  437. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  438. if chat:
  439. return chat.pinned
  440. else:
  441. raise HTTPException(
  442. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  443. )
  444. ############################
  445. # PinChatById
  446. ############################
  447. @router.post("/{id}/pin", response_model=Optional[ChatResponse])
  448. async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
  449. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  450. if chat:
  451. chat = Chats.toggle_chat_pinned_by_id(id)
  452. return chat
  453. else:
  454. raise HTTPException(
  455. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  456. )
  457. ############################
  458. # CloneChat
  459. ############################
  460. class CloneForm(BaseModel):
  461. title: Optional[str] = None
  462. @router.post("/{id}/clone", response_model=Optional[ChatResponse])
  463. async def clone_chat_by_id(
  464. form_data: CloneForm, id: str, user=Depends(get_verified_user)
  465. ):
  466. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  467. if chat:
  468. updated_chat = {
  469. **chat.chat,
  470. "originalChatId": chat.id,
  471. "branchPointMessageId": chat.chat["history"]["currentId"],
  472. "title": form_data.title if form_data.title else f"Clone of {chat.title}",
  473. }
  474. chat = Chats.import_chat(
  475. user.id,
  476. ChatImportForm(
  477. **{
  478. "chat": updated_chat,
  479. "meta": chat.meta,
  480. "pinned": chat.pinned,
  481. "folder_id": chat.folder_id,
  482. }
  483. ),
  484. )
  485. return ChatResponse(**chat.model_dump())
  486. else:
  487. raise HTTPException(
  488. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  489. )
  490. ############################
  491. # CloneSharedChatById
  492. ############################
  493. @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
  494. async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  495. if user.role == "admin":
  496. chat = Chats.get_chat_by_id(id)
  497. else:
  498. chat = Chats.get_chat_by_share_id(id)
  499. if chat:
  500. updated_chat = {
  501. **chat.chat,
  502. "originalChatId": chat.id,
  503. "branchPointMessageId": chat.chat["history"]["currentId"],
  504. "title": f"Clone of {chat.title}",
  505. }
  506. chat = Chats.import_chat(
  507. user.id,
  508. ChatImportForm(
  509. **{
  510. "chat": updated_chat,
  511. "meta": chat.meta,
  512. "pinned": chat.pinned,
  513. "folder_id": chat.folder_id,
  514. }
  515. ),
  516. )
  517. return ChatResponse(**chat.model_dump())
  518. else:
  519. raise HTTPException(
  520. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  521. )
  522. ############################
  523. # ArchiveChat
  524. ############################
  525. @router.post("/{id}/archive", response_model=Optional[ChatResponse])
  526. async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
  527. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  528. if chat:
  529. chat = Chats.toggle_chat_archive_by_id(id)
  530. # Delete tags if chat is archived
  531. if chat.archived:
  532. for tag_id in chat.meta.get("tags", []):
  533. if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
  534. log.debug(f"deleting tag: {tag_id}")
  535. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  536. else:
  537. for tag_id in chat.meta.get("tags", []):
  538. tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
  539. if tag is None:
  540. log.debug(f"inserting tag: {tag_id}")
  541. tag = Tags.insert_new_tag(tag_id, user.id)
  542. return ChatResponse(**chat.model_dump())
  543. else:
  544. raise HTTPException(
  545. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  546. )
  547. ############################
  548. # ShareChatById
  549. ############################
  550. @router.post("/{id}/share", response_model=Optional[ChatResponse])
  551. async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  552. if (user.role != "admin") and (
  553. not has_permission(
  554. user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
  555. )
  556. ):
  557. raise HTTPException(
  558. status_code=status.HTTP_401_UNAUTHORIZED,
  559. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  560. )
  561. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  562. if chat:
  563. if chat.share_id:
  564. shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
  565. return ChatResponse(**shared_chat.model_dump())
  566. shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
  567. if not shared_chat:
  568. raise HTTPException(
  569. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  570. detail=ERROR_MESSAGES.DEFAULT(),
  571. )
  572. return ChatResponse(**shared_chat.model_dump())
  573. else:
  574. raise HTTPException(
  575. status_code=status.HTTP_401_UNAUTHORIZED,
  576. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  577. )
  578. ############################
  579. # DeletedSharedChatById
  580. ############################
  581. @router.delete("/{id}/share", response_model=Optional[bool])
  582. async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  583. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  584. if chat:
  585. if not chat.share_id:
  586. return False
  587. result = Chats.delete_shared_chat_by_chat_id(id)
  588. update_result = Chats.update_chat_share_id_by_id(id, None)
  589. return result and update_result != None
  590. else:
  591. raise HTTPException(
  592. status_code=status.HTTP_401_UNAUTHORIZED,
  593. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  594. )
  595. ############################
  596. # UpdateChatFolderIdById
  597. ############################
  598. class ChatFolderIdForm(BaseModel):
  599. folder_id: Optional[str] = None
  600. @router.post("/{id}/folder", response_model=Optional[ChatResponse])
  601. async def update_chat_folder_id_by_id(
  602. id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
  603. ):
  604. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  605. if chat:
  606. chat = Chats.update_chat_folder_id_by_id_and_user_id(
  607. id, user.id, form_data.folder_id
  608. )
  609. return ChatResponse(**chat.model_dump())
  610. else:
  611. raise HTTPException(
  612. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  613. )
  614. ############################
  615. # GetChatTagsById
  616. ############################
  617. @router.get("/{id}/tags", response_model=list[TagModel])
  618. async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  619. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  620. if chat:
  621. tags = chat.meta.get("tags", [])
  622. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  623. else:
  624. raise HTTPException(
  625. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  626. )
  627. ############################
  628. # AddChatTagById
  629. ############################
  630. @router.post("/{id}/tags", response_model=list[TagModel])
  631. async def add_tag_by_id_and_tag_name(
  632. id: str, form_data: TagForm, user=Depends(get_verified_user)
  633. ):
  634. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  635. if chat:
  636. tags = chat.meta.get("tags", [])
  637. tag_id = form_data.name.replace(" ", "_").lower()
  638. if tag_id == "none":
  639. raise HTTPException(
  640. status_code=status.HTTP_400_BAD_REQUEST,
  641. detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
  642. )
  643. if tag_id not in tags:
  644. Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
  645. id, user.id, form_data.name
  646. )
  647. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  648. tags = chat.meta.get("tags", [])
  649. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  650. else:
  651. raise HTTPException(
  652. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  653. )
  654. ############################
  655. # DeleteChatTagById
  656. ############################
  657. @router.delete("/{id}/tags", response_model=list[TagModel])
  658. async def delete_tag_by_id_and_tag_name(
  659. id: str, form_data: TagForm, user=Depends(get_verified_user)
  660. ):
  661. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  662. if chat:
  663. Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
  664. if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
  665. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  666. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  667. tags = chat.meta.get("tags", [])
  668. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  669. else:
  670. raise HTTPException(
  671. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  672. )
  673. ############################
  674. # DeleteAllTagsById
  675. ############################
  676. @router.delete("/{id}/tags/all", response_model=Optional[bool])
  677. async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
  678. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  679. if chat:
  680. Chats.delete_all_tags_by_id_and_user_id(id, user.id)
  681. for tag in chat.meta.get("tags", []):
  682. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
  683. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  684. return True
  685. else:
  686. raise HTTPException(
  687. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  688. )