functions.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import logging
  2. import sys
  3. import inspect
  4. import json
  5. import asyncio
  6. from pydantic import BaseModel
  7. from typing import AsyncGenerator, Generator, Iterator
  8. from fastapi import (
  9. Depends,
  10. FastAPI,
  11. File,
  12. Form,
  13. HTTPException,
  14. Request,
  15. UploadFile,
  16. status,
  17. )
  18. from starlette.responses import Response, StreamingResponse
  19. from open_webui.constants import ERROR_MESSAGES
  20. from open_webui.socket.main import (
  21. get_event_call,
  22. get_event_emitter,
  23. )
  24. from open_webui.models.users import UserModel
  25. from open_webui.models.functions import Functions
  26. from open_webui.models.models import Models
  27. from open_webui.utils.plugin import (
  28. load_function_module_by_id,
  29. get_function_module_from_cache,
  30. )
  31. from open_webui.utils.tools import get_tools
  32. from open_webui.utils.access_control import has_access
  33. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  34. from open_webui.utils.misc import (
  35. add_or_update_system_message,
  36. get_last_user_message,
  37. prepend_to_first_user_message_content,
  38. openai_chat_chunk_message_template,
  39. openai_chat_completion_message_template,
  40. )
  41. from open_webui.utils.payload import (
  42. apply_model_params_to_body_openai,
  43. apply_system_prompt_to_body,
  44. )
  45. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  46. log = logging.getLogger(__name__)
  47. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  48. def get_function_module_by_id(request: Request, pipe_id: str):
  49. function_module, _, _ = get_function_module_from_cache(request, pipe_id)
  50. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  51. Valves = function_module.Valves
  52. valves = Functions.get_function_valves_by_id(pipe_id)
  53. if valves:
  54. try:
  55. function_module.valves = Valves(
  56. **{k: v for k, v in valves.items() if v is not None}
  57. )
  58. except Exception as e:
  59. log.exception(f"Error loading valves for function {pipe_id}: {e}")
  60. raise e
  61. else:
  62. function_module.valves = Valves()
  63. return function_module
  64. async def get_function_models(request):
  65. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  66. pipe_models = []
  67. for pipe in pipes:
  68. try:
  69. function_module = get_function_module_by_id(request, pipe.id)
  70. has_user_valves = False
  71. if hasattr(function_module, "UserValves"):
  72. has_user_valves = True
  73. # Check if function is a manifold
  74. if hasattr(function_module, "pipes"):
  75. sub_pipes = []
  76. # Handle pipes being a list, sync function, or async function
  77. try:
  78. if callable(function_module.pipes):
  79. if asyncio.iscoroutinefunction(function_module.pipes):
  80. sub_pipes = await function_module.pipes()
  81. else:
  82. sub_pipes = function_module.pipes()
  83. else:
  84. sub_pipes = function_module.pipes
  85. except Exception as e:
  86. log.exception(e)
  87. sub_pipes = []
  88. log.debug(
  89. f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
  90. )
  91. for p in sub_pipes:
  92. sub_pipe_id = f'{pipe.id}.{p["id"]}'
  93. sub_pipe_name = p["name"]
  94. if hasattr(function_module, "name"):
  95. sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
  96. pipe_flag = {"type": pipe.type}
  97. pipe_models.append(
  98. {
  99. "id": sub_pipe_id,
  100. "name": sub_pipe_name,
  101. "object": "model",
  102. "created": pipe.created_at,
  103. "owned_by": "openai",
  104. "pipe": pipe_flag,
  105. "has_user_valves": has_user_valves,
  106. }
  107. )
  108. else:
  109. pipe_flag = {"type": "pipe"}
  110. log.debug(
  111. f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
  112. )
  113. pipe_models.append(
  114. {
  115. "id": pipe.id,
  116. "name": pipe.name,
  117. "object": "model",
  118. "created": pipe.created_at,
  119. "owned_by": "openai",
  120. "pipe": pipe_flag,
  121. "has_user_valves": has_user_valves,
  122. }
  123. )
  124. except Exception as e:
  125. log.exception(e)
  126. continue
  127. return pipe_models
  128. async def generate_function_chat_completion(
  129. request, form_data, user, models: dict = {}
  130. ):
  131. async def execute_pipe(pipe, params):
  132. if inspect.iscoroutinefunction(pipe):
  133. return await pipe(**params)
  134. else:
  135. return pipe(**params)
  136. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  137. if isinstance(res, str):
  138. return res
  139. if isinstance(res, Generator):
  140. return "".join(map(str, res))
  141. if isinstance(res, AsyncGenerator):
  142. return "".join([str(stream) async for stream in res])
  143. def process_line(form_data: dict, line):
  144. if isinstance(line, BaseModel):
  145. line = line.model_dump_json()
  146. line = f"data: {line}"
  147. if isinstance(line, dict):
  148. line = f"data: {json.dumps(line)}"
  149. try:
  150. line = line.decode("utf-8")
  151. except Exception:
  152. pass
  153. if line.startswith("data:"):
  154. return f"{line}\n\n"
  155. else:
  156. line = openai_chat_chunk_message_template(form_data["model"], line)
  157. return f"data: {json.dumps(line)}\n\n"
  158. def get_pipe_id(form_data: dict) -> str:
  159. pipe_id = form_data["model"]
  160. if "." in pipe_id:
  161. pipe_id, _ = pipe_id.split(".", 1)
  162. return pipe_id
  163. def get_function_params(function_module, form_data, user, extra_params=None):
  164. if extra_params is None:
  165. extra_params = {}
  166. pipe_id = get_pipe_id(form_data)
  167. # Get the signature of the function
  168. sig = inspect.signature(function_module.pipe)
  169. params = {"body": form_data} | {
  170. k: v for k, v in extra_params.items() if k in sig.parameters
  171. }
  172. if "__user__" in params and hasattr(function_module, "UserValves"):
  173. user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  174. try:
  175. params["__user__"]["valves"] = function_module.UserValves(**user_valves)
  176. except Exception as e:
  177. log.exception(e)
  178. params["__user__"]["valves"] = function_module.UserValves()
  179. return params
  180. model_id = form_data.get("model")
  181. model_info = Models.get_model_by_id(model_id)
  182. metadata = form_data.pop("metadata", {})
  183. files = metadata.get("files", [])
  184. tool_ids = metadata.get("tool_ids", [])
  185. # Check if tool_ids is None
  186. if tool_ids is None:
  187. tool_ids = []
  188. __event_emitter__ = None
  189. __event_call__ = None
  190. __task__ = None
  191. __task_body__ = None
  192. if metadata:
  193. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  194. __event_emitter__ = get_event_emitter(metadata)
  195. __event_call__ = get_event_call(metadata)
  196. __task__ = metadata.get("task", None)
  197. __task_body__ = metadata.get("task_body", None)
  198. oauth_token = None
  199. try:
  200. if request.cookies.get("oauth_session_id", None):
  201. oauth_token = await request.app.state.oauth_manager.get_oauth_token(
  202. user.id,
  203. request.cookies.get("oauth_session_id", None),
  204. )
  205. except Exception as e:
  206. log.error(f"Error getting OAuth token: {e}")
  207. extra_params = {
  208. "__event_emitter__": __event_emitter__,
  209. "__event_call__": __event_call__,
  210. "__chat_id__": metadata.get("chat_id", None),
  211. "__session_id__": metadata.get("session_id", None),
  212. "__message_id__": metadata.get("message_id", None),
  213. "__task__": __task__,
  214. "__task_body__": __task_body__,
  215. "__files__": files,
  216. "__user__": user.model_dump() if isinstance(user, UserModel) else {},
  217. "__metadata__": metadata,
  218. "__oauth_token__": oauth_token,
  219. "__request__": request,
  220. }
  221. extra_params["__tools__"] = await get_tools(
  222. request,
  223. tool_ids,
  224. user,
  225. {
  226. **extra_params,
  227. "__model__": models.get(form_data["model"], None),
  228. "__messages__": form_data["messages"],
  229. "__files__": files,
  230. },
  231. )
  232. if model_info:
  233. if model_info.base_model_id:
  234. form_data["model"] = model_info.base_model_id
  235. params = model_info.params.model_dump()
  236. if params:
  237. system = params.pop("system", None)
  238. form_data = apply_model_params_to_body_openai(params, form_data)
  239. form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
  240. pipe_id = get_pipe_id(form_data)
  241. function_module = get_function_module_by_id(request, pipe_id)
  242. pipe = function_module.pipe
  243. params = get_function_params(function_module, form_data, user, extra_params)
  244. if form_data.get("stream", False):
  245. async def stream_content():
  246. try:
  247. res = await execute_pipe(pipe, params)
  248. # Directly return if the response is a StreamingResponse
  249. if isinstance(res, StreamingResponse):
  250. async for data in res.body_iterator:
  251. yield data
  252. return
  253. if isinstance(res, dict):
  254. yield f"data: {json.dumps(res)}\n\n"
  255. return
  256. except Exception as e:
  257. log.error(f"Error: {e}")
  258. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  259. return
  260. if isinstance(res, str):
  261. message = openai_chat_chunk_message_template(form_data["model"], res)
  262. yield f"data: {json.dumps(message)}\n\n"
  263. if isinstance(res, Iterator):
  264. for line in res:
  265. yield process_line(form_data, line)
  266. if isinstance(res, AsyncGenerator):
  267. async for line in res:
  268. yield process_line(form_data, line)
  269. if isinstance(res, str) or isinstance(res, Generator):
  270. finish_message = openai_chat_chunk_message_template(
  271. form_data["model"], ""
  272. )
  273. finish_message["choices"][0]["finish_reason"] = "stop"
  274. yield f"data: {json.dumps(finish_message)}\n\n"
  275. yield "data: [DONE]"
  276. return StreamingResponse(stream_content(), media_type="text/event-stream")
  277. else:
  278. try:
  279. res = await execute_pipe(pipe, params)
  280. except Exception as e:
  281. log.error(f"Error: {e}")
  282. return {"error": {"detail": str(e)}}
  283. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  284. return res
  285. if isinstance(res, BaseModel):
  286. return res.model_dump()
  287. message = await get_message_content(res)
  288. return openai_chat_completion_message_template(form_data["model"], message)