models.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import time
  2. import logging
  3. import asyncio
  4. import sys
  5. from aiocache import cached
  6. from fastapi import Request
  7. from open_webui.routers import openai, ollama
  8. from open_webui.functions import get_function_models
  9. from open_webui.models.functions import Functions
  10. from open_webui.models.models import Models
  11. from open_webui.utils.plugin import (
  12. load_function_module_by_id,
  13. get_function_module_from_cache,
  14. )
  15. from open_webui.utils.access_control import has_access
  16. from open_webui.config import (
  17. DEFAULT_ARENA_MODEL,
  18. )
  19. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  20. from open_webui.models.users import UserModel
  21. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  22. log = logging.getLogger(__name__)
  23. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  24. async def fetch_ollama_models(request: Request, user: UserModel = None):
  25. raw_ollama_models = await ollama.get_all_models(request, user=user)
  26. return [
  27. {
  28. "id": model["model"],
  29. "name": model["name"],
  30. "object": "model",
  31. "created": int(time.time()),
  32. "owned_by": "ollama",
  33. "ollama": model,
  34. "connection_type": model.get("connection_type", "local"),
  35. "tags": model.get("tags", []),
  36. }
  37. for model in raw_ollama_models["models"]
  38. ]
  39. async def fetch_openai_models(request: Request, user: UserModel = None):
  40. openai_response = await openai.get_all_models(request, user=user)
  41. return openai_response["data"]
  42. async def get_all_base_models(request: Request, user: UserModel = None):
  43. openai_task = (
  44. fetch_openai_models(request, user)
  45. if request.app.state.config.ENABLE_OPENAI_API
  46. else asyncio.sleep(0, result=[])
  47. )
  48. ollama_task = (
  49. fetch_ollama_models(request, user)
  50. if request.app.state.config.ENABLE_OLLAMA_API
  51. else asyncio.sleep(0, result=[])
  52. )
  53. function_task = get_function_models(request)
  54. openai_models, ollama_models, function_models = await asyncio.gather(
  55. openai_task, ollama_task, function_task
  56. )
  57. return function_models + openai_models + ollama_models
  58. async def get_all_models(request, refresh: bool = False, user: UserModel = None):
  59. if (
  60. request.app.state.MODELS
  61. and request.app.state.BASE_MODELS
  62. and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
  63. ):
  64. models = request.app.state.BASE_MODELS
  65. else:
  66. models = await get_all_base_models(request, user=user)
  67. request.app.state.BASE_MODELS = models
  68. # If there are no models, return an empty list
  69. if len(models) == 0:
  70. return []
  71. # Add arena models
  72. if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  73. arena_models = []
  74. if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
  75. arena_models = [
  76. {
  77. "id": model["id"],
  78. "name": model["name"],
  79. "info": {
  80. "meta": model["meta"],
  81. },
  82. "object": "model",
  83. "created": int(time.time()),
  84. "owned_by": "arena",
  85. "arena": True,
  86. }
  87. for model in request.app.state.config.EVALUATION_ARENA_MODELS
  88. ]
  89. else:
  90. # Add default arena model
  91. arena_models = [
  92. {
  93. "id": DEFAULT_ARENA_MODEL["id"],
  94. "name": DEFAULT_ARENA_MODEL["name"],
  95. "info": {
  96. "meta": DEFAULT_ARENA_MODEL["meta"],
  97. },
  98. "object": "model",
  99. "created": int(time.time()),
  100. "owned_by": "arena",
  101. "arena": True,
  102. }
  103. ]
  104. models = models + arena_models
  105. global_action_ids = [
  106. function.id for function in Functions.get_global_action_functions()
  107. ]
  108. enabled_action_ids = [
  109. function.id
  110. for function in Functions.get_functions_by_type("action", active_only=True)
  111. ]
  112. global_filter_ids = [
  113. function.id for function in Functions.get_global_filter_functions()
  114. ]
  115. enabled_filter_ids = [
  116. function.id
  117. for function in Functions.get_functions_by_type("filter", active_only=True)
  118. ]
  119. custom_models = Models.get_all_models()
  120. for custom_model in custom_models:
  121. if custom_model.base_model_id is None:
  122. for model in models:
  123. if custom_model.id == model["id"] or (
  124. model.get("owned_by") == "ollama"
  125. and custom_model.id
  126. == model["id"].split(":")[
  127. 0
  128. ] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
  129. ):
  130. if custom_model.is_active:
  131. model["name"] = custom_model.name
  132. model["info"] = custom_model.model_dump()
  133. # Set action_ids and filter_ids
  134. action_ids = []
  135. filter_ids = []
  136. if "info" in model and "meta" in model["info"]:
  137. action_ids.extend(
  138. model["info"]["meta"].get("actionIds", [])
  139. )
  140. filter_ids.extend(
  141. model["info"]["meta"].get("filterIds", [])
  142. )
  143. model["action_ids"] = action_ids
  144. model["filter_ids"] = filter_ids
  145. else:
  146. models.remove(model)
  147. elif custom_model.is_active and (
  148. custom_model.id not in [model["id"] for model in models]
  149. ):
  150. owned_by = "openai"
  151. pipe = None
  152. action_ids = []
  153. filter_ids = []
  154. for model in models:
  155. if (
  156. custom_model.base_model_id == model["id"]
  157. or custom_model.base_model_id == model["id"].split(":")[0]
  158. ):
  159. owned_by = model.get("owned_by", "unknown owner")
  160. if "pipe" in model:
  161. pipe = model["pipe"]
  162. break
  163. if custom_model.meta:
  164. meta = custom_model.meta.model_dump()
  165. if "actionIds" in meta:
  166. action_ids.extend(meta["actionIds"])
  167. if "filterIds" in meta:
  168. filter_ids.extend(meta["filterIds"])
  169. models.append(
  170. {
  171. "id": f"{custom_model.id}",
  172. "name": custom_model.name,
  173. "object": "model",
  174. "created": custom_model.created_at,
  175. "owned_by": owned_by,
  176. "info": custom_model.model_dump(),
  177. "preset": True,
  178. **({"pipe": pipe} if pipe is not None else {}),
  179. "action_ids": action_ids,
  180. "filter_ids": filter_ids,
  181. }
  182. )
  183. # Process action_ids to get the actions
  184. def get_action_items_from_module(function, module):
  185. actions = []
  186. if hasattr(module, "actions"):
  187. actions = module.actions
  188. return [
  189. {
  190. "id": f"{function.id}.{action['id']}",
  191. "name": action.get("name", f"{function.name} ({action['id']})"),
  192. "description": function.meta.description,
  193. "icon": action.get(
  194. "icon_url",
  195. function.meta.manifest.get("icon_url", None)
  196. or getattr(module, "icon_url", None)
  197. or getattr(module, "icon", None),
  198. ),
  199. }
  200. for action in actions
  201. ]
  202. else:
  203. return [
  204. {
  205. "id": function.id,
  206. "name": function.name,
  207. "description": function.meta.description,
  208. "icon": function.meta.manifest.get("icon_url", None)
  209. or getattr(module, "icon_url", None)
  210. or getattr(module, "icon", None),
  211. }
  212. ]
  213. # Process filter_ids to get the filters
  214. def get_filter_items_from_module(function, module):
  215. return [
  216. {
  217. "id": function.id,
  218. "name": function.name,
  219. "description": function.meta.description,
  220. "icon": function.meta.manifest.get("icon_url", None)
  221. or getattr(module, "icon_url", None)
  222. or getattr(module, "icon", None),
  223. }
  224. ]
  225. def get_function_module_by_id(function_id):
  226. function_module, _, _ = get_function_module_from_cache(request, function_id)
  227. return function_module
  228. for model in models:
  229. action_ids = [
  230. action_id
  231. for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
  232. if action_id in enabled_action_ids
  233. ]
  234. filter_ids = [
  235. filter_id
  236. for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
  237. if filter_id in enabled_filter_ids
  238. ]
  239. model["actions"] = []
  240. for action_id in action_ids:
  241. action_function = Functions.get_function_by_id(action_id)
  242. if action_function is None:
  243. raise Exception(f"Action not found: {action_id}")
  244. function_module = get_function_module_by_id(action_id)
  245. model["actions"].extend(
  246. get_action_items_from_module(action_function, function_module)
  247. )
  248. model["filters"] = []
  249. for filter_id in filter_ids:
  250. filter_function = Functions.get_function_by_id(filter_id)
  251. if filter_function is None:
  252. raise Exception(f"Filter not found: {filter_id}")
  253. function_module = get_function_module_by_id(filter_id)
  254. if getattr(function_module, "toggle", None):
  255. model["filters"].extend(
  256. get_filter_items_from_module(filter_function, function_module)
  257. )
  258. log.debug(f"get_all_models() returned {len(models)} models")
  259. request.app.state.MODELS = {model["id"]: model for model in models}
  260. return models
  261. def check_model_access(user, model):
  262. if model.get("arena"):
  263. if not has_access(
  264. user.id,
  265. type="read",
  266. access_control=model.get("info", {})
  267. .get("meta", {})
  268. .get("access_control", {}),
  269. ):
  270. raise Exception("Model not found")
  271. else:
  272. model_info = Models.get_model_by_id(model.get("id"))
  273. if not model_info:
  274. raise Exception("Model not found")
  275. elif not (
  276. user.id == model_info.user_id
  277. or has_access(
  278. user.id, type="read", access_control=model_info.access_control
  279. )
  280. ):
  281. raise Exception("Model not found")