chats.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932
  1. import json
  2. import logging
  3. from typing import Optional
  4. from open_webui.socket.main import get_event_emitter
  5. from open_webui.models.chats import (
  6. ChatForm,
  7. ChatImportForm,
  8. ChatResponse,
  9. Chats,
  10. ChatTitleIdResponse,
  11. )
  12. from open_webui.models.tags import TagModel, Tags
  13. from open_webui.models.folders import Folders
  14. from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import SRC_LOG_LEVELS
  17. from fastapi import APIRouter, Depends, HTTPException, Request, status
  18. from pydantic import BaseModel
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.access_control import has_permission
  21. log = logging.getLogger(__name__)
  22. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  23. router = APIRouter()
  24. ############################
  25. # GetChatList
  26. ############################
  27. @router.get("/", response_model=list[ChatTitleIdResponse])
  28. @router.get("/list", response_model=list[ChatTitleIdResponse])
  29. def get_session_user_chat_list(
  30. user=Depends(get_verified_user),
  31. page: Optional[int] = None,
  32. include_pinned: Optional[bool] = False,
  33. include_folders: Optional[bool] = False,
  34. ):
  35. try:
  36. if page is not None:
  37. limit = 60
  38. skip = (page - 1) * limit
  39. return Chats.get_chat_title_id_list_by_user_id(
  40. user.id,
  41. include_folders=include_folders,
  42. include_pinned=include_pinned,
  43. skip=skip,
  44. limit=limit,
  45. )
  46. else:
  47. return Chats.get_chat_title_id_list_by_user_id(
  48. user.id, include_folders=include_folders, include_pinned=include_pinned
  49. )
  50. except Exception as e:
  51. log.exception(e)
  52. raise HTTPException(
  53. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  54. )
  55. ############################
  56. # DeleteAllChats
  57. ############################
  58. @router.delete("/", response_model=bool)
  59. async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
  60. if user.role == "user" and not has_permission(
  61. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  62. ):
  63. raise HTTPException(
  64. status_code=status.HTTP_401_UNAUTHORIZED,
  65. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  66. )
  67. result = Chats.delete_chats_by_user_id(user.id)
  68. return result
  69. ############################
  70. # GetUserChatList
  71. ############################
  72. @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
  73. async def get_user_chat_list_by_user_id(
  74. user_id: str,
  75. page: Optional[int] = None,
  76. query: Optional[str] = None,
  77. order_by: Optional[str] = None,
  78. direction: Optional[str] = None,
  79. user=Depends(get_admin_user),
  80. ):
  81. if not ENABLE_ADMIN_CHAT_ACCESS:
  82. raise HTTPException(
  83. status_code=status.HTTP_401_UNAUTHORIZED,
  84. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  85. )
  86. if page is None:
  87. page = 1
  88. limit = 60
  89. skip = (page - 1) * limit
  90. filter = {}
  91. if query:
  92. filter["query"] = query
  93. if order_by:
  94. filter["order_by"] = order_by
  95. if direction:
  96. filter["direction"] = direction
  97. return Chats.get_chat_list_by_user_id(
  98. user_id, include_archived=True, filter=filter, skip=skip, limit=limit
  99. )
  100. ############################
  101. # CreateNewChat
  102. ############################
  103. @router.post("/new", response_model=Optional[ChatResponse])
  104. async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
  105. try:
  106. chat = Chats.insert_new_chat(user.id, form_data)
  107. return ChatResponse(**chat.model_dump())
  108. except Exception as e:
  109. log.exception(e)
  110. raise HTTPException(
  111. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  112. )
  113. ############################
  114. # ImportChat
  115. ############################
  116. @router.post("/import", response_model=Optional[ChatResponse])
  117. async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
  118. try:
  119. chat = Chats.import_chat(user.id, form_data)
  120. if chat:
  121. tags = chat.meta.get("tags", [])
  122. for tag_id in tags:
  123. tag_id = tag_id.replace(" ", "_").lower()
  124. tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
  125. if (
  126. tag_id != "none"
  127. and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
  128. ):
  129. Tags.insert_new_tag(tag_name, user.id)
  130. return ChatResponse(**chat.model_dump())
  131. except Exception as e:
  132. log.exception(e)
  133. raise HTTPException(
  134. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  135. )
  136. ############################
  137. # GetChats
  138. ############################
  139. @router.get("/search", response_model=list[ChatTitleIdResponse])
  140. def search_user_chats(
  141. text: str, page: Optional[int] = None, user=Depends(get_verified_user)
  142. ):
  143. if page is None:
  144. page = 1
  145. limit = 60
  146. skip = (page - 1) * limit
  147. chat_list = [
  148. ChatTitleIdResponse(**chat.model_dump())
  149. for chat in Chats.get_chats_by_user_id_and_search_text(
  150. user.id, text, skip=skip, limit=limit
  151. )
  152. ]
  153. # Delete tag if no chat is found
  154. words = text.strip().split(" ")
  155. if page == 1 and len(words) == 1 and words[0].startswith("tag:"):
  156. tag_id = words[0].replace("tag:", "")
  157. if len(chat_list) == 0:
  158. if Tags.get_tag_by_name_and_user_id(tag_id, user.id):
  159. log.debug(f"deleting tag: {tag_id}")
  160. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  161. return chat_list
  162. ############################
  163. # GetChatsByFolderId
  164. ############################
  165. @router.get("/folder/{folder_id}", response_model=list[ChatResponse])
  166. async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
  167. folder_ids = [folder_id]
  168. children_folders = Folders.get_children_folders_by_id_and_user_id(
  169. folder_id, user.id
  170. )
  171. if children_folders:
  172. folder_ids.extend([folder.id for folder in children_folders])
  173. return [
  174. ChatResponse(**chat.model_dump())
  175. for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
  176. ]
  177. @router.get("/folder/{folder_id}/list")
  178. async def get_chat_list_by_folder_id(
  179. folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
  180. ):
  181. try:
  182. limit = 60
  183. skip = (page - 1) * limit
  184. return [
  185. {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
  186. for chat in Chats.get_chats_by_folder_id_and_user_id(
  187. folder_id, user.id, skip=skip, limit=limit
  188. )
  189. ]
  190. except Exception as e:
  191. log.exception(e)
  192. raise HTTPException(
  193. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  194. )
  195. ############################
  196. # GetPinnedChats
  197. ############################
  198. @router.get("/pinned", response_model=list[ChatTitleIdResponse])
  199. async def get_user_pinned_chats(user=Depends(get_verified_user)):
  200. return [
  201. ChatTitleIdResponse(**chat.model_dump())
  202. for chat in Chats.get_pinned_chats_by_user_id(user.id)
  203. ]
  204. ############################
  205. # GetChats
  206. ############################
  207. @router.get("/all", response_model=list[ChatResponse])
  208. async def get_user_chats(user=Depends(get_verified_user)):
  209. return [
  210. ChatResponse(**chat.model_dump())
  211. for chat in Chats.get_chats_by_user_id(user.id)
  212. ]
  213. ############################
  214. # GetArchivedChats
  215. ############################
  216. @router.get("/all/archived", response_model=list[ChatResponse])
  217. async def get_user_archived_chats(user=Depends(get_verified_user)):
  218. return [
  219. ChatResponse(**chat.model_dump())
  220. for chat in Chats.get_archived_chats_by_user_id(user.id)
  221. ]
  222. ############################
  223. # GetAllTags
  224. ############################
  225. @router.get("/all/tags", response_model=list[TagModel])
  226. async def get_all_user_tags(user=Depends(get_verified_user)):
  227. try:
  228. tags = Tags.get_tags_by_user_id(user.id)
  229. return tags
  230. except Exception as e:
  231. log.exception(e)
  232. raise HTTPException(
  233. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  234. )
  235. ############################
  236. # GetAllChatsInDB
  237. ############################
  238. @router.get("/all/db", response_model=list[ChatResponse])
  239. async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
  240. if not ENABLE_ADMIN_EXPORT:
  241. raise HTTPException(
  242. status_code=status.HTTP_401_UNAUTHORIZED,
  243. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  244. )
  245. return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
  246. ############################
  247. # GetArchivedChats
  248. ############################
  249. @router.get("/archived", response_model=list[ChatTitleIdResponse])
  250. async def get_archived_session_user_chat_list(
  251. page: Optional[int] = None,
  252. query: Optional[str] = None,
  253. order_by: Optional[str] = None,
  254. direction: Optional[str] = None,
  255. user=Depends(get_verified_user),
  256. ):
  257. if page is None:
  258. page = 1
  259. limit = 60
  260. skip = (page - 1) * limit
  261. filter = {}
  262. if query:
  263. filter["query"] = query
  264. if order_by:
  265. filter["order_by"] = order_by
  266. if direction:
  267. filter["direction"] = direction
  268. chat_list = [
  269. ChatTitleIdResponse(**chat.model_dump())
  270. for chat in Chats.get_archived_chat_list_by_user_id(
  271. user.id,
  272. filter=filter,
  273. skip=skip,
  274. limit=limit,
  275. )
  276. ]
  277. return chat_list
  278. ############################
  279. # ArchiveAllChats
  280. ############################
  281. @router.post("/archive/all", response_model=bool)
  282. async def archive_all_chats(user=Depends(get_verified_user)):
  283. return Chats.archive_all_chats_by_user_id(user.id)
  284. ############################
  285. # UnarchiveAllChats
  286. ############################
  287. @router.post("/unarchive/all", response_model=bool)
  288. async def unarchive_all_chats(user=Depends(get_verified_user)):
  289. return Chats.unarchive_all_chats_by_user_id(user.id)
  290. ############################
  291. # GetSharedChatById
  292. ############################
  293. @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
  294. async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
  295. if user.role == "pending":
  296. raise HTTPException(
  297. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  298. )
  299. if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
  300. chat = Chats.get_chat_by_share_id(share_id)
  301. elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
  302. chat = Chats.get_chat_by_id(share_id)
  303. if chat:
  304. return ChatResponse(**chat.model_dump())
  305. else:
  306. raise HTTPException(
  307. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  308. )
  309. ############################
  310. # GetChatsByTags
  311. ############################
  312. class TagForm(BaseModel):
  313. name: str
  314. class TagFilterForm(TagForm):
  315. skip: Optional[int] = 0
  316. limit: Optional[int] = 50
  317. @router.post("/tags", response_model=list[ChatTitleIdResponse])
  318. async def get_user_chat_list_by_tag_name(
  319. form_data: TagFilterForm, user=Depends(get_verified_user)
  320. ):
  321. chats = Chats.get_chat_list_by_user_id_and_tag_name(
  322. user.id, form_data.name, form_data.skip, form_data.limit
  323. )
  324. if len(chats) == 0:
  325. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  326. return chats
  327. ############################
  328. # GetChatById
  329. ############################
  330. @router.get("/{id}", response_model=Optional[ChatResponse])
  331. async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
  332. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  333. if chat:
  334. return ChatResponse(**chat.model_dump())
  335. else:
  336. raise HTTPException(
  337. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  338. )
  339. ############################
  340. # UpdateChatById
  341. ############################
  342. @router.post("/{id}", response_model=Optional[ChatResponse])
  343. async def update_chat_by_id(
  344. id: str, form_data: ChatForm, user=Depends(get_verified_user)
  345. ):
  346. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  347. if chat:
  348. updated_chat = {**chat.chat, **form_data.chat}
  349. chat = Chats.update_chat_by_id(id, updated_chat)
  350. return ChatResponse(**chat.model_dump())
  351. else:
  352. raise HTTPException(
  353. status_code=status.HTTP_401_UNAUTHORIZED,
  354. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  355. )
  356. ############################
  357. # UpdateChatMessageById
  358. ############################
  359. class MessageForm(BaseModel):
  360. content: str
  361. @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
  362. async def update_chat_message_by_id(
  363. id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
  364. ):
  365. chat = Chats.get_chat_by_id(id)
  366. if not chat:
  367. raise HTTPException(
  368. status_code=status.HTTP_401_UNAUTHORIZED,
  369. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  370. )
  371. if chat.user_id != user.id and user.role != "admin":
  372. raise HTTPException(
  373. status_code=status.HTTP_401_UNAUTHORIZED,
  374. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  375. )
  376. chat = Chats.upsert_message_to_chat_by_id_and_message_id(
  377. id,
  378. message_id,
  379. {
  380. "content": form_data.content,
  381. },
  382. )
  383. event_emitter = get_event_emitter(
  384. {
  385. "user_id": user.id,
  386. "chat_id": id,
  387. "message_id": message_id,
  388. },
  389. False,
  390. )
  391. if event_emitter:
  392. await event_emitter(
  393. {
  394. "type": "chat:message",
  395. "data": {
  396. "chat_id": id,
  397. "message_id": message_id,
  398. "content": form_data.content,
  399. },
  400. }
  401. )
  402. return ChatResponse(**chat.model_dump())
  403. ############################
  404. # SendChatMessageEventById
  405. ############################
  406. class EventForm(BaseModel):
  407. type: str
  408. data: dict
  409. @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
  410. async def send_chat_message_event_by_id(
  411. id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
  412. ):
  413. chat = Chats.get_chat_by_id(id)
  414. if not chat:
  415. raise HTTPException(
  416. status_code=status.HTTP_401_UNAUTHORIZED,
  417. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  418. )
  419. if chat.user_id != user.id and user.role != "admin":
  420. raise HTTPException(
  421. status_code=status.HTTP_401_UNAUTHORIZED,
  422. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  423. )
  424. event_emitter = get_event_emitter(
  425. {
  426. "user_id": user.id,
  427. "chat_id": id,
  428. "message_id": message_id,
  429. }
  430. )
  431. try:
  432. if event_emitter:
  433. await event_emitter(form_data.model_dump())
  434. else:
  435. return False
  436. return True
  437. except:
  438. return False
  439. ############################
  440. # DeleteChatById
  441. ############################
  442. @router.delete("/{id}", response_model=bool)
  443. async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  444. if user.role == "admin":
  445. chat = Chats.get_chat_by_id(id)
  446. for tag in chat.meta.get("tags", []):
  447. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  448. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  449. result = Chats.delete_chat_by_id(id)
  450. return result
  451. else:
  452. if not has_permission(
  453. user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
  454. ):
  455. raise HTTPException(
  456. status_code=status.HTTP_401_UNAUTHORIZED,
  457. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  458. )
  459. chat = Chats.get_chat_by_id(id)
  460. for tag in chat.meta.get("tags", []):
  461. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
  462. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  463. result = Chats.delete_chat_by_id_and_user_id(id, user.id)
  464. return result
  465. ############################
  466. # GetPinnedStatusById
  467. ############################
  468. @router.get("/{id}/pinned", response_model=Optional[bool])
  469. async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
  470. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  471. if chat:
  472. return chat.pinned
  473. else:
  474. raise HTTPException(
  475. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  476. )
  477. ############################
  478. # PinChatById
  479. ############################
  480. @router.post("/{id}/pin", response_model=Optional[ChatResponse])
  481. async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
  482. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  483. if chat:
  484. chat = Chats.toggle_chat_pinned_by_id(id)
  485. return chat
  486. else:
  487. raise HTTPException(
  488. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  489. )
  490. ############################
  491. # CloneChat
  492. ############################
  493. class CloneForm(BaseModel):
  494. title: Optional[str] = None
  495. @router.post("/{id}/clone", response_model=Optional[ChatResponse])
  496. async def clone_chat_by_id(
  497. form_data: CloneForm, id: str, user=Depends(get_verified_user)
  498. ):
  499. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  500. if chat:
  501. updated_chat = {
  502. **chat.chat,
  503. "originalChatId": chat.id,
  504. "branchPointMessageId": chat.chat["history"]["currentId"],
  505. "title": form_data.title if form_data.title else f"Clone of {chat.title}",
  506. }
  507. chat = Chats.import_chat(
  508. user.id,
  509. ChatImportForm(
  510. **{
  511. "chat": updated_chat,
  512. "meta": chat.meta,
  513. "pinned": chat.pinned,
  514. "folder_id": chat.folder_id,
  515. }
  516. ),
  517. )
  518. return ChatResponse(**chat.model_dump())
  519. else:
  520. raise HTTPException(
  521. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  522. )
  523. ############################
  524. # CloneSharedChatById
  525. ############################
  526. @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
  527. async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  528. if user.role == "admin":
  529. chat = Chats.get_chat_by_id(id)
  530. else:
  531. chat = Chats.get_chat_by_share_id(id)
  532. if chat:
  533. updated_chat = {
  534. **chat.chat,
  535. "originalChatId": chat.id,
  536. "branchPointMessageId": chat.chat["history"]["currentId"],
  537. "title": f"Clone of {chat.title}",
  538. }
  539. chat = Chats.import_chat(
  540. user.id,
  541. ChatImportForm(
  542. **{
  543. "chat": updated_chat,
  544. "meta": chat.meta,
  545. "pinned": chat.pinned,
  546. "folder_id": chat.folder_id,
  547. }
  548. ),
  549. )
  550. return ChatResponse(**chat.model_dump())
  551. else:
  552. raise HTTPException(
  553. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  554. )
  555. ############################
  556. # ArchiveChat
  557. ############################
  558. @router.post("/{id}/archive", response_model=Optional[ChatResponse])
  559. async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
  560. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  561. if chat:
  562. chat = Chats.toggle_chat_archive_by_id(id)
  563. # Delete tags if chat is archived
  564. if chat.archived:
  565. for tag_id in chat.meta.get("tags", []):
  566. if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0:
  567. log.debug(f"deleting tag: {tag_id}")
  568. Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
  569. else:
  570. for tag_id in chat.meta.get("tags", []):
  571. tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id)
  572. if tag is None:
  573. log.debug(f"inserting tag: {tag_id}")
  574. tag = Tags.insert_new_tag(tag_id, user.id)
  575. return ChatResponse(**chat.model_dump())
  576. else:
  577. raise HTTPException(
  578. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  579. )
  580. ############################
  581. # ShareChatById
  582. ############################
  583. @router.post("/{id}/share", response_model=Optional[ChatResponse])
  584. async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  585. if (user.role != "admin") and (
  586. not has_permission(
  587. user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
  588. )
  589. ):
  590. raise HTTPException(
  591. status_code=status.HTTP_401_UNAUTHORIZED,
  592. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  593. )
  594. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  595. if chat:
  596. if chat.share_id:
  597. shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
  598. return ChatResponse(**shared_chat.model_dump())
  599. shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
  600. if not shared_chat:
  601. raise HTTPException(
  602. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  603. detail=ERROR_MESSAGES.DEFAULT(),
  604. )
  605. return ChatResponse(**shared_chat.model_dump())
  606. else:
  607. raise HTTPException(
  608. status_code=status.HTTP_401_UNAUTHORIZED,
  609. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  610. )
  611. ############################
  612. # DeletedSharedChatById
  613. ############################
  614. @router.delete("/{id}/share", response_model=Optional[bool])
  615. async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  616. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  617. if chat:
  618. if not chat.share_id:
  619. return False
  620. result = Chats.delete_shared_chat_by_chat_id(id)
  621. update_result = Chats.update_chat_share_id_by_id(id, None)
  622. return result and update_result != None
  623. else:
  624. raise HTTPException(
  625. status_code=status.HTTP_401_UNAUTHORIZED,
  626. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  627. )
  628. ############################
  629. # UpdateChatFolderIdById
  630. ############################
  631. class ChatFolderIdForm(BaseModel):
  632. folder_id: Optional[str] = None
  633. @router.post("/{id}/folder", response_model=Optional[ChatResponse])
  634. async def update_chat_folder_id_by_id(
  635. id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
  636. ):
  637. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  638. if chat:
  639. chat = Chats.update_chat_folder_id_by_id_and_user_id(
  640. id, user.id, form_data.folder_id
  641. )
  642. return ChatResponse(**chat.model_dump())
  643. else:
  644. raise HTTPException(
  645. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  646. )
  647. ############################
  648. # GetChatTagsById
  649. ############################
  650. @router.get("/{id}/tags", response_model=list[TagModel])
  651. async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  652. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  653. if chat:
  654. tags = chat.meta.get("tags", [])
  655. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  656. else:
  657. raise HTTPException(
  658. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  659. )
  660. ############################
  661. # AddChatTagById
  662. ############################
  663. @router.post("/{id}/tags", response_model=list[TagModel])
  664. async def add_tag_by_id_and_tag_name(
  665. id: str, form_data: TagForm, user=Depends(get_verified_user)
  666. ):
  667. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  668. if chat:
  669. tags = chat.meta.get("tags", [])
  670. tag_id = form_data.name.replace(" ", "_").lower()
  671. if tag_id == "none":
  672. raise HTTPException(
  673. status_code=status.HTTP_400_BAD_REQUEST,
  674. detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
  675. )
  676. if tag_id not in tags:
  677. Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
  678. id, user.id, form_data.name
  679. )
  680. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  681. tags = chat.meta.get("tags", [])
  682. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  683. else:
  684. raise HTTPException(
  685. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  686. )
  687. ############################
  688. # DeleteChatTagById
  689. ############################
  690. @router.delete("/{id}/tags", response_model=list[TagModel])
  691. async def delete_tag_by_id_and_tag_name(
  692. id: str, form_data: TagForm, user=Depends(get_verified_user)
  693. ):
  694. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  695. if chat:
  696. Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
  697. if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
  698. Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
  699. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  700. tags = chat.meta.get("tags", [])
  701. return Tags.get_tags_by_ids_and_user_id(tags, user.id)
  702. else:
  703. raise HTTPException(
  704. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  705. )
  706. ############################
  707. # DeleteAllTagsById
  708. ############################
  709. @router.delete("/{id}/tags/all", response_model=Optional[bool])
  710. async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)):
  711. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  712. if chat:
  713. Chats.delete_all_tags_by_id_and_user_id(id, user.id)
  714. for tag in chat.meta.get("tags", []):
  715. if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
  716. Tags.delete_tag_by_name_and_user_id(tag, user.id)
  717. return True
  718. else:
  719. raise HTTPException(
  720. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  721. )