1
0

chats.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. import json
  2. import logging
  3. from typing import Optional
  4. from open_webui.socket.main import get_event_emitter
  5. from open_webui.models.chats import (
  6. ChatForm,
  7. ChatImportForm,
  8. ChatResponse,
  9. Chats,
  10. ChatTitleIdResponse,
  11. )
  12. from open_webui.models.tags import TagModel, Tags
  13. from open_webui.models.folders import Folders
  14. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import SRC_LOG_LEVELS
  17. from fastapi import APIRouter, Depends, HTTPException, Request, status
  18. from pydantic import BaseModel
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_permission
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  23. router = APIRouter()
  24. ############################
  25. # GetChatList
  26. ############################
  27. @router.get("/", response_model=list[ChatTitleIdResponse])
  28. @router.get("/list", response_model=list[ChatTitleIdResponse])
  29. def get_session_user_chat_list(
  30. user=Depends(get_verified_user),
  31. page: Optional[int] = None,
  32. include_folders: Optional[bool] = False,
  33. ):
  34. try:
  35. if page is not None:
  36. limit = 60
  37. skip = (page - 1) * limit
  38. return Chats.get_chat_title_id_list_by_user_id(
  39. user.id, include_folders=include_folders, skip=skip, limit=limit
  40. )
  41. else:
  42. return Chats.get_chat_title_id_list_by_user_id(
  43. user.id, include_folders=include_folders
  44. )
  45. except Exception as e:
  46. log.exception(e)
  47. raise HTTPException(
  48. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  49. )
  50. ############################
  51. # DeleteAllChats
  52. ############################
  53. @router.delete("/", response_model=bool)
  54. async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
  55. if user.role == "user" and not has_permission(
  56. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  57. ):
  58. raise HTTPException(
  59. status_code=status.HTTP_401_UNAUTHORIZED,
  60. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  61. )
  62. result = Chats.delete_chats_by_user_id(user.id)
  63. return result
  64. ############################
  65. # GetUserChatList
  66. ############################
  67. @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
  68. async def get_user_chat_list_by_user_id(
  69. user_id: str,
  70. page: Optional[int] = None,
  71. query: Optional[str] = None,
  72. order_by: Optional[str] = None,
  73. direction: Optional[str] = None,
  74. user=Depends(get_admin_user),
  75. ):
  76. if not ENABLE_ADMIN_CHAT_ACCESS:
  77. raise HTTPException(
  78. status_code=status.HTTP_401_UNAUTHORIZED,
  79. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  80. )
  81. if page is None:
  82. page = 1
  83. limit = 60
  84. skip = (page - 1) * limit
  85. filter = {}
  86. if query:
  87. filter["query"] = query
  88. if order_by:
  89. filter["order_by"] = order_by
  90. if direction:
  91. filter["direction"] = direction
  92. return Chats.get_chat_list_by_user_id(
  93. user_id, include_archived=True, filter=filter, skip=skip, limit=limit
  94. )
  95. ############################
  96. # CreateNewChat
  97. ############################
  98. @router.post("/new", response_model=Optional[ChatResponse])
  99. async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
  100. try:
  101. chat = Chats.insert_new_chat(user.id, form_data)
  102. return ChatResponse(**chat.model_dump())
  103. except Exception as e:
  104. log.exception(e)
  105. raise HTTPException(
  106. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  107. )
  108. ############################
  109. # ImportChat
  110. ############################
  111. @router.post("/import", response_model=Optional[ChatResponse])
  112. async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
  113. try:
  114. chat = Chats.import_chat(user.id, form_data)
  115. if chat:
  116. tags = chat.meta.get("tags", [])
  117. for tag_id in tags:
  118. tag_id = tag_id.replace(" ", "_").lower()
  119. tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
  120. if (
  121. tag_id != "none"
  122. and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
  123. ):
  124. Tags.insert_new_tag(tag_name, user.id)
  125. return ChatResponse(**chat.model_dump())
  126. except Exception as e:
  127. log.exception(e)
  128. raise HTTPException(
  129. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  130. )
  131. ############################
  132. # GetChats
  133. ############################
  134. @router.get("/search", response_model=list[ChatTitleIdResponse])
  135. def search_user_chats(
  136. text: str, page: Optional[int] = None, user=Depends(get_verified_user)
  137. ):
  138. if page is None:
  139. page = 1
  140. limit = 60
  141. skip = (page - 1) * limit
  142. chat_list = [
  143. ChatTitleIdResponse(**chat.model_dump())
  144. for chat in Chats.get_chats_by_user_id_and_search_text(
  145. user.id, text, skip=skip, limit=limit
  146. )
  147. ]
  148. # Delete tag if no chat is found
  149. words = text.strip().split(" ")
  150. if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
  151. tag_id = words[0].replace("tag:", "")
  152. if len(chat_list) == 0:
  153. if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
  154. log.debug(f"deleting tag: {tag_id}")
  155. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  156. return chat_list
  157. ############################
  158. # GetChatsByFolderId
  159. ############################
  160. @router.get("/folder/{folder_id}", response_model=list[ChatResponse])
  161. async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
  162. folder_ids = [folder_id]
  163. children_folders = Folders.get_children_folders_by_id_and_user_id(
  164. folder_id, user.id
  165. )
  166. if children_folders:
  167. folder_ids.extend([folder.id for folder in children_folders])
  168. return [
  169. ChatResponse(**chat.model_dump())
  170. for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
  171. ]
  172. @router.get("/folder/{folder_id}/list")
  173. async def get_chat_list_by_folder_id(
  174. folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
  175. ):
  176. try:
  177. limit = 60
  178. skip = (page - 1) * limit
  179. return [
  180. {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
  181. for chat in Chats.get_chats_by_folder_id_and_user_id(
  182. folder_id, user.id, skip=skip, limit=limit
  183. )
  184. ]
  185. except Exception as e:
  186. log.exception(e)
  187. raise HTTPException(
  188. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  189. )
  190. ############################
  191. # GetPinnedChats
  192. ############################
  193. @router.get("/pinned", response_model=list[ChatTitleIdResponse])
  194. async def get_user_pinned_chats(user=Depends(get_verified_user)):
  195. return [
  196. ChatTitleIdResponse(**chat.model_dump())
  197. for chat in Chats.get_pinned_chats_by_user_id(user.id)
  198. ]
  199. ############################
  200. # GetChats
  201. ############################
  202. @router.get("/all", response_model=list[ChatResponse])
  203. async def get_user_chats(user=Depends(get_verified_user)):
  204. return [
  205. ChatResponse(**chat.model_dump())
  206. for chat in Chats.get_chats_by_user_id(user.id)
  207. ]
  208. ############################
  209. # GetArchivedChats
  210. ############################
  211. @router.get("/all/archived", response_model=list[ChatResponse])
  212. async def get_user_archived_chats(user=Depends(get_verified_user)):
  213. return [
  214. ChatResponse(**chat.model_dump())
  215. for chat in Chats.get_archived_chats_by_user_id(user.id)
  216. ]
  217. ############################
  218. # GetAllTags
  219. ############################
  220. @router.get("/all/tags", response_model=list[TagModel])
  221. async def get_all_user_tags(user=Depends(get_verified_user)):
  222. try:
  223. tags = Tags.get_tags_by_user_id(user.id)
  224. return tags
  225. except Exception as e:
  226. log.exception(e)
  227. raise HTTPException(
  228. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  229. )
  230. ############################
  231. # GetAllChatsInDB
  232. ############################
  233. @router.get("/all/db", response_model=list[ChatResponse])
  234. async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
  235. if not ENABLE_ADMIN_EXPORT:
  236. raise HTTPException(
  237. status_code=status.HTTP_401_UNAUTHORIZED,
  238. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  239. )
  240. return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
  241. ############################
  242. # GetArchivedChats
  243. ############################
  244. @router.get("/archived", response_model=list[ChatTitleIdResponse])
  245. async def get_archived_session_user_chat_list(
  246. page: Optional[int] = None,
  247. query: Optional[str] = None,
  248. order_by: Optional[str] = None,
  249. direction: Optional[str] = None,
  250. user=Depends(get_verified_user),
  251. ):
  252. if page is None:
  253. page = 1
  254. limit = 60
  255. skip = (page - 1) * limit
  256. filter = {}
  257. if query:
  258. filter["query"] = query
  259. if order_by:
  260. filter["order_by"] = order_by
  261. if direction:
  262. filter["direction"] = direction
  263. chat_list = [
  264. ChatTitleIdResponse(**chat.model_dump())
  265. for chat in Chats.get_archived_chat_list_by_user_id(
  266. user.id,
  267. filter=filter,
  268. skip=skip,
  269. limit=limit,
  270. )
  271. ]
  272. return chat_list
  273. ############################
  274. # ArchiveAllChats
  275. ############################
  276. @router.post("/archive/all", response_model=bool)
  277. async def archive_all_chats(user=Depends(get_verified_user)):
  278. return Chats.archive_all_chats_by_user_id(user.id)
  279. ############################
  280. # UnarchiveAllChats
  281. ############################
  282. @router.post("/unarchive/all", response_model=bool)
  283. async def unarchive_all_chats(user=Depends(get_verified_user)):
  284. return Chats.unarchive_all_chats_by_user_id(user.id)
  285. ############################
  286. # GetSharedChatById
  287. ############################
  288. @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
  289. async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
  290. if user.role == "pending":
  291. raise HTTPException(
  292. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  293. )
  294. if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
  295. chat = Chats.get_chat_by_share_id(share_id)
  296. elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
  297. chat = Chats.get_chat_by_id(share_id)
  298. if chat:
  299. return ChatResponse(**chat.model_dump())
  300. else:
  301. raise HTTPException(
  302. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  303. )
  304. ############################
  305. # GetChatsByTags
  306. ############################
  307. class TagForm(BaseModel):
  308. name: str
  309. class TagFilterForm(TagForm):
  310. skip: Optional[int] = 0
  311. limit: Optional[int] = 50
  312. @router.post("/tags", response_model=list[ChatTitleIdResponse])
  313. async def get_user_chat_list_by_tag_name(
  314. form_data: TagFilterForm, user=Depends(get_verified_user)
  315. ):
  316. chats = Chats.get_chat_list_by_user_id_and_tag_name(
  317. user.id, form_data.name, form_data.skip, form_data.limit
  318. )
  319. if len(chats) == 0:
  320. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  321. return chats
  322. ############################
  323. # GetChatById
  324. ############################
  325. @router.get("/{id}", response_model=Optional[ChatResponse])
  326. async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
  327. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  328. if chat:
  329. return ChatResponse(**chat.model_dump())
  330. else:
  331. raise HTTPException(
  332. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  333. )
  334. ############################
  335. # UpdateChatById
  336. ############################
  337. @router.post("/{id}", response_model=Optional[ChatResponse])
  338. async def update_chat_by_id(
  339. id: str, form_data: ChatForm, user=Depends(get_verified_user)
  340. ):
  341. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  342. if chat:
  343. updated_chat = {**chat.chat, **form_data.chat}
  344. chat = Chats.update_chat_by_id(id, updated_chat)
  345. return ChatResponse(**chat.model_dump())
  346. else:
  347. raise HTTPException(
  348. status_code=status.HTTP_401_UNAUTHORIZED,
  349. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  350. )
  351. ############################
  352. # UpdateChatMessageById
  353. ############################
  354. class MessageForm(BaseModel):
  355. content: str
  356. @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
  357. async def update_chat_message_by_id(
  358. id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
  359. ):
  360. chat = Chats.get_chat_by_id(id)
  361. if not chat:
  362. raise HTTPException(
  363. status_code=status.HTTP_401_UNAUTHORIZED,
  364. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  365. )
  366. if chat.user_id != user.id and user.role != "admin":
  367. raise HTTPException(
  368. status_code=status.HTTP_401_UNAUTHORIZED,
  369. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  370. )
  371. chat = Chats.upsert_message_to_chat_by_id_and_message_id(
  372. id,
  373. message_id,
  374. {
  375. "content": form_data.content,
  376. },
  377. )
  378. event_emitter = get_event_emitter(
  379. {
  380. "user_id": user.id,
  381. "chat_id": id,
  382. "message_id": message_id,
  383. },
  384. False,
  385. )
  386. if event_emitter:
  387. await event_emitter(
  388. {
  389. "type": "chat:message",
  390. "data": {
  391. "chat_id": id,
  392. "message_id": message_id,
  393. "content": form_data.content,
  394. },
  395. }
  396. )
  397. return ChatResponse(**chat.model_dump())
  398. ############################
  399. # SendChatMessageEventById
  400. ############################
  401. class EventForm(BaseModel):
  402. type: str
  403. data: dict
  404. @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
  405. async def send_chat_message_event_by_id(
  406. id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
  407. ):
  408. chat = Chats.get_chat_by_id(id)
  409. if not chat:
  410. raise HTTPException(
  411. status_code=status.HTTP_401_UNAUTHORIZED,
  412. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  413. )
  414. if chat.user_id != user.id and user.role != "admin":
  415. raise HTTPException(
  416. status_code=status.HTTP_401_UNAUTHORIZED,
  417. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  418. )
  419. event_emitter = get_event_emitter(
  420. {
  421. "user_id": user.id,
  422. "chat_id": id,
  423. "message_id": message_id,
  424. }
  425. )
  426. try:
  427. if event_emitter:
  428. await event_emitter(form_data.model_dump())
  429. else:
  430. return False
  431. return True
  432. except:
  433. return False
  434. ############################
  435. # DeleteChatById
  436. ############################
  437. @router.delete("/{id}", response_model=bool)
  438. async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  439. if user.role == "admin":
  440. chat = Chats.get_chat_by_id(id)
  441. for tag in chat.meta.get("tags", []):
  442. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  443. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  444. result = Chats.delete_chat_by_id(id)
  445. return result
  446. else:
  447. if not has_permission(
  448. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  449. ):
  450. raise HTTPException(
  451. status_code=status.HTTP_401_UNAUTHORIZED,
  452. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  453. )
  454. chat = Chats.get_chat_by_id(id)
  455. for tag in chat.meta.get("tags", []):
  456. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  457. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  458. result = Chats.delete_chat_by_id_and_user_id(id, user.id)
  459. return result
  460. ############################
  461. # GetPinnedStatusById
  462. ############################
  463. @router.get("/{id}/pinned", response_model=Optional[bool])
  464. async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
  465. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  466. if chat:
  467. return chat.pinned
  468. else:
  469. raise HTTPException(
  470. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  471. )
  472. ############################
  473. # PinChatById
  474. ############################
  475. @router.post("/{id}/pin", response_model=Optional[ChatResponse])
  476. async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
  477. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  478. if chat:
  479. chat = Chats.toggle_chat_pinned_by_id(id)
  480. return chat
  481. else:
  482. raise HTTPException(
  483. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  484. )
  485. ############################
  486. # CloneChat
  487. ############################
  488. class CloneForm(BaseModel):
  489. title: Optional[str] = None
  490. @router.post("/{id}/clone", response_model=Optional[ChatResponse])
  491. async def clone_chat_by_id(
  492. form_data: CloneForm, id: str, user=Depends(get_verified_user)
  493. ):
  494. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  495. if chat:
  496. updated_chat = {
  497. **chat.chat,
  498. "originalChatId": chat.id,
  499. "branchPointMessageId": chat.chat["history"]["currentId"],
  500. "title": form_data.title if form_data.title else f"Clone of {chat.title}",
  501. }
  502. chat = Chats.import_chat(
  503. user.id,
  504. ChatImportForm(
  505. **{
  506. "chat": updated_chat,
  507. "meta": chat.meta,
  508. "pinned": chat.pinned,
  509. "folder_id": chat.folder_id,
  510. }
  511. ),
  512. )
  513. return ChatResponse(**chat.model_dump())
  514. else:
  515. raise HTTPException(
  516. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  517. )
  518. ############################
  519. # CloneSharedChatById
  520. ############################
  521. @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
  522. async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  523. if user.role == "admin":
  524. chat = Chats.get_chat_by_id(id)
  525. else:
  526. chat = Chats.get_chat_by_share_id(id)
  527. if chat:
  528. updated_chat = {
  529. **chat.chat,
  530. "originalChatId": chat.id,
  531. "branchPointMessageId": chat.chat["history"]["currentId"],
  532. "title": f"Clone of {chat.title}",
  533. }
  534. chat = Chats.import_chat(
  535. user.id,
  536. ChatImportForm(
  537. **{
  538. "chat": updated_chat,
  539. "meta": chat.meta,
  540. "pinned": chat.pinned,
  541. "folder_id": chat.folder_id,
  542. }
  543. ),
  544. )
  545. return ChatResponse(**chat.model_dump())
  546. else:
  547. raise HTTPException(
  548. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  549. )
  550. ############################
  551. # ArchiveChat
  552. ############################
  553. @router.post("/{id}/archive", response_model=Optional[ChatResponse])
  554. async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
  555. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  556. if chat:
  557. chat = Chats.toggle_chat_archive_by_id(id)
  558. # Delete tags if chat is archived
  559. if chat.archived:
  560. for tag_id in chat.meta.get("tags", []):
  561. if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
  562. log.debug(f"deleting tag: {tag_id}")
  563. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  564. else:
  565. for tag_id in chat.meta.get("tags", []):
  566. tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
  567. if tag is None:
  568. log.debug(f"inserting tag: {tag_id}")
  569. tag = Tags.insert_new_tag(tag_id, user.id)
  570. return ChatResponse(**chat.model_dump())
  571. else:
  572. raise HTTPException(
  573. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  574. )
  575. ############################
  576. # ShareChatById
  577. ############################
  578. @router.post("/{id}/share", response_model=Optional[ChatResponse])
  579. async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  580. if (user.role != "admin") and (
  581. not has_permission(
  582. user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
  583. )
  584. ):
  585. raise HTTPException(
  586. status_code=status.HTTP_401_UNAUTHORIZED,
  587. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  588. )
  589. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  590. if chat:
  591. if chat.share_id:
  592. shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
  593. return ChatResponse(**shared_chat.model_dump())
  594. shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
  595. if not shared_chat:
  596. raise HTTPException(
  597. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  598. detail=ERROR_MESSAGES.DEFAULT(),
  599. )
  600. return ChatResponse(**shared_chat.model_dump())
  601. else:
  602. raise HTTPException(
  603. status_code=status.HTTP_401_UNAUTHORIZED,
  604. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  605. )
  606. ############################
  607. # DeletedSharedChatById
  608. ############################
  609. @router.delete("/{id}/share", response_model=Optional[bool])
  610. async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  611. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  612. if chat:
  613. if not chat.share_id:
  614. return False
  615. result = Chats.delete_shared_chat_by_chat_id(id)
  616. update_result = Chats.update_chat_share_id_by_id(id, None)
  617. return result and update_result != None
  618. else:
  619. raise HTTPException(
  620. status_code=status.HTTP_401_UNAUTHORIZED,
  621. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  622. )
  623. ############################
  624. # UpdateChatFolderIdById
  625. ############################
  626. class ChatFolderIdForm(BaseModel):
  627. folder_id: Optional[str] = None
  628. @router.post("/{id}/folder", response_model=Optional[ChatResponse])
  629. async def update_chat_folder_id_by_id(
  630. id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
  631. ):
  632. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  633. if chat:
  634. chat = Chats.update_chat_folder_id_by_id_and_user_id(
  635. id, user.id, form_data.folder_id
  636. )
  637. return ChatResponse(**chat.model_dump())
  638. else:
  639. raise HTTPException(
  640. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  641. )
  642. ############################
  643. # GetChatTagsById
  644. ############################
  645. @router.get("/{id}/tags", response_model=list[TagModel])
  646. async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  647. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  648. if chat:
  649. tags = chat.meta.get("tags", [])
  650. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  651. else:
  652. raise HTTPException(
  653. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  654. )
  655. ############################
  656. # AddChatTagById
  657. ############################
  658. @router.post("/{id}/tags", response_model=list[TagModel])
  659. async def add_tag_by_id_and_tag_name(
  660. id: str, form_data: TagForm, user=Depends(get_verified_user)
  661. ):
  662. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  663. if chat:
  664. tags = chat.meta.get("tags", [])
  665. tag_id = form_data.name.replace(" ", "_").lower()
  666. if tag_id == "none":
  667. raise HTTPException(
  668. status_code=status.HTTP_400_BAD_REQUEST,
  669. detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
  670. )
  671. if tag_id not in tags:
  672. Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
  673. id, user.id, form_data.name
  674. )
  675. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  676. tags = chat.meta.get("tags", [])
  677. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  678. else:
  679. raise HTTPException(
  680. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  681. )
  682. ############################
  683. # DeleteChatTagById
  684. ############################
  685. @router.delete("/{id}/tags", response_model=list[TagModel])
  686. async def delete_tag_by_id_and_tag_name(
  687. id: str, form_data: TagForm, user=Depends(get_verified_user)
  688. ):
  689. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  690. if chat:
  691. Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
  692. if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
  693. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  694. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  695. tags = chat.meta.get("tags", [])
  696. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  697. else:
  698. raise HTTPException(
  699. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  700. )
  701. ############################
  702. # DeleteAllTagsById
  703. ############################
  704. @router.delete("/{id}/tags/all", response_model=Optional[bool])
  705. async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
  706. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  707. if chat:
  708. Chats.delete_all_tags_by_id_and_user_id(id, user.id)
  709. for tag in chat.meta.get("tags", []):
  710. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
  711. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  712. return True
  713. else:
  714. raise HTTPException(
  715. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  716. )