tasks.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773
  1. from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
  2. from fastapi.responses import JSONResponse, RedirectResponse
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. import logging
  6. import re
  7. from open_webui.utils.chat import generate_chat_completion
  8. from open_webui.utils.task import (
  9. title_generation_template,
  10. follow_up_generation_template,
  11. query_generation_template,
  12. image_prompt_generation_template,
  13. autocomplete_generation_template,
  14. tags_generation_template,
  15. emoji_generation_template,
  16. moa_response_generation_template,
  17. )
  18. from open_webui.utils.auth import get_admin_user, get_verified_user
  19. from open_webui.constants import TASKS
  20. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  21. from open_webui.utils.task import get_task_model_id
  22. from open_webui.config import (
  23. DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
  24. DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
  25. DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
  26. DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  27. DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
  28. DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
  29. DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
  30. DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
  31. )
  32. from open_webui.env import SRC_LOG_LEVELS
  33. log = logging.getLogger(__name__)
  34. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  35. router = APIRouter()
  36. ##################################
  37. #
  38. # Task Endpoints
  39. #
  40. ##################################
  41. @router.get("/config")
  42. async def get_task_config(request: Request, user=Depends(get_verified_user)):
  43. return {
  44. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  45. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  46. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  47. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  48. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  49. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  50. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  51. "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
  52. "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
  53. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  54. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  55. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  56. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  57. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  58. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  59. }
  60. class TaskConfigForm(BaseModel):
  61. TASK_MODEL: Optional[str]
  62. TASK_MODEL_EXTERNAL: Optional[str]
  63. ENABLE_TITLE_GENERATION: bool
  64. TITLE_GENERATION_PROMPT_TEMPLATE: str
  65. IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
  66. ENABLE_AUTOCOMPLETE_GENERATION: bool
  67. AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
  68. TAGS_GENERATION_PROMPT_TEMPLATE: str
  69. FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
  70. ENABLE_FOLLOW_UP_GENERATION: bool
  71. ENABLE_TAGS_GENERATION: bool
  72. ENABLE_SEARCH_QUERY_GENERATION: bool
  73. ENABLE_RETRIEVAL_QUERY_GENERATION: bool
  74. QUERY_GENERATION_PROMPT_TEMPLATE: str
  75. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  76. @router.post("/config/update")
  77. async def update_task_config(
  78. request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
  79. ):
  80. request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
  81. request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  82. request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
  83. request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  84. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  85. )
  86. request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
  87. form_data.ENABLE_FOLLOW_UP_GENERATION
  88. )
  89. request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
  90. form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
  91. )
  92. request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
  93. form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  94. )
  95. request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
  96. form_data.ENABLE_AUTOCOMPLETE_GENERATION
  97. )
  98. request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
  99. form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  100. )
  101. request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
  102. form_data.TAGS_GENERATION_PROMPT_TEMPLATE
  103. )
  104. request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
  105. request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
  106. form_data.ENABLE_SEARCH_QUERY_GENERATION
  107. )
  108. request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
  109. form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
  110. )
  111. request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
  112. form_data.QUERY_GENERATION_PROMPT_TEMPLATE
  113. )
  114. request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  115. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  116. )
  117. return {
  118. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  119. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  120. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  121. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  122. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  123. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  124. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  125. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  126. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  127. "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
  128. "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
  129. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  130. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  131. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  132. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  133. }
  134. @router.post("/title/completions")
  135. async def generate_title(
  136. request: Request, form_data: dict, user=Depends(get_verified_user)
  137. ):
  138. if not request.app.state.config.ENABLE_TITLE_GENERATION:
  139. return JSONResponse(
  140. status_code=status.HTTP_200_OK,
  141. content={"detail": "Title generation is disabled"},
  142. )
  143. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  144. models = {
  145. request.state.model["id"]: request.state.model,
  146. }
  147. else:
  148. models = request.app.state.MODELS
  149. model_id = form_data["model"]
  150. if model_id not in models:
  151. raise HTTPException(
  152. status_code=status.HTTP_404_NOT_FOUND,
  153. detail="Model not found",
  154. )
  155. # Check if the user has a custom task model
  156. # If the user has a custom task model, use that model
  157. task_model_id = get_task_model_id(
  158. model_id,
  159. request.app.state.config.TASK_MODEL,
  160. request.app.state.config.TASK_MODEL_EXTERNAL,
  161. models,
  162. )
  163. log.debug(
  164. f"generating chat title using model {task_model_id} for user {user.email} "
  165. )
  166. if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
  167. template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  168. else:
  169. template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
  170. content = title_generation_template(
  171. template,
  172. form_data["messages"],
  173. {
  174. "name": user.name,
  175. "location": user.info.get("location") if user.info else None,
  176. },
  177. )
  178. max_tokens = (
  179. models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
  180. )
  181. payload = {
  182. "model": task_model_id,
  183. "messages": [{"role": "user", "content": content}],
  184. "stream": False,
  185. **(
  186. {"max_tokens": max_tokens}
  187. if models[task_model_id].get("owned_by") == "ollama"
  188. else {
  189. "max_completion_tokens": max_tokens,
  190. }
  191. ),
  192. "metadata": {
  193. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  194. "task": str(TASKS.TITLE_GENERATION),
  195. "task_body": form_data,
  196. "chat_id": form_data.get("chat_id", None),
  197. },
  198. }
  199. # Process the payload through the pipeline
  200. try:
  201. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  202. except Exception as e:
  203. raise e
  204. try:
  205. return await generate_chat_completion(request, form_data=payload, user=user)
  206. except Exception as e:
  207. log.error("Exception occurred", exc_info=True)
  208. return JSONResponse(
  209. status_code=status.HTTP_400_BAD_REQUEST,
  210. content={"detail": "An internal error has occurred."},
  211. )
  212. @router.post("/follow_up/completions")
  213. async def generate_follow_ups(
  214. request: Request, form_data: dict, user=Depends(get_verified_user)
  215. ):
  216. if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
  217. return JSONResponse(
  218. status_code=status.HTTP_200_OK,
  219. content={"detail": "Follow-up generation is disabled"},
  220. )
  221. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  222. models = {
  223. request.state.model["id"]: request.state.model,
  224. }
  225. else:
  226. models = request.app.state.MODELS
  227. model_id = form_data["model"]
  228. if model_id not in models:
  229. raise HTTPException(
  230. status_code=status.HTTP_404_NOT_FOUND,
  231. detail="Model not found",
  232. )
  233. # Check if the user has a custom task model
  234. # If the user has a custom task model, use that model
  235. task_model_id = get_task_model_id(
  236. model_id,
  237. request.app.state.config.TASK_MODEL,
  238. request.app.state.config.TASK_MODEL_EXTERNAL,
  239. models,
  240. )
  241. log.debug(
  242. f"generating chat title using model {task_model_id} for user {user.email} "
  243. )
  244. if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
  245. template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
  246. else:
  247. template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
  248. content = follow_up_generation_template(
  249. template,
  250. form_data["messages"],
  251. {
  252. "name": user.name,
  253. "location": user.info.get("location") if user.info else None,
  254. },
  255. )
  256. payload = {
  257. "model": task_model_id,
  258. "messages": [{"role": "user", "content": content}],
  259. "stream": False,
  260. "metadata": {
  261. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  262. "task": str(TASKS.FOLLOW_UP_GENERATION),
  263. "task_body": form_data,
  264. "chat_id": form_data.get("chat_id", None),
  265. },
  266. }
  267. # Process the payload through the pipeline
  268. try:
  269. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  270. except Exception as e:
  271. raise e
  272. try:
  273. return await generate_chat_completion(request, form_data=payload, user=user)
  274. except Exception as e:
  275. log.error("Exception occurred", exc_info=True)
  276. return JSONResponse(
  277. status_code=status.HTTP_400_BAD_REQUEST,
  278. content={"detail": "An internal error has occurred."},
  279. )
  280. @router.post("/tags/completions")
  281. async def generate_chat_tags(
  282. request: Request, form_data: dict, user=Depends(get_verified_user)
  283. ):
  284. if not request.app.state.config.ENABLE_TAGS_GENERATION:
  285. return JSONResponse(
  286. status_code=status.HTTP_200_OK,
  287. content={"detail": "Tags generation is disabled"},
  288. )
  289. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  290. models = {
  291. request.state.model["id"]: request.state.model,
  292. }
  293. else:
  294. models = request.app.state.MODELS
  295. model_id = form_data["model"]
  296. if model_id not in models:
  297. raise HTTPException(
  298. status_code=status.HTTP_404_NOT_FOUND,
  299. detail="Model not found",
  300. )
  301. # Check if the user has a custom task model
  302. # If the user has a custom task model, use that model
  303. task_model_id = get_task_model_id(
  304. model_id,
  305. request.app.state.config.TASK_MODEL,
  306. request.app.state.config.TASK_MODEL_EXTERNAL,
  307. models,
  308. )
  309. log.debug(
  310. f"generating chat tags using model {task_model_id} for user {user.email} "
  311. )
  312. if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
  313. template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
  314. else:
  315. template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
  316. content = tags_generation_template(
  317. template, form_data["messages"], {"name": user.name}
  318. )
  319. payload = {
  320. "model": task_model_id,
  321. "messages": [{"role": "user", "content": content}],
  322. "stream": False,
  323. "metadata": {
  324. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  325. "task": str(TASKS.TAGS_GENERATION),
  326. "task_body": form_data,
  327. "chat_id": form_data.get("chat_id", None),
  328. },
  329. }
  330. # Process the payload through the pipeline
  331. try:
  332. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  333. except Exception as e:
  334. raise e
  335. try:
  336. return await generate_chat_completion(request, form_data=payload, user=user)
  337. except Exception as e:
  338. log.error(f"Error generating chat completion: {e}")
  339. return JSONResponse(
  340. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  341. content={"detail": "An internal error has occurred."},
  342. )
  343. @router.post("/image_prompt/completions")
  344. async def generate_image_prompt(
  345. request: Request, form_data: dict, user=Depends(get_verified_user)
  346. ):
  347. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  348. models = {
  349. request.state.model["id"]: request.state.model,
  350. }
  351. else:
  352. models = request.app.state.MODELS
  353. model_id = form_data["model"]
  354. if model_id not in models:
  355. raise HTTPException(
  356. status_code=status.HTTP_404_NOT_FOUND,
  357. detail="Model not found",
  358. )
  359. # Check if the user has a custom task model
  360. # If the user has a custom task model, use that model
  361. task_model_id = get_task_model_id(
  362. model_id,
  363. request.app.state.config.TASK_MODEL,
  364. request.app.state.config.TASK_MODEL_EXTERNAL,
  365. models,
  366. )
  367. log.debug(
  368. f"generating image prompt using model {task_model_id} for user {user.email} "
  369. )
  370. if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
  371. template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  372. else:
  373. template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  374. content = image_prompt_generation_template(
  375. template,
  376. form_data["messages"],
  377. user={
  378. "name": user.name,
  379. },
  380. )
  381. payload = {
  382. "model": task_model_id,
  383. "messages": [{"role": "user", "content": content}],
  384. "stream": False,
  385. "metadata": {
  386. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  387. "task": str(TASKS.IMAGE_PROMPT_GENERATION),
  388. "task_body": form_data,
  389. "chat_id": form_data.get("chat_id", None),
  390. },
  391. }
  392. # Process the payload through the pipeline
  393. try:
  394. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  395. except Exception as e:
  396. raise e
  397. try:
  398. return await generate_chat_completion(request, form_data=payload, user=user)
  399. except Exception as e:
  400. log.error("Exception occurred", exc_info=True)
  401. return JSONResponse(
  402. status_code=status.HTTP_400_BAD_REQUEST,
  403. content={"detail": "An internal error has occurred."},
  404. )
  405. @router.post("/queries/completions")
  406. async def generate_queries(
  407. request: Request, form_data: dict, user=Depends(get_verified_user)
  408. ):
  409. type = form_data.get("type")
  410. if type == "web_search":
  411. if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
  412. raise HTTPException(
  413. status_code=status.HTTP_400_BAD_REQUEST,
  414. detail=f"Search query generation is disabled",
  415. )
  416. elif type == "retrieval":
  417. if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
  418. raise HTTPException(
  419. status_code=status.HTTP_400_BAD_REQUEST,
  420. detail=f"Query generation is disabled",
  421. )
  422. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  423. models = {
  424. request.state.model["id"]: request.state.model,
  425. }
  426. else:
  427. models = request.app.state.MODELS
  428. model_id = form_data["model"]
  429. if model_id not in models:
  430. raise HTTPException(
  431. status_code=status.HTTP_404_NOT_FOUND,
  432. detail="Model not found",
  433. )
  434. # Check if the user has a custom task model
  435. # If the user has a custom task model, use that model
  436. task_model_id = get_task_model_id(
  437. model_id,
  438. request.app.state.config.TASK_MODEL,
  439. request.app.state.config.TASK_MODEL_EXTERNAL,
  440. models,
  441. )
  442. log.debug(
  443. f"generating {type} queries using model {task_model_id} for user {user.email}"
  444. )
  445. if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
  446. template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
  447. else:
  448. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
  449. content = query_generation_template(
  450. template, form_data["messages"], {"name": user.name}
  451. )
  452. payload = {
  453. "model": task_model_id,
  454. "messages": [{"role": "user", "content": content}],
  455. "stream": False,
  456. "metadata": {
  457. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  458. "task": str(TASKS.QUERY_GENERATION),
  459. "task_body": form_data,
  460. "chat_id": form_data.get("chat_id", None),
  461. },
  462. }
  463. # Process the payload through the pipeline
  464. try:
  465. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  466. except Exception as e:
  467. raise e
  468. try:
  469. return await generate_chat_completion(request, form_data=payload, user=user)
  470. except Exception as e:
  471. return JSONResponse(
  472. status_code=status.HTTP_400_BAD_REQUEST,
  473. content={"detail": str(e)},
  474. )
  475. @router.post("/auto/completions")
  476. async def generate_autocompletion(
  477. request: Request, form_data: dict, user=Depends(get_verified_user)
  478. ):
  479. if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
  480. raise HTTPException(
  481. status_code=status.HTTP_400_BAD_REQUEST,
  482. detail=f"Autocompletion generation is disabled",
  483. )
  484. type = form_data.get("type")
  485. prompt = form_data.get("prompt")
  486. messages = form_data.get("messages")
  487. if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
  488. if (
  489. len(prompt)
  490. > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  491. ):
  492. raise HTTPException(
  493. status_code=status.HTTP_400_BAD_REQUEST,
  494. detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
  495. )
  496. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  497. models = {
  498. request.state.model["id"]: request.state.model,
  499. }
  500. else:
  501. models = request.app.state.MODELS
  502. model_id = form_data["model"]
  503. if model_id not in models:
  504. raise HTTPException(
  505. status_code=status.HTTP_404_NOT_FOUND,
  506. detail="Model not found",
  507. )
  508. # Check if the user has a custom task model
  509. # If the user has a custom task model, use that model
  510. task_model_id = get_task_model_id(
  511. model_id,
  512. request.app.state.config.TASK_MODEL,
  513. request.app.state.config.TASK_MODEL_EXTERNAL,
  514. models,
  515. )
  516. log.debug(
  517. f"generating autocompletion using model {task_model_id} for user {user.email}"
  518. )
  519. if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
  520. template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  521. else:
  522. template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  523. content = autocomplete_generation_template(
  524. template, prompt, messages, type, {"name": user.name}
  525. )
  526. payload = {
  527. "model": task_model_id,
  528. "messages": [{"role": "user", "content": content}],
  529. "stream": False,
  530. "metadata": {
  531. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  532. "task": str(TASKS.AUTOCOMPLETE_GENERATION),
  533. "task_body": form_data,
  534. "chat_id": form_data.get("chat_id", None),
  535. },
  536. }
  537. # Process the payload through the pipeline
  538. try:
  539. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  540. except Exception as e:
  541. raise e
  542. try:
  543. return await generate_chat_completion(request, form_data=payload, user=user)
  544. except Exception as e:
  545. log.error(f"Error generating chat completion: {e}")
  546. return JSONResponse(
  547. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  548. content={"detail": "An internal error has occurred."},
  549. )
  550. @router.post("/emoji/completions")
  551. async def generate_emoji(
  552. request: Request, form_data: dict, user=Depends(get_verified_user)
  553. ):
  554. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  555. models = {
  556. request.state.model["id"]: request.state.model,
  557. }
  558. else:
  559. models = request.app.state.MODELS
  560. model_id = form_data["model"]
  561. if model_id not in models:
  562. raise HTTPException(
  563. status_code=status.HTTP_404_NOT_FOUND,
  564. detail="Model not found",
  565. )
  566. # Check if the user has a custom task model
  567. # If the user has a custom task model, use that model
  568. task_model_id = get_task_model_id(
  569. model_id,
  570. request.app.state.config.TASK_MODEL,
  571. request.app.state.config.TASK_MODEL_EXTERNAL,
  572. models,
  573. )
  574. log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
  575. template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
  576. content = emoji_generation_template(
  577. template,
  578. form_data["prompt"],
  579. {
  580. "name": user.name,
  581. "location": user.info.get("location") if user.info else None,
  582. },
  583. )
  584. payload = {
  585. "model": task_model_id,
  586. "messages": [{"role": "user", "content": content}],
  587. "stream": False,
  588. **(
  589. {"max_tokens": 4}
  590. if models[task_model_id].get("owned_by") == "ollama"
  591. else {
  592. "max_completion_tokens": 4,
  593. }
  594. ),
  595. "chat_id": form_data.get("chat_id", None),
  596. "metadata": {
  597. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  598. "task": str(TASKS.EMOJI_GENERATION),
  599. "task_body": form_data,
  600. },
  601. }
  602. # Process the payload through the pipeline
  603. try:
  604. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  605. except Exception as e:
  606. raise e
  607. try:
  608. return await generate_chat_completion(request, form_data=payload, user=user)
  609. except Exception as e:
  610. return JSONResponse(
  611. status_code=status.HTTP_400_BAD_REQUEST,
  612. content={"detail": str(e)},
  613. )
  614. @router.post("/moa/completions")
  615. async def generate_moa_response(
  616. request: Request, form_data: dict, user=Depends(get_verified_user)
  617. ):
  618. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  619. models = {
  620. request.state.model["id"]: request.state.model,
  621. }
  622. else:
  623. models = request.app.state.MODELS
  624. model_id = form_data["model"]
  625. if model_id not in models:
  626. raise HTTPException(
  627. status_code=status.HTTP_404_NOT_FOUND,
  628. detail="Model not found",
  629. )
  630. template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
  631. content = moa_response_generation_template(
  632. template,
  633. form_data["prompt"],
  634. form_data["responses"],
  635. )
  636. payload = {
  637. "model": model_id,
  638. "messages": [{"role": "user", "content": content}],
  639. "stream": form_data.get("stream", False),
  640. "metadata": {
  641. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  642. "chat_id": form_data.get("chat_id", None),
  643. "task": str(TASKS.MOA_RESPONSE_GENERATION),
  644. "task_body": form_data,
  645. },
  646. }
  647. # Process the payload through the pipeline
  648. try:
  649. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  650. except Exception as e:
  651. raise e
  652. try:
  653. return await generate_chat_completion(request, form_data=payload, user=user)
  654. except Exception as e:
  655. return JSONResponse(
  656. status_code=status.HTTP_400_BAD_REQUEST,
  657. content={"detail": str(e)},
  658. )