1
0

tasks.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
  2. from fastapi.responses import JSONResponse, RedirectResponse
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. import logging
  6. import re
  7. from open_webui.utils.chat import generate_chat_completion
  8. from open_webui.utils.task import (
  9. title_generation_template,
  10. follow_up_generation_template,
  11. query_generation_template,
  12. image_prompt_generation_template,
  13. autocomplete_generation_template,
  14. tags_generation_template,
  15. emoji_generation_template,
  16. moa_response_generation_template,
  17. )
  18. from open_webui.utils.auth import get_admin_user, get_verified_user
  19. from open_webui.constants import TASKS
  20. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  21. from open_webui.utils.task import get_task_model_id
  22. from open_webui.config import (
  23. DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
  24. DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
  25. DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
  26. DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  27. DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
  28. DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
  29. DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
  30. DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
  31. )
  32. from open_webui.env import SRC_LOG_LEVELS
  33. log = logging.getLogger(__name__)
  34. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  35. router = APIRouter()
  36. ##################################
  37. #
  38. # Task Endpoints
  39. #
  40. ##################################
  41. @router.get("/config")
  42. async def get_task_config(request: Request, user=Depends(get_verified_user)):
  43. return {
  44. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  45. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  46. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  47. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  48. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  49. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  50. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  51. "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
  52. "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
  53. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  54. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  55. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  56. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  57. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  58. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  59. }
  60. class TaskConfigForm(BaseModel):
  61. TASK_MODEL: Optional[str]
  62. TASK_MODEL_EXTERNAL: Optional[str]
  63. ENABLE_TITLE_GENERATION: bool
  64. TITLE_GENERATION_PROMPT_TEMPLATE: str
  65. IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
  66. ENABLE_AUTOCOMPLETE_GENERATION: bool
  67. AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
  68. TAGS_GENERATION_PROMPT_TEMPLATE: str
  69. FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
  70. ENABLE_FOLLOW_UP_GENERATION: bool
  71. ENABLE_TAGS_GENERATION: bool
  72. ENABLE_SEARCH_QUERY_GENERATION: bool
  73. ENABLE_RETRIEVAL_QUERY_GENERATION: bool
  74. QUERY_GENERATION_PROMPT_TEMPLATE: str
  75. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  76. @router.post("/config/update")
  77. async def update_task_config(
  78. request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
  79. ):
  80. request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
  81. request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  82. request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
  83. request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  84. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  85. )
  86. request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
  87. form_data.ENABLE_FOLLOW_UP_GENERATION
  88. )
  89. request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
  90. form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
  91. )
  92. request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
  93. form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  94. )
  95. request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
  96. form_data.ENABLE_AUTOCOMPLETE_GENERATION
  97. )
  98. request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
  99. form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  100. )
  101. request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
  102. form_data.TAGS_GENERATION_PROMPT_TEMPLATE
  103. )
  104. request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
  105. request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
  106. form_data.ENABLE_SEARCH_QUERY_GENERATION
  107. )
  108. request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
  109. form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
  110. )
  111. request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
  112. form_data.QUERY_GENERATION_PROMPT_TEMPLATE
  113. )
  114. request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  115. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  116. )
  117. return {
  118. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  119. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  120. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  121. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  122. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  123. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  124. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  125. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  126. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  127. "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
  128. "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
  129. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  130. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  131. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  132. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  133. }
  134. @router.post("/title/completions")
  135. async def generate_title(
  136. request: Request, form_data: dict, user=Depends(get_verified_user)
  137. ):
  138. if not request.app.state.config.ENABLE_TITLE_GENERATION:
  139. return JSONResponse(
  140. status_code=status.HTTP_200_OK,
  141. content={"detail": "Title generation is disabled"},
  142. )
  143. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  144. models = {
  145. request.state.model["id"]: request.state.model,
  146. }
  147. else:
  148. models = request.app.state.MODELS
  149. model_id = form_data["model"]
  150. if model_id not in models:
  151. raise HTTPException(
  152. status_code=status.HTTP_404_NOT_FOUND,
  153. detail="Model not found",
  154. )
  155. # Check if the user has a custom task model
  156. # If the user has a custom task model, use that model
  157. task_model_id = get_task_model_id(
  158. model_id,
  159. request.app.state.config.TASK_MODEL,
  160. request.app.state.config.TASK_MODEL_EXTERNAL,
  161. models,
  162. )
  163. log.debug(
  164. f"generating chat title using model {task_model_id} for user {user.email} "
  165. )
  166. if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
  167. template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  168. else:
  169. template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
  170. content = title_generation_template(template, form_data["messages"], user)
  171. max_tokens = (
  172. models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
  173. )
  174. payload = {
  175. "model": task_model_id,
  176. "messages": [{"role": "user", "content": content}],
  177. "stream": False,
  178. **(
  179. {"max_tokens": max_tokens}
  180. if models[task_model_id].get("owned_by") == "ollama"
  181. else {
  182. "max_completion_tokens": max_tokens,
  183. }
  184. ),
  185. "metadata": {
  186. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  187. "task": str(TASKS.TITLE_GENERATION),
  188. "task_body": form_data,
  189. "chat_id": form_data.get("chat_id", None),
  190. },
  191. }
  192. # Process the payload through the pipeline
  193. try:
  194. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  195. except Exception as e:
  196. raise e
  197. try:
  198. return await generate_chat_completion(request, form_data=payload, user=user)
  199. except Exception as e:
  200. log.error("Exception occurred", exc_info=True)
  201. return JSONResponse(
  202. status_code=status.HTTP_400_BAD_REQUEST,
  203. content={"detail": "An internal error has occurred."},
  204. )
  205. @router.post("/follow_up/completions")
  206. async def generate_follow_ups(
  207. request: Request, form_data: dict, user=Depends(get_verified_user)
  208. ):
  209. if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
  210. return JSONResponse(
  211. status_code=status.HTTP_200_OK,
  212. content={"detail": "Follow-up generation is disabled"},
  213. )
  214. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  215. models = {
  216. request.state.model["id"]: request.state.model,
  217. }
  218. else:
  219. models = request.app.state.MODELS
  220. model_id = form_data["model"]
  221. if model_id not in models:
  222. raise HTTPException(
  223. status_code=status.HTTP_404_NOT_FOUND,
  224. detail="Model not found",
  225. )
  226. # Check if the user has a custom task model
  227. # If the user has a custom task model, use that model
  228. task_model_id = get_task_model_id(
  229. model_id,
  230. request.app.state.config.TASK_MODEL,
  231. request.app.state.config.TASK_MODEL_EXTERNAL,
  232. models,
  233. )
  234. log.debug(
  235. f"generating chat title using model {task_model_id} for user {user.email} "
  236. )
  237. if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
  238. template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
  239. else:
  240. template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
  241. content = follow_up_generation_template(template, form_data["messages"], user)
  242. payload = {
  243. "model": task_model_id,
  244. "messages": [{"role": "user", "content": content}],
  245. "stream": False,
  246. "metadata": {
  247. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  248. "task": str(TASKS.FOLLOW_UP_GENERATION),
  249. "task_body": form_data,
  250. "chat_id": form_data.get("chat_id", None),
  251. },
  252. }
  253. # Process the payload through the pipeline
  254. try:
  255. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  256. except Exception as e:
  257. raise e
  258. try:
  259. return await generate_chat_completion(request, form_data=payload, user=user)
  260. except Exception as e:
  261. log.error("Exception occurred", exc_info=True)
  262. return JSONResponse(
  263. status_code=status.HTTP_400_BAD_REQUEST,
  264. content={"detail": "An internal error has occurred."},
  265. )
  266. @router.post("/tags/completions")
  267. async def generate_chat_tags(
  268. request: Request, form_data: dict, user=Depends(get_verified_user)
  269. ):
  270. if not request.app.state.config.ENABLE_TAGS_GENERATION:
  271. return JSONResponse(
  272. status_code=status.HTTP_200_OK,
  273. content={"detail": "Tags generation is disabled"},
  274. )
  275. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  276. models = {
  277. request.state.model["id"]: request.state.model,
  278. }
  279. else:
  280. models = request.app.state.MODELS
  281. model_id = form_data["model"]
  282. if model_id not in models:
  283. raise HTTPException(
  284. status_code=status.HTTP_404_NOT_FOUND,
  285. detail="Model not found",
  286. )
  287. # Check if the user has a custom task model
  288. # If the user has a custom task model, use that model
  289. task_model_id = get_task_model_id(
  290. model_id,
  291. request.app.state.config.TASK_MODEL,
  292. request.app.state.config.TASK_MODEL_EXTERNAL,
  293. models,
  294. )
  295. log.debug(
  296. f"generating chat tags using model {task_model_id} for user {user.email} "
  297. )
  298. if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
  299. template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
  300. else:
  301. template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
  302. content = tags_generation_template(template, form_data["messages"], user)
  303. payload = {
  304. "model": task_model_id,
  305. "messages": [{"role": "user", "content": content}],
  306. "stream": False,
  307. "metadata": {
  308. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  309. "task": str(TASKS.TAGS_GENERATION),
  310. "task_body": form_data,
  311. "chat_id": form_data.get("chat_id", None),
  312. },
  313. }
  314. # Process the payload through the pipeline
  315. try:
  316. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  317. except Exception as e:
  318. raise e
  319. try:
  320. return await generate_chat_completion(request, form_data=payload, user=user)
  321. except Exception as e:
  322. log.error(f"Error generating chat completion: {e}")
  323. return JSONResponse(
  324. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  325. content={"detail": "An internal error has occurred."},
  326. )
  327. @router.post("/image_prompt/completions")
  328. async def generate_image_prompt(
  329. request: Request, form_data: dict, user=Depends(get_verified_user)
  330. ):
  331. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  332. models = {
  333. request.state.model["id"]: request.state.model,
  334. }
  335. else:
  336. models = request.app.state.MODELS
  337. model_id = form_data["model"]
  338. if model_id not in models:
  339. raise HTTPException(
  340. status_code=status.HTTP_404_NOT_FOUND,
  341. detail="Model not found",
  342. )
  343. # Check if the user has a custom task model
  344. # If the user has a custom task model, use that model
  345. task_model_id = get_task_model_id(
  346. model_id,
  347. request.app.state.config.TASK_MODEL,
  348. request.app.state.config.TASK_MODEL_EXTERNAL,
  349. models,
  350. )
  351. log.debug(
  352. f"generating image prompt using model {task_model_id} for user {user.email} "
  353. )
  354. if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
  355. template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  356. else:
  357. template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  358. content = image_prompt_generation_template(template, form_data["messages"], user)
  359. payload = {
  360. "model": task_model_id,
  361. "messages": [{"role": "user", "content": content}],
  362. "stream": False,
  363. "metadata": {
  364. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  365. "task": str(TASKS.IMAGE_PROMPT_GENERATION),
  366. "task_body": form_data,
  367. "chat_id": form_data.get("chat_id", None),
  368. },
  369. }
  370. # Process the payload through the pipeline
  371. try:
  372. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  373. except Exception as e:
  374. raise e
  375. try:
  376. return await generate_chat_completion(request, form_data=payload, user=user)
  377. except Exception as e:
  378. log.error("Exception occurred", exc_info=True)
  379. return JSONResponse(
  380. status_code=status.HTTP_400_BAD_REQUEST,
  381. content={"detail": "An internal error has occurred."},
  382. )
  383. @router.post("/queries/completions")
  384. async def generate_queries(
  385. request: Request, form_data: dict, user=Depends(get_verified_user)
  386. ):
  387. type = form_data.get("type")
  388. if type == "web_search":
  389. if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
  390. raise HTTPException(
  391. status_code=status.HTTP_400_BAD_REQUEST,
  392. detail=f"Search query generation is disabled",
  393. )
  394. elif type == "retrieval":
  395. if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
  396. raise HTTPException(
  397. status_code=status.HTTP_400_BAD_REQUEST,
  398. detail=f"Query generation is disabled",
  399. )
  400. if getattr(request.state, "cached_queries", None):
  401. log.info(f"Reusing cached queries: {request.state.cached_queries}")
  402. return request.state.cached_queries
  403. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  404. models = {
  405. request.state.model["id"]: request.state.model,
  406. }
  407. else:
  408. models = request.app.state.MODELS
  409. model_id = form_data["model"]
  410. if model_id not in models:
  411. raise HTTPException(
  412. status_code=status.HTTP_404_NOT_FOUND,
  413. detail="Model not found",
  414. )
  415. # Check if the user has a custom task model
  416. # If the user has a custom task model, use that model
  417. task_model_id = get_task_model_id(
  418. model_id,
  419. request.app.state.config.TASK_MODEL,
  420. request.app.state.config.TASK_MODEL_EXTERNAL,
  421. models,
  422. )
  423. log.debug(
  424. f"generating {type} queries using model {task_model_id} for user {user.email}"
  425. )
  426. if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
  427. template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
  428. else:
  429. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
  430. content = query_generation_template(template, form_data["messages"], user)
  431. payload = {
  432. "model": task_model_id,
  433. "messages": [{"role": "user", "content": content}],
  434. "stream": False,
  435. "metadata": {
  436. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  437. "task": str(TASKS.QUERY_GENERATION),
  438. "task_body": form_data,
  439. "chat_id": form_data.get("chat_id", None),
  440. },
  441. }
  442. # Process the payload through the pipeline
  443. try:
  444. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  445. except Exception as e:
  446. raise e
  447. try:
  448. return await generate_chat_completion(request, form_data=payload, user=user)
  449. except Exception as e:
  450. return JSONResponse(
  451. status_code=status.HTTP_400_BAD_REQUEST,
  452. content={"detail": str(e)},
  453. )
  454. @router.post("/auto/completions")
  455. async def generate_autocompletion(
  456. request: Request, form_data: dict, user=Depends(get_verified_user)
  457. ):
  458. if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
  459. raise HTTPException(
  460. status_code=status.HTTP_400_BAD_REQUEST,
  461. detail=f"Autocompletion generation is disabled",
  462. )
  463. type = form_data.get("type")
  464. prompt = form_data.get("prompt")
  465. messages = form_data.get("messages")
  466. if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
  467. if (
  468. len(prompt)
  469. > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  470. ):
  471. raise HTTPException(
  472. status_code=status.HTTP_400_BAD_REQUEST,
  473. detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
  474. )
  475. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  476. models = {
  477. request.state.model["id"]: request.state.model,
  478. }
  479. else:
  480. models = request.app.state.MODELS
  481. model_id = form_data["model"]
  482. if model_id not in models:
  483. raise HTTPException(
  484. status_code=status.HTTP_404_NOT_FOUND,
  485. detail="Model not found",
  486. )
  487. # Check if the user has a custom task model
  488. # If the user has a custom task model, use that model
  489. task_model_id = get_task_model_id(
  490. model_id,
  491. request.app.state.config.TASK_MODEL,
  492. request.app.state.config.TASK_MODEL_EXTERNAL,
  493. models,
  494. )
  495. log.debug(
  496. f"generating autocompletion using model {task_model_id} for user {user.email}"
  497. )
  498. if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
  499. template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  500. else:
  501. template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  502. content = autocomplete_generation_template(template, prompt, messages, type, user)
  503. payload = {
  504. "model": task_model_id,
  505. "messages": [{"role": "user", "content": content}],
  506. "stream": False,
  507. "metadata": {
  508. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  509. "task": str(TASKS.AUTOCOMPLETE_GENERATION),
  510. "task_body": form_data,
  511. "chat_id": form_data.get("chat_id", None),
  512. },
  513. }
  514. # Process the payload through the pipeline
  515. try:
  516. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  517. except Exception as e:
  518. raise e
  519. try:
  520. return await generate_chat_completion(request, form_data=payload, user=user)
  521. except Exception as e:
  522. log.error(f"Error generating chat completion: {e}")
  523. return JSONResponse(
  524. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  525. content={"detail": "An internal error has occurred."},
  526. )
  527. @router.post("/emoji/completions")
  528. async def generate_emoji(
  529. request: Request, form_data: dict, user=Depends(get_verified_user)
  530. ):
  531. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  532. models = {
  533. request.state.model["id"]: request.state.model,
  534. }
  535. else:
  536. models = request.app.state.MODELS
  537. model_id = form_data["model"]
  538. if model_id not in models:
  539. raise HTTPException(
  540. status_code=status.HTTP_404_NOT_FOUND,
  541. detail="Model not found",
  542. )
  543. # Check if the user has a custom task model
  544. # If the user has a custom task model, use that model
  545. task_model_id = get_task_model_id(
  546. model_id,
  547. request.app.state.config.TASK_MODEL,
  548. request.app.state.config.TASK_MODEL_EXTERNAL,
  549. models,
  550. )
  551. log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
  552. template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
  553. content = emoji_generation_template(template, form_data["prompt"], user)
  554. payload = {
  555. "model": task_model_id,
  556. "messages": [{"role": "user", "content": content}],
  557. "stream": False,
  558. **(
  559. {"max_tokens": 4}
  560. if models[task_model_id].get("owned_by") == "ollama"
  561. else {
  562. "max_completion_tokens": 4,
  563. }
  564. ),
  565. "metadata": {
  566. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  567. "task": str(TASKS.EMOJI_GENERATION),
  568. "task_body": form_data,
  569. "chat_id": form_data.get("chat_id", None),
  570. },
  571. }
  572. # Process the payload through the pipeline
  573. try:
  574. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  575. except Exception as e:
  576. raise e
  577. try:
  578. return await generate_chat_completion(request, form_data=payload, user=user)
  579. except Exception as e:
  580. return JSONResponse(
  581. status_code=status.HTTP_400_BAD_REQUEST,
  582. content={"detail": str(e)},
  583. )
  584. @router.post("/moa/completions")
  585. async def generate_moa_response(
  586. request: Request, form_data: dict, user=Depends(get_verified_user)
  587. ):
  588. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  589. models = {
  590. request.state.model["id"]: request.state.model,
  591. }
  592. else:
  593. models = request.app.state.MODELS
  594. model_id = form_data["model"]
  595. if model_id not in models:
  596. raise HTTPException(
  597. status_code=status.HTTP_404_NOT_FOUND,
  598. detail="Model not found",
  599. )
  600. template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
  601. content = moa_response_generation_template(
  602. template,
  603. form_data["prompt"],
  604. form_data["responses"],
  605. )
  606. payload = {
  607. "model": model_id,
  608. "messages": [{"role": "user", "content": content}],
  609. "stream": form_data.get("stream", False),
  610. "metadata": {
  611. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  612. "chat_id": form_data.get("chat_id", None),
  613. "task": str(TASKS.MOA_RESPONSE_GENERATION),
  614. "task_body": form_data,
  615. },
  616. }
  617. # Process the payload through the pipeline
  618. try:
  619. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  620. except Exception as e:
  621. raise e
  622. try:
  623. return await generate_chat_completion(request, form_data=payload, user=user)
  624. except Exception as e:
  625. return JSONResponse(
  626. status_code=status.HTTP_400_BAD_REQUEST,
  627. content={"detail": str(e)},
  628. )