models.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import time
  2. import logging
  3. import asyncio
  4. import sys
  5. from aiocache import cached
  6. from fastapi import Request
  7. from open_webui.routers import openai, ollama
  8. from open_webui.functions import get_function_models
  9. from open_webui.models.functions import Functions
  10. from open_webui.models.models import Models
  11. from open_webui.utils.plugin import (
  12. load_function_module_by_id,
  13. get_function_module_from_cache,
  14. )
  15. from open_webui.utils.access_control import has_access
  16. from open_webui.config import (
  17. BYPASS_ADMIN_ACCESS_CONTROL,
  18. DEFAULT_ARENA_MODEL,
  19. )
  20. from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  21. from open_webui.models.users import UserModel
  22. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  25. async def fetch_ollama_models(request: Request, user: UserModel = None):
  26. raw_ollama_models = await ollama.get_all_models(request, user=user)
  27. return [
  28. {
  29. "id": model["model"],
  30. "name": model["name"],
  31. "object": "model",
  32. "created": int(time.time()),
  33. "owned_by": "ollama",
  34. "ollama": model,
  35. "connection_type": model.get("connection_type", "local"),
  36. "tags": model.get("tags", []),
  37. }
  38. for model in raw_ollama_models["models"]
  39. ]
  40. async def fetch_openai_models(request: Request, user: UserModel = None):
  41. openai_response = await openai.get_all_models(request, user=user)
  42. return openai_response["data"]
  43. async def get_all_base_models(request: Request, user: UserModel = None):
  44. openai_task = (
  45. fetch_openai_models(request, user)
  46. if request.app.state.config.ENABLE_OPENAI_API
  47. else asyncio.sleep(0, result=[])
  48. )
  49. ollama_task = (
  50. fetch_ollama_models(request, user)
  51. if request.app.state.config.ENABLE_OLLAMA_API
  52. else asyncio.sleep(0, result=[])
  53. )
  54. function_task = get_function_models(request)
  55. openai_models, ollama_models, function_models = await asyncio.gather(
  56. openai_task, ollama_task, function_task
  57. )
  58. return function_models + openai_models + ollama_models
  59. async def get_all_models(request, refresh: bool = False, user: UserModel = None):
  60. if (
  61. request.app.state.MODELS
  62. and request.app.state.BASE_MODELS
  63. and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
  64. ):
  65. base_models = request.app.state.BASE_MODELS
  66. else:
  67. base_models = await get_all_base_models(request, user=user)
  68. request.app.state.BASE_MODELS = base_models
  69. # deep copy the base models to avoid modifying the original list
  70. models = [model.copy() for model in base_models]
  71. # If there are no models, return an empty list
  72. if len(models) == 0:
  73. return []
  74. # Add arena models
  75. if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  76. arena_models = []
  77. if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
  78. arena_models = [
  79. {
  80. "id": model["id"],
  81. "name": model["name"],
  82. "info": {
  83. "meta": model["meta"],
  84. },
  85. "object": "model",
  86. "created": int(time.time()),
  87. "owned_by": "arena",
  88. "arena": True,
  89. }
  90. for model in request.app.state.config.EVALUATION_ARENA_MODELS
  91. ]
  92. else:
  93. # Add default arena model
  94. arena_models = [
  95. {
  96. "id": DEFAULT_ARENA_MODEL["id"],
  97. "name": DEFAULT_ARENA_MODEL["name"],
  98. "info": {
  99. "meta": DEFAULT_ARENA_MODEL["meta"],
  100. },
  101. "object": "model",
  102. "created": int(time.time()),
  103. "owned_by": "arena",
  104. "arena": True,
  105. }
  106. ]
  107. models = models + arena_models
  108. global_action_ids = [
  109. function.id for function in Functions.get_global_action_functions()
  110. ]
  111. enabled_action_ids = [
  112. function.id
  113. for function in Functions.get_functions_by_type("action", active_only=True)
  114. ]
  115. global_filter_ids = [
  116. function.id for function in Functions.get_global_filter_functions()
  117. ]
  118. enabled_filter_ids = [
  119. function.id
  120. for function in Functions.get_functions_by_type("filter", active_only=True)
  121. ]
  122. custom_models = Models.get_all_models()
  123. for custom_model in custom_models:
  124. if custom_model.base_model_id is None:
  125. # Applied directly to a base model
  126. for model in models:
  127. if custom_model.id == model["id"] or (
  128. model.get("owned_by") == "ollama"
  129. and custom_model.id
  130. == model["id"].split(":")[
  131. 0
  132. ] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
  133. ):
  134. if custom_model.is_active:
  135. model["name"] = custom_model.name
  136. model["info"] = custom_model.model_dump()
  137. # Set action_ids and filter_ids
  138. action_ids = []
  139. filter_ids = []
  140. if "info" in model and "meta" in model["info"]:
  141. action_ids.extend(
  142. model["info"]["meta"].get("actionIds", [])
  143. )
  144. filter_ids.extend(
  145. model["info"]["meta"].get("filterIds", [])
  146. )
  147. model["action_ids"] = action_ids
  148. model["filter_ids"] = filter_ids
  149. else:
  150. models.remove(model)
  151. elif custom_model.is_active and (
  152. custom_model.id not in [model["id"] for model in models]
  153. ):
  154. owned_by = "openai"
  155. pipe = None
  156. action_ids = []
  157. filter_ids = []
  158. for model in models:
  159. if (
  160. custom_model.base_model_id == model["id"]
  161. or custom_model.base_model_id == model["id"].split(":")[0]
  162. ):
  163. owned_by = model.get("owned_by", "unknown owner")
  164. if "pipe" in model:
  165. pipe = model["pipe"]
  166. break
  167. if custom_model.meta:
  168. meta = custom_model.meta.model_dump()
  169. if "actionIds" in meta:
  170. action_ids.extend(meta["actionIds"])
  171. if "filterIds" in meta:
  172. filter_ids.extend(meta["filterIds"])
  173. models.append(
  174. {
  175. "id": f"{custom_model.id}",
  176. "name": custom_model.name,
  177. "object": "model",
  178. "created": custom_model.created_at,
  179. "owned_by": owned_by,
  180. "info": custom_model.model_dump(),
  181. "preset": True,
  182. **({"pipe": pipe} if pipe is not None else {}),
  183. "action_ids": action_ids,
  184. "filter_ids": filter_ids,
  185. }
  186. )
  187. # Process action_ids to get the actions
  188. def get_action_items_from_module(function, module):
  189. actions = []
  190. if hasattr(module, "actions"):
  191. actions = module.actions
  192. return [
  193. {
  194. "id": f"{function.id}.{action['id']}",
  195. "name": action.get("name", f"{function.name} ({action['id']})"),
  196. "description": function.meta.description,
  197. "icon": action.get(
  198. "icon_url",
  199. function.meta.manifest.get("icon_url", None)
  200. or getattr(module, "icon_url", None)
  201. or getattr(module, "icon", None),
  202. ),
  203. }
  204. for action in actions
  205. ]
  206. else:
  207. return [
  208. {
  209. "id": function.id,
  210. "name": function.name,
  211. "description": function.meta.description,
  212. "icon": function.meta.manifest.get("icon_url", None)
  213. or getattr(module, "icon_url", None)
  214. or getattr(module, "icon", None),
  215. }
  216. ]
  217. # Process filter_ids to get the filters
  218. def get_filter_items_from_module(function, module):
  219. return [
  220. {
  221. "id": function.id,
  222. "name": function.name,
  223. "description": function.meta.description,
  224. "icon": function.meta.manifest.get("icon_url", None)
  225. or getattr(module, "icon_url", None)
  226. or getattr(module, "icon", None),
  227. "has_user_valves": hasattr(module, "UserValves"),
  228. }
  229. ]
  230. def get_function_module_by_id(function_id):
  231. function_module, _, _ = get_function_module_from_cache(request, function_id)
  232. return function_module
  233. for model in models:
  234. action_ids = [
  235. action_id
  236. for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
  237. if action_id in enabled_action_ids
  238. ]
  239. filter_ids = [
  240. filter_id
  241. for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
  242. if filter_id in enabled_filter_ids
  243. ]
  244. model["actions"] = []
  245. for action_id in action_ids:
  246. action_function = Functions.get_function_by_id(action_id)
  247. if action_function is None:
  248. raise Exception(f"Action not found: {action_id}")
  249. function_module = get_function_module_by_id(action_id)
  250. model["actions"].extend(
  251. get_action_items_from_module(action_function, function_module)
  252. )
  253. model["filters"] = []
  254. for filter_id in filter_ids:
  255. filter_function = Functions.get_function_by_id(filter_id)
  256. if filter_function is None:
  257. raise Exception(f"Filter not found: {filter_id}")
  258. function_module = get_function_module_by_id(filter_id)
  259. if getattr(function_module, "toggle", None):
  260. model["filters"].extend(
  261. get_filter_items_from_module(filter_function, function_module)
  262. )
  263. log.debug(f"get_all_models() returned {len(models)} models")
  264. request.app.state.MODELS = {model["id"]: model for model in models}
  265. return models
  266. def check_model_access(user, model):
  267. if model.get("arena"):
  268. if not has_access(
  269. user.id,
  270. type="read",
  271. access_control=model.get("info", {})
  272. .get("meta", {})
  273. .get("access_control", {}),
  274. ):
  275. raise Exception("Model not found")
  276. else:
  277. model_info = Models.get_model_by_id(model.get("id"))
  278. if not model_info:
  279. raise Exception("Model not found")
  280. elif not (
  281. user.id == model_info.user_id
  282. or has_access(
  283. user.id, type="read", access_control=model_info.access_control
  284. )
  285. ):
  286. raise Exception("Model not found")
  287. def get_filtered_models(models, user):
  288. # Filter out models that the user does not have access to
  289. if (
  290. user.role == "user"
  291. or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
  292. ) and not BYPASS_MODEL_ACCESS_CONTROL:
  293. filtered_models = []
  294. for model in models:
  295. if model.get("arena"):
  296. if has_access(
  297. user.id,
  298. type="read",
  299. access_control=model.get("info", {})
  300. .get("meta", {})
  301. .get("access_control", {}),
  302. ):
  303. filtered_models.append(model)
  304. continue
  305. model_info = Models.get_model_by_id(model["id"])
  306. if model_info:
  307. if (
  308. (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
  309. or user.id == model_info.user_id
  310. or has_access(
  311. user.id,
  312. type="read",
  313. access_control=model_info.access_control,
  314. )
  315. ):
  316. filtered_models.append(model)
  317. return filtered_models
  318. else:
  319. return models