1
0

models.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import time
  2. import logging
  3. import asyncio
  4. import sys
  5. from aiocache import cached
  6. from fastapi import Request
  7. from open_webui.routers import openai, ollama
  8. from open_webui.functions import get_function_models
  9. from open_webui.models.functions import Functions
  10. from open_webui.models.models import Models
  11. from open_webui.models.groups import Groups
  12. from open_webui.utils.plugin import (
  13. load_function_module_by_id,
  14. get_function_module_from_cache,
  15. )
  16. from open_webui.utils.access_control import has_access
  17. from open_webui.config import (
  18. BYPASS_ADMIN_ACCESS_CONTROL,
  19. DEFAULT_ARENA_MODEL,
  20. )
  21. from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  22. from open_webui.models.users import UserModel
  23. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  26. async def fetch_ollama_models(request: Request, user: UserModel = None):
  27. raw_ollama_models = await ollama.get_all_models(request, user=user)
  28. return [
  29. {
  30. "id": model["model"],
  31. "name": model["name"],
  32. "object": "model",
  33. "created": int(time.time()),
  34. "owned_by": "ollama",
  35. "ollama": model,
  36. "connection_type": model.get("connection_type", "local"),
  37. "tags": model.get("tags", []),
  38. }
  39. for model in raw_ollama_models["models"]
  40. ]
  41. async def fetch_openai_models(request: Request, user: UserModel = None):
  42. openai_response = await openai.get_all_models(request, user=user)
  43. return openai_response["data"]
  44. async def get_all_base_models(request: Request, user: UserModel = None):
  45. openai_task = (
  46. fetch_openai_models(request, user)
  47. if request.app.state.config.ENABLE_OPENAI_API
  48. else asyncio.sleep(0, result=[])
  49. )
  50. ollama_task = (
  51. fetch_ollama_models(request, user)
  52. if request.app.state.config.ENABLE_OLLAMA_API
  53. else asyncio.sleep(0, result=[])
  54. )
  55. function_task = get_function_models(request)
  56. openai_models, ollama_models, function_models = await asyncio.gather(
  57. openai_task, ollama_task, function_task
  58. )
  59. return function_models + openai_models + ollama_models
  60. async def get_all_models(request, refresh: bool = False, user: UserModel = None):
  61. if (
  62. request.app.state.MODELS
  63. and request.app.state.BASE_MODELS
  64. and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
  65. ):
  66. base_models = request.app.state.BASE_MODELS
  67. else:
  68. base_models = await get_all_base_models(request, user=user)
  69. request.app.state.BASE_MODELS = base_models
  70. # deep copy the base models to avoid modifying the original list
  71. models = [model.copy() for model in base_models]
  72. # If there are no models, return an empty list
  73. if len(models) == 0:
  74. return []
  75. # Add arena models
  76. if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  77. arena_models = []
  78. if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
  79. arena_models = [
  80. {
  81. "id": model["id"],
  82. "name": model["name"],
  83. "info": {
  84. "meta": model["meta"],
  85. },
  86. "object": "model",
  87. "created": int(time.time()),
  88. "owned_by": "arena",
  89. "arena": True,
  90. }
  91. for model in request.app.state.config.EVALUATION_ARENA_MODELS
  92. ]
  93. else:
  94. # Add default arena model
  95. arena_models = [
  96. {
  97. "id": DEFAULT_ARENA_MODEL["id"],
  98. "name": DEFAULT_ARENA_MODEL["name"],
  99. "info": {
  100. "meta": DEFAULT_ARENA_MODEL["meta"],
  101. },
  102. "object": "model",
  103. "created": int(time.time()),
  104. "owned_by": "arena",
  105. "arena": True,
  106. }
  107. ]
  108. models = models + arena_models
  109. global_action_ids = [
  110. function.id for function in Functions.get_global_action_functions()
  111. ]
  112. enabled_action_ids = [
  113. function.id
  114. for function in Functions.get_functions_by_type("action", active_only=True)
  115. ]
  116. global_filter_ids = [
  117. function.id for function in Functions.get_global_filter_functions()
  118. ]
  119. enabled_filter_ids = [
  120. function.id
  121. for function in Functions.get_functions_by_type("filter", active_only=True)
  122. ]
  123. custom_models = Models.get_all_models()
  124. for custom_model in custom_models:
  125. if custom_model.base_model_id is None:
  126. # Applied directly to a base model
  127. for model in models:
  128. if custom_model.id == model["id"] or (
  129. model.get("owned_by") == "ollama"
  130. and custom_model.id
  131. == model["id"].split(":")[
  132. 0
  133. ] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
  134. ):
  135. if custom_model.is_active:
  136. model["name"] = custom_model.name
  137. model["info"] = custom_model.model_dump()
  138. # Set action_ids and filter_ids
  139. action_ids = []
  140. filter_ids = []
  141. if "info" in model:
  142. if "meta" in model["info"]:
  143. action_ids.extend(
  144. model["info"]["meta"].get("actionIds", [])
  145. )
  146. filter_ids.extend(
  147. model["info"]["meta"].get("filterIds", [])
  148. )
  149. if "params" in model["info"]:
  150. # Remove params to avoid exposing sensitive info
  151. del model["info"]["params"]
  152. model["action_ids"] = action_ids
  153. model["filter_ids"] = filter_ids
  154. else:
  155. models.remove(model)
  156. elif custom_model.is_active and (
  157. custom_model.id not in [model["id"] for model in models]
  158. ):
  159. # Custom model based on a base model
  160. owned_by = "openai"
  161. pipe = None
  162. for m in models:
  163. if (
  164. custom_model.base_model_id == m["id"]
  165. or custom_model.base_model_id == m["id"].split(":")[0]
  166. ):
  167. owned_by = m.get("owned_by", "unknown")
  168. if "pipe" in m:
  169. pipe = m["pipe"]
  170. break
  171. model = {
  172. "id": f"{custom_model.id}",
  173. "name": custom_model.name,
  174. "object": "model",
  175. "created": custom_model.created_at,
  176. "owned_by": owned_by,
  177. "preset": True,
  178. **({"pipe": pipe} if pipe is not None else {}),
  179. }
  180. info = custom_model.model_dump()
  181. if "params" in info:
  182. # Remove params to avoid exposing sensitive info
  183. del info["params"]
  184. model["info"] = info
  185. action_ids = []
  186. filter_ids = []
  187. if custom_model.meta:
  188. meta = custom_model.meta.model_dump()
  189. if "actionIds" in meta:
  190. action_ids.extend(meta["actionIds"])
  191. if "filterIds" in meta:
  192. filter_ids.extend(meta["filterIds"])
  193. model["action_ids"] = action_ids
  194. model["filter_ids"] = filter_ids
  195. models.append(model)
  196. # Process action_ids to get the actions
  197. def get_action_items_from_module(function, module):
  198. actions = []
  199. if hasattr(module, "actions"):
  200. actions = module.actions
  201. return [
  202. {
  203. "id": f"{function.id}.{action['id']}",
  204. "name": action.get("name", f"{function.name} ({action['id']})"),
  205. "description": function.meta.description,
  206. "icon": action.get(
  207. "icon_url",
  208. function.meta.manifest.get("icon_url", None)
  209. or getattr(module, "icon_url", None)
  210. or getattr(module, "icon", None),
  211. ),
  212. }
  213. for action in actions
  214. ]
  215. else:
  216. return [
  217. {
  218. "id": function.id,
  219. "name": function.name,
  220. "description": function.meta.description,
  221. "icon": function.meta.manifest.get("icon_url", None)
  222. or getattr(module, "icon_url", None)
  223. or getattr(module, "icon", None),
  224. }
  225. ]
  226. # Process filter_ids to get the filters
  227. def get_filter_items_from_module(function, module):
  228. return [
  229. {
  230. "id": function.id,
  231. "name": function.name,
  232. "description": function.meta.description,
  233. "icon": function.meta.manifest.get("icon_url", None)
  234. or getattr(module, "icon_url", None)
  235. or getattr(module, "icon", None),
  236. "has_user_valves": hasattr(module, "UserValves"),
  237. }
  238. ]
  239. def get_function_module_by_id(function_id):
  240. function_module, _, _ = get_function_module_from_cache(request, function_id)
  241. return function_module
  242. for model in models:
  243. action_ids = [
  244. action_id
  245. for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
  246. if action_id in enabled_action_ids
  247. ]
  248. filter_ids = [
  249. filter_id
  250. for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
  251. if filter_id in enabled_filter_ids
  252. ]
  253. model["actions"] = []
  254. for action_id in action_ids:
  255. action_function = Functions.get_function_by_id(action_id)
  256. if action_function is None:
  257. raise Exception(f"Action not found: {action_id}")
  258. function_module = get_function_module_by_id(action_id)
  259. model["actions"].extend(
  260. get_action_items_from_module(action_function, function_module)
  261. )
  262. model["filters"] = []
  263. for filter_id in filter_ids:
  264. filter_function = Functions.get_function_by_id(filter_id)
  265. if filter_function is None:
  266. raise Exception(f"Filter not found: {filter_id}")
  267. function_module = get_function_module_by_id(filter_id)
  268. if getattr(function_module, "toggle", None):
  269. model["filters"].extend(
  270. get_filter_items_from_module(filter_function, function_module)
  271. )
  272. log.debug(f"get_all_models() returned {len(models)} models")
  273. request.app.state.MODELS = {model["id"]: model for model in models}
  274. return models
  275. def check_model_access(user, model):
  276. if model.get("arena"):
  277. if not has_access(
  278. user.id,
  279. type="read",
  280. access_control=model.get("info", {})
  281. .get("meta", {})
  282. .get("access_control", {}),
  283. ):
  284. raise Exception("Model not found")
  285. else:
  286. model_info = Models.get_model_by_id(model.get("id"))
  287. if not model_info:
  288. raise Exception("Model not found")
  289. elif not (
  290. user.id == model_info.user_id
  291. or has_access(
  292. user.id, type="read", access_control=model_info.access_control
  293. )
  294. ):
  295. raise Exception("Model not found")
  296. def get_filtered_models(models, user):
  297. # Filter out models that the user does not have access to
  298. if (
  299. user.role == "user"
  300. or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
  301. ) and not BYPASS_MODEL_ACCESS_CONTROL:
  302. filtered_models = []
  303. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
  304. for model in models:
  305. if model.get("arena"):
  306. if has_access(
  307. user.id,
  308. type="read",
  309. access_control=model.get("info", {})
  310. .get("meta", {})
  311. .get("access_control", {}),
  312. user_group_ids=user_group_ids,
  313. ):
  314. filtered_models.append(model)
  315. continue
  316. model_info = Models.get_model_by_id(model["id"])
  317. if model_info:
  318. if (
  319. (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
  320. or user.id == model_info.user_id
  321. or has_access(
  322. user.id,
  323. type="read",
  324. access_control=model_info.access_control,
  325. user_group_ids=user_group_ids,
  326. )
  327. ):
  328. filtered_models.append(model)
  329. return filtered_models
  330. else:
  331. return models