tasks.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
  2. from fastapi.responses import JSONResponse, RedirectResponse
  3. from pydantic import BaseModel
  4. from typing import Optional
  5. import logging
  6. import re
  7. from open_webui.utils.chat import generate_chat_completion
  8. from open_webui.utils.task import (
  9. title_generation_template,
  10. query_generation_template,
  11. image_prompt_generation_template,
  12. autocomplete_generation_template,
  13. tags_generation_template,
  14. emoji_generation_template,
  15. moa_response_generation_template,
  16. )
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.constants import TASKS
  19. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  20. from open_webui.utils.task import get_task_model_id
  21. from open_webui.config import (
  22. DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
  23. DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
  24. DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  25. DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
  26. DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
  27. DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
  28. DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
  29. )
  30. from open_webui.env import SRC_LOG_LEVELS
  31. log = logging.getLogger(__name__)
  32. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  33. router = APIRouter()
  34. ##################################
  35. #
  36. # Task Endpoints
  37. #
  38. ##################################
  39. @router.get("/config")
  40. async def get_task_config(request: Request, user=Depends(get_verified_user)):
  41. return {
  42. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  43. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  44. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  45. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  46. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  47. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  48. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  49. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  50. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  51. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  52. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  53. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  54. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  55. }
  56. class TaskConfigForm(BaseModel):
  57. TASK_MODEL: Optional[str]
  58. TASK_MODEL_EXTERNAL: Optional[str]
  59. ENABLE_TITLE_GENERATION: bool
  60. TITLE_GENERATION_PROMPT_TEMPLATE: str
  61. IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
  62. ENABLE_AUTOCOMPLETE_GENERATION: bool
  63. AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
  64. TAGS_GENERATION_PROMPT_TEMPLATE: str
  65. ENABLE_TAGS_GENERATION: bool
  66. ENABLE_SEARCH_QUERY_GENERATION: bool
  67. ENABLE_RETRIEVAL_QUERY_GENERATION: bool
  68. QUERY_GENERATION_PROMPT_TEMPLATE: str
  69. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  70. @router.post("/config/update")
  71. async def update_task_config(
  72. request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
  73. ):
  74. request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
  75. request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  76. request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
  77. request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  78. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  79. )
  80. request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
  81. form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  82. )
  83. request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
  84. form_data.ENABLE_AUTOCOMPLETE_GENERATION
  85. )
  86. request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
  87. form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  88. )
  89. request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
  90. form_data.TAGS_GENERATION_PROMPT_TEMPLATE
  91. )
  92. request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
  93. request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
  94. form_data.ENABLE_SEARCH_QUERY_GENERATION
  95. )
  96. request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
  97. form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
  98. )
  99. request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
  100. form_data.QUERY_GENERATION_PROMPT_TEMPLATE
  101. )
  102. request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  103. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  104. )
  105. return {
  106. "TASK_MODEL": request.app.state.config.TASK_MODEL,
  107. "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
  108. "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
  109. "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  110. "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
  111. "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
  112. "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
  113. "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
  114. "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
  115. "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
  116. "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
  117. "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
  118. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  119. }
  120. @router.post("/title/completions")
  121. async def generate_title(
  122. request: Request, form_data: dict, user=Depends(get_verified_user)
  123. ):
  124. if not request.app.state.config.ENABLE_TITLE_GENERATION:
  125. return JSONResponse(
  126. status_code=status.HTTP_200_OK,
  127. content={"detail": "Title generation is disabled"},
  128. )
  129. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  130. models = {
  131. request.state.model["id"]: request.state.model,
  132. }
  133. else:
  134. models = request.app.state.MODELS
  135. model_id = form_data["model"]
  136. if model_id not in models:
  137. raise HTTPException(
  138. status_code=status.HTTP_404_NOT_FOUND,
  139. detail="Model not found",
  140. )
  141. # Check if the user has a custom task model
  142. # If the user has a custom task model, use that model
  143. task_model_id = get_task_model_id(
  144. model_id,
  145. request.app.state.config.TASK_MODEL,
  146. request.app.state.config.TASK_MODEL_EXTERNAL,
  147. models,
  148. )
  149. log.debug(
  150. f"generating chat title using model {task_model_id} for user {user.email} "
  151. )
  152. if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
  153. template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  154. else:
  155. template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
  156. content = title_generation_template(
  157. template,
  158. form_data["messages"],
  159. {
  160. "name": user.name,
  161. "location": user.info.get("location") if user.info else None,
  162. },
  163. )
  164. max_tokens = (
  165. models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
  166. )
  167. payload = {
  168. "model": task_model_id,
  169. "messages": [{"role": "user", "content": content}],
  170. "stream": False,
  171. **(
  172. {"max_tokens": max_tokens}
  173. if models[task_model_id].get("owned_by") == "ollama"
  174. else {
  175. "max_completion_tokens": max_tokens,
  176. }
  177. ),
  178. "metadata": {
  179. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  180. "task": str(TASKS.TITLE_GENERATION),
  181. "task_body": form_data,
  182. "chat_id": form_data.get("chat_id", None),
  183. },
  184. }
  185. # Process the payload through the pipeline
  186. try:
  187. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  188. except Exception as e:
  189. raise e
  190. try:
  191. return await generate_chat_completion(request, form_data=payload, user=user)
  192. except Exception as e:
  193. log.error("Exception occurred", exc_info=True)
  194. return JSONResponse(
  195. status_code=status.HTTP_400_BAD_REQUEST,
  196. content={"detail": "An internal error has occurred."},
  197. )
  198. @router.post("/tags/completions")
  199. async def generate_chat_tags(
  200. request: Request, form_data: dict, user=Depends(get_verified_user)
  201. ):
  202. if not request.app.state.config.ENABLE_TAGS_GENERATION:
  203. return JSONResponse(
  204. status_code=status.HTTP_200_OK,
  205. content={"detail": "Tags generation is disabled"},
  206. )
  207. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  208. models = {
  209. request.state.model["id"]: request.state.model,
  210. }
  211. else:
  212. models = request.app.state.MODELS
  213. model_id = form_data["model"]
  214. if model_id not in models:
  215. raise HTTPException(
  216. status_code=status.HTTP_404_NOT_FOUND,
  217. detail="Model not found",
  218. )
  219. # Check if the user has a custom task model
  220. # If the user has a custom task model, use that model
  221. task_model_id = get_task_model_id(
  222. model_id,
  223. request.app.state.config.TASK_MODEL,
  224. request.app.state.config.TASK_MODEL_EXTERNAL,
  225. models,
  226. )
  227. log.debug(
  228. f"generating chat tags using model {task_model_id} for user {user.email} "
  229. )
  230. if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
  231. template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
  232. else:
  233. template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
  234. content = tags_generation_template(
  235. template, form_data["messages"], {"name": user.name}
  236. )
  237. payload = {
  238. "model": task_model_id,
  239. "messages": [{"role": "user", "content": content}],
  240. "stream": False,
  241. "metadata": {
  242. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  243. "task": str(TASKS.TAGS_GENERATION),
  244. "task_body": form_data,
  245. "chat_id": form_data.get("chat_id", None),
  246. },
  247. }
  248. # Process the payload through the pipeline
  249. try:
  250. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  251. except Exception as e:
  252. raise e
  253. try:
  254. return await generate_chat_completion(request, form_data=payload, user=user)
  255. except Exception as e:
  256. log.error(f"Error generating chat completion: {e}")
  257. return JSONResponse(
  258. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  259. content={"detail": "An internal error has occurred."},
  260. )
  261. @router.post("/image_prompt/completions")
  262. async def generate_image_prompt(
  263. request: Request, form_data: dict, user=Depends(get_verified_user)
  264. ):
  265. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  266. models = {
  267. request.state.model["id"]: request.state.model,
  268. }
  269. else:
  270. models = request.app.state.MODELS
  271. model_id = form_data["model"]
  272. if model_id not in models:
  273. raise HTTPException(
  274. status_code=status.HTTP_404_NOT_FOUND,
  275. detail="Model not found",
  276. )
  277. # Check if the user has a custom task model
  278. # If the user has a custom task model, use that model
  279. task_model_id = get_task_model_id(
  280. model_id,
  281. request.app.state.config.TASK_MODEL,
  282. request.app.state.config.TASK_MODEL_EXTERNAL,
  283. models,
  284. )
  285. log.debug(
  286. f"generating image prompt using model {task_model_id} for user {user.email} "
  287. )
  288. if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
  289. template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  290. else:
  291. template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
  292. content = image_prompt_generation_template(
  293. template,
  294. form_data["messages"],
  295. user={
  296. "name": user.name,
  297. },
  298. )
  299. payload = {
  300. "model": task_model_id,
  301. "messages": [{"role": "user", "content": content}],
  302. "stream": False,
  303. "metadata": {
  304. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  305. "task": str(TASKS.IMAGE_PROMPT_GENERATION),
  306. "task_body": form_data,
  307. "chat_id": form_data.get("chat_id", None),
  308. },
  309. }
  310. # Process the payload through the pipeline
  311. try:
  312. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  313. except Exception as e:
  314. raise e
  315. try:
  316. return await generate_chat_completion(request, form_data=payload, user=user)
  317. except Exception as e:
  318. log.error("Exception occurred", exc_info=True)
  319. return JSONResponse(
  320. status_code=status.HTTP_400_BAD_REQUEST,
  321. content={"detail": "An internal error has occurred."},
  322. )
  323. @router.post("/queries/completions")
  324. async def generate_queries(
  325. request: Request, form_data: dict, user=Depends(get_verified_user)
  326. ):
  327. type = form_data.get("type")
  328. if type == "web_search":
  329. if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
  330. raise HTTPException(
  331. status_code=status.HTTP_400_BAD_REQUEST,
  332. detail=f"Search query generation is disabled",
  333. )
  334. elif type == "retrieval":
  335. if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
  336. raise HTTPException(
  337. status_code=status.HTTP_400_BAD_REQUEST,
  338. detail=f"Query generation is disabled",
  339. )
  340. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  341. models = {
  342. request.state.model["id"]: request.state.model,
  343. }
  344. else:
  345. models = request.app.state.MODELS
  346. model_id = form_data["model"]
  347. if model_id not in models:
  348. raise HTTPException(
  349. status_code=status.HTTP_404_NOT_FOUND,
  350. detail="Model not found",
  351. )
  352. # Check if the user has a custom task model
  353. # If the user has a custom task model, use that model
  354. task_model_id = get_task_model_id(
  355. model_id,
  356. request.app.state.config.TASK_MODEL,
  357. request.app.state.config.TASK_MODEL_EXTERNAL,
  358. models,
  359. )
  360. log.debug(
  361. f"generating {type} queries using model {task_model_id} for user {user.email}"
  362. )
  363. if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
  364. template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
  365. else:
  366. template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
  367. content = query_generation_template(
  368. template, form_data["messages"], {"name": user.name}
  369. )
  370. payload = {
  371. "model": task_model_id,
  372. "messages": [{"role": "user", "content": content}],
  373. "stream": False,
  374. "metadata": {
  375. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  376. "task": str(TASKS.QUERY_GENERATION),
  377. "task_body": form_data,
  378. "chat_id": form_data.get("chat_id", None),
  379. },
  380. }
  381. # Process the payload through the pipeline
  382. try:
  383. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  384. except Exception as e:
  385. raise e
  386. try:
  387. return await generate_chat_completion(request, form_data=payload, user=user)
  388. except Exception as e:
  389. return JSONResponse(
  390. status_code=status.HTTP_400_BAD_REQUEST,
  391. content={"detail": str(e)},
  392. )
  393. @router.post("/auto/completions")
  394. async def generate_autocompletion(
  395. request: Request, form_data: dict, user=Depends(get_verified_user)
  396. ):
  397. if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
  398. raise HTTPException(
  399. status_code=status.HTTP_400_BAD_REQUEST,
  400. detail=f"Autocompletion generation is disabled",
  401. )
  402. type = form_data.get("type")
  403. prompt = form_data.get("prompt")
  404. messages = form_data.get("messages")
  405. if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
  406. if (
  407. len(prompt)
  408. > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
  409. ):
  410. raise HTTPException(
  411. status_code=status.HTTP_400_BAD_REQUEST,
  412. detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
  413. )
  414. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  415. models = {
  416. request.state.model["id"]: request.state.model,
  417. }
  418. else:
  419. models = request.app.state.MODELS
  420. model_id = form_data["model"]
  421. if model_id not in models:
  422. raise HTTPException(
  423. status_code=status.HTTP_404_NOT_FOUND,
  424. detail="Model not found",
  425. )
  426. # Check if the user has a custom task model
  427. # If the user has a custom task model, use that model
  428. task_model_id = get_task_model_id(
  429. model_id,
  430. request.app.state.config.TASK_MODEL,
  431. request.app.state.config.TASK_MODEL_EXTERNAL,
  432. models,
  433. )
  434. log.debug(
  435. f"generating autocompletion using model {task_model_id} for user {user.email}"
  436. )
  437. if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
  438. template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  439. else:
  440. template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
  441. content = autocomplete_generation_template(
  442. template, prompt, messages, type, {"name": user.name}
  443. )
  444. payload = {
  445. "model": task_model_id,
  446. "messages": [{"role": "user", "content": content}],
  447. "stream": False,
  448. "metadata": {
  449. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  450. "task": str(TASKS.AUTOCOMPLETE_GENERATION),
  451. "task_body": form_data,
  452. "chat_id": form_data.get("chat_id", None),
  453. },
  454. }
  455. # Process the payload through the pipeline
  456. try:
  457. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  458. except Exception as e:
  459. raise e
  460. try:
  461. return await generate_chat_completion(request, form_data=payload, user=user)
  462. except Exception as e:
  463. log.error(f"Error generating chat completion: {e}")
  464. return JSONResponse(
  465. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  466. content={"detail": "An internal error has occurred."},
  467. )
  468. @router.post("/emoji/completions")
  469. async def generate_emoji(
  470. request: Request, form_data: dict, user=Depends(get_verified_user)
  471. ):
  472. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  473. models = {
  474. request.state.model["id"]: request.state.model,
  475. }
  476. else:
  477. models = request.app.state.MODELS
  478. model_id = form_data["model"]
  479. if model_id not in models:
  480. raise HTTPException(
  481. status_code=status.HTTP_404_NOT_FOUND,
  482. detail="Model not found",
  483. )
  484. # Check if the user has a custom task model
  485. # If the user has a custom task model, use that model
  486. task_model_id = get_task_model_id(
  487. model_id,
  488. request.app.state.config.TASK_MODEL,
  489. request.app.state.config.TASK_MODEL_EXTERNAL,
  490. models,
  491. )
  492. log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
  493. template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
  494. content = emoji_generation_template(
  495. template,
  496. form_data["prompt"],
  497. {
  498. "name": user.name,
  499. "location": user.info.get("location") if user.info else None,
  500. },
  501. )
  502. payload = {
  503. "model": task_model_id,
  504. "messages": [{"role": "user", "content": content}],
  505. "stream": False,
  506. **(
  507. {"max_tokens": 4}
  508. if models[task_model_id].get("owned_by") == "ollama"
  509. else {
  510. "max_completion_tokens": 4,
  511. }
  512. ),
  513. "chat_id": form_data.get("chat_id", None),
  514. "metadata": {
  515. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  516. "task": str(TASKS.EMOJI_GENERATION),
  517. "task_body": form_data,
  518. },
  519. }
  520. # Process the payload through the pipeline
  521. try:
  522. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  523. except Exception as e:
  524. raise e
  525. try:
  526. return await generate_chat_completion(request, form_data=payload, user=user)
  527. except Exception as e:
  528. return JSONResponse(
  529. status_code=status.HTTP_400_BAD_REQUEST,
  530. content={"detail": str(e)},
  531. )
  532. @router.post("/moa/completions")
  533. async def generate_moa_response(
  534. request: Request, form_data: dict, user=Depends(get_verified_user)
  535. ):
  536. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  537. models = {
  538. request.state.model["id"]: request.state.model,
  539. }
  540. else:
  541. models = request.app.state.MODELS
  542. model_id = form_data["model"]
  543. if model_id not in models:
  544. raise HTTPException(
  545. status_code=status.HTTP_404_NOT_FOUND,
  546. detail="Model not found",
  547. )
  548. template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
  549. content = moa_response_generation_template(
  550. template,
  551. form_data["prompt"],
  552. form_data["responses"],
  553. )
  554. payload = {
  555. "model": model_id,
  556. "messages": [{"role": "user", "content": content}],
  557. "stream": form_data.get("stream", False),
  558. "metadata": {
  559. **(request.state.metadata if hasattr(request.state, "metadata") else {}),
  560. "chat_id": form_data.get("chat_id", None),
  561. "task": str(TASKS.MOA_RESPONSE_GENERATION),
  562. "task_body": form_data,
  563. },
  564. }
  565. # Process the payload through the pipeline
  566. try:
  567. payload = await process_pipeline_inlet_filter(request, payload, user, models)
  568. except Exception as e:
  569. raise e
  570. try:
  571. return await generate_chat_completion(request, form_data=payload, user=user)
  572. except Exception as e:
  573. return JSONResponse(
  574. status_code=status.HTTP_400_BAD_REQUEST,
  575. content={"detail": str(e)},
  576. )