misc.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. import hashlib
  2. import re
  3. import threading
  4. import time
  5. import uuid
  6. import logging
  7. from datetime import timedelta
  8. from pathlib import Path
  9. from typing import Callable, Optional
  10. import json
  11. import collections.abc
  12. from open_webui.env import SRC_LOG_LEVELS
  13. log = logging.getLogger(__name__)
  14. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  15. def deep_update(d, u):
  16. for k, v in u.items():
  17. if isinstance(v, collections.abc.Mapping):
  18. d[k] = deep_update(d.get(k, {}), v)
  19. else:
  20. d[k] = v
  21. return d
  22. def get_message_list(messages_map, message_id):
  23. """
  24. Reconstructs a list of messages in order up to the specified message_id.
  25. :param message_id: ID of the message to reconstruct the chain
  26. :param messages: Message history dict containing all messages
  27. :return: List of ordered messages starting from the root to the given message
  28. """
  29. # Handle case where messages is None
  30. if not messages_map:
  31. return [] # Return empty list instead of None to prevent iteration errors
  32. # Find the message by its id
  33. current_message = messages_map.get(message_id)
  34. if not current_message:
  35. return [] # Return empty list instead of None to prevent iteration errors
  36. # Reconstruct the chain by following the parentId links
  37. message_list = []
  38. while current_message:
  39. message_list.insert(
  40. 0, current_message
  41. ) # Insert the message at the beginning of the list
  42. parent_id = current_message.get("parentId") # Use .get() for safety
  43. current_message = messages_map.get(parent_id) if parent_id else None
  44. return message_list
  45. def get_messages_content(messages: list[dict]) -> str:
  46. return "\n".join(
  47. [
  48. f"{message['role'].upper()}: {get_content_from_message(message)}"
  49. for message in messages
  50. ]
  51. )
  52. def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
  53. for message in reversed(messages):
  54. if message["role"] == "user":
  55. return message
  56. return None
  57. def get_content_from_message(message: dict) -> Optional[str]:
  58. if isinstance(message.get("content"), list):
  59. for item in message["content"]:
  60. if item["type"] == "text":
  61. return item["text"]
  62. else:
  63. return message.get("content")
  64. return None
  65. def get_last_user_message(messages: list[dict]) -> Optional[str]:
  66. message = get_last_user_message_item(messages)
  67. if message is None:
  68. return None
  69. return get_content_from_message(message)
  70. def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]:
  71. for message in reversed(messages):
  72. if message["role"] == "assistant":
  73. return message
  74. return None
  75. def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
  76. for message in reversed(messages):
  77. if message["role"] == "assistant":
  78. return get_content_from_message(message)
  79. return None
  80. def get_system_message(messages: list[dict]) -> Optional[dict]:
  81. for message in messages:
  82. if message["role"] == "system":
  83. return message
  84. return None
  85. def remove_system_message(messages: list[dict]) -> list[dict]:
  86. return [message for message in messages if message["role"] != "system"]
  87. def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
  88. return get_system_message(messages), remove_system_message(messages)
  89. def update_message_content(message: dict, content: str, append: bool = True) -> dict:
  90. if isinstance(message["content"], list):
  91. for item in message["content"]:
  92. if item["type"] == "text":
  93. if append:
  94. item["text"] = f"{item['text']}\n{content}"
  95. else:
  96. item["text"] = f"{content}\n{item['text']}"
  97. else:
  98. if append:
  99. message["content"] = f"{message['content']}\n{content}"
  100. else:
  101. message["content"] = f"{content}\n{message['content']}"
  102. return message
  103. def replace_system_message_content(content: str, messages: list[dict]) -> dict:
  104. for message in messages:
  105. if message["role"] == "system":
  106. message["content"] = content
  107. break
  108. return messages
  109. def add_or_update_system_message(
  110. content: str, messages: list[dict], append: bool = False
  111. ):
  112. """
  113. Adds a new system message at the beginning of the messages list
  114. or updates the existing system message at the beginning.
  115. :param msg: The message to be added or appended.
  116. :param messages: The list of message dictionaries.
  117. :return: The updated list of message dictionaries.
  118. """
  119. if messages and messages[0].get("role") == "system":
  120. messages[0] = update_message_content(messages[0], content, append)
  121. else:
  122. # Insert at the beginning
  123. messages.insert(0, {"role": "system", "content": content})
  124. return messages
  125. def add_or_update_user_message(content: str, messages: list[dict], append: bool = True):
  126. """
  127. Adds a new user message at the end of the messages list
  128. or updates the existing user message at the end.
  129. :param msg: The message to be added or appended.
  130. :param messages: The list of message dictionaries.
  131. :return: The updated list of message dictionaries.
  132. """
  133. if messages and messages[-1].get("role") == "user":
  134. messages[-1] = update_message_content(messages[-1], content, append)
  135. else:
  136. # Insert at the end
  137. messages.append({"role": "user", "content": content})
  138. return messages
  139. def prepend_to_first_user_message_content(
  140. content: str, messages: list[dict]
  141. ) -> list[dict]:
  142. for message in messages:
  143. if message["role"] == "user":
  144. message = update_message_content(message, content, append=False)
  145. break
  146. return messages
  147. def append_or_update_assistant_message(content: str, messages: list[dict]):
  148. """
  149. Adds a new assistant message at the end of the messages list
  150. or updates the existing assistant message at the end.
  151. :param msg: The message to be added or appended.
  152. :param messages: The list of message dictionaries.
  153. :return: The updated list of message dictionaries.
  154. """
  155. if messages and messages[-1].get("role") == "assistant":
  156. messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
  157. else:
  158. # Insert at the end
  159. messages.append({"role": "assistant", "content": content})
  160. return messages
  161. def openai_chat_message_template(model: str):
  162. return {
  163. "id": f"{model}-{str(uuid.uuid4())}",
  164. "created": int(time.time()),
  165. "model": model,
  166. "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
  167. }
  168. def openai_chat_chunk_message_template(
  169. model: str,
  170. content: Optional[str] = None,
  171. reasoning_content: Optional[str] = None,
  172. tool_calls: Optional[list[dict]] = None,
  173. usage: Optional[dict] = None,
  174. ) -> dict:
  175. template = openai_chat_message_template(model)
  176. template["object"] = "chat.completion.chunk"
  177. template["choices"][0]["index"] = 0
  178. template["choices"][0]["delta"] = {}
  179. if content:
  180. template["choices"][0]["delta"]["content"] = content
  181. if reasoning_content:
  182. template["choices"][0]["delta"]["reasoning_content"] = reasoning_content
  183. if tool_calls:
  184. template["choices"][0]["delta"]["tool_calls"] = tool_calls
  185. if not content and not reasoning_content and not tool_calls:
  186. template["choices"][0]["finish_reason"] = "stop"
  187. if usage:
  188. template["usage"] = usage
  189. return template
  190. def openai_chat_completion_message_template(
  191. model: str,
  192. message: Optional[str] = None,
  193. reasoning_content: Optional[str] = None,
  194. tool_calls: Optional[list[dict]] = None,
  195. usage: Optional[dict] = None,
  196. ) -> dict:
  197. template = openai_chat_message_template(model)
  198. template["object"] = "chat.completion"
  199. if message is not None:
  200. template["choices"][0]["message"] = {
  201. "role": "assistant",
  202. "content": message,
  203. **({"reasoning_content": reasoning_content} if reasoning_content else {}),
  204. **({"tool_calls": tool_calls} if tool_calls else {}),
  205. }
  206. template["choices"][0]["finish_reason"] = "stop"
  207. if usage:
  208. template["usage"] = usage
  209. return template
  210. def get_gravatar_url(email):
  211. # Trim leading and trailing whitespace from
  212. # an email address and force all characters
  213. # to lower case
  214. address = str(email).strip().lower()
  215. # Create a SHA256 hash of the final string
  216. hash_object = hashlib.sha256(address.encode())
  217. hash_hex = hash_object.hexdigest()
  218. # Grab the actual image URL
  219. return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
  220. def calculate_sha256(file_path, chunk_size):
  221. # Compute SHA-256 hash of a file efficiently in chunks
  222. sha256 = hashlib.sha256()
  223. with open(file_path, "rb") as f:
  224. while chunk := f.read(chunk_size):
  225. sha256.update(chunk)
  226. return sha256.hexdigest()
  227. def calculate_sha256_string(string):
  228. # Create a new SHA-256 hash object
  229. sha256_hash = hashlib.sha256()
  230. # Update the hash object with the bytes of the input string
  231. sha256_hash.update(string.encode("utf-8"))
  232. # Get the hexadecimal representation of the hash
  233. hashed_string = sha256_hash.hexdigest()
  234. return hashed_string
  235. def validate_email_format(email: str) -> bool:
  236. if email.endswith("@localhost"):
  237. return True
  238. return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
  239. def sanitize_filename(file_name):
  240. # Convert to lowercase
  241. lower_case_file_name = file_name.lower()
  242. # Remove special characters using regular expression
  243. sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
  244. # Replace spaces with dashes
  245. final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
  246. return final_file_name
  247. def extract_folders_after_data_docs(path):
  248. # Convert the path to a Path object if it's not already
  249. path = Path(path)
  250. # Extract parts of the path
  251. parts = path.parts
  252. # Find the index of '/data/docs' in the path
  253. try:
  254. index_data_docs = parts.index("data") + 1
  255. index_docs = parts.index("docs", index_data_docs) + 1
  256. except ValueError:
  257. return []
  258. # Exclude the filename and accumulate folder names
  259. tags = []
  260. folders = parts[index_docs:-1]
  261. for idx, _ in enumerate(folders):
  262. tags.append("/".join(folders[: idx + 1]))
  263. return tags
  264. def parse_duration(duration: str) -> Optional[timedelta]:
  265. if duration == "-1" or duration == "0":
  266. return None
  267. # Regular expression to find number and unit pairs
  268. pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
  269. matches = re.findall(pattern, duration)
  270. if not matches:
  271. raise ValueError("Invalid duration string")
  272. total_duration = timedelta()
  273. for number, _, unit in matches:
  274. number = float(number)
  275. if unit == "ms":
  276. total_duration += timedelta(milliseconds=number)
  277. elif unit == "s":
  278. total_duration += timedelta(seconds=number)
  279. elif unit == "m":
  280. total_duration += timedelta(minutes=number)
  281. elif unit == "h":
  282. total_duration += timedelta(hours=number)
  283. elif unit == "d":
  284. total_duration += timedelta(days=number)
  285. elif unit == "w":
  286. total_duration += timedelta(weeks=number)
  287. return total_duration
  288. def parse_ollama_modelfile(model_text):
  289. parameters_meta = {
  290. "mirostat": int,
  291. "mirostat_eta": float,
  292. "mirostat_tau": float,
  293. "num_ctx": int,
  294. "repeat_last_n": int,
  295. "repeat_penalty": float,
  296. "temperature": float,
  297. "seed": int,
  298. "tfs_z": float,
  299. "num_predict": int,
  300. "top_k": int,
  301. "top_p": float,
  302. "num_keep": int,
  303. "presence_penalty": float,
  304. "frequency_penalty": float,
  305. "num_batch": int,
  306. "num_gpu": int,
  307. "use_mmap": bool,
  308. "use_mlock": bool,
  309. "num_thread": int,
  310. }
  311. data = {"base_model_id": None, "params": {}}
  312. # Parse base model
  313. base_model_match = re.search(
  314. r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
  315. )
  316. if base_model_match:
  317. data["base_model_id"] = base_model_match.group(1)
  318. # Parse template
  319. template_match = re.search(
  320. r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  321. )
  322. if template_match:
  323. data["params"] = {"template": template_match.group(1).strip()}
  324. # Parse stops
  325. stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
  326. if stops:
  327. data["params"]["stop"] = stops
  328. # Parse other parameters from the provided list
  329. for param, param_type in parameters_meta.items():
  330. param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
  331. if param_match:
  332. value = param_match.group(1)
  333. try:
  334. if param_type is int:
  335. value = int(value)
  336. elif param_type is float:
  337. value = float(value)
  338. elif param_type is bool:
  339. value = value.lower() == "true"
  340. except Exception as e:
  341. log.exception(f"Failed to parse parameter {param}: {e}")
  342. continue
  343. data["params"][param] = value
  344. # Parse adapter
  345. adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
  346. if adapter_match:
  347. data["params"]["adapter"] = adapter_match.group(1)
  348. # Parse system description
  349. system_desc_match = re.search(
  350. r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  351. )
  352. system_desc_match_single = re.search(
  353. r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
  354. )
  355. if system_desc_match:
  356. data["params"]["system"] = system_desc_match.group(1).strip()
  357. elif system_desc_match_single:
  358. data["params"]["system"] = system_desc_match_single.group(1).strip()
  359. # Parse messages
  360. messages = []
  361. message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
  362. for role, content in message_matches:
  363. messages.append({"role": role, "content": content})
  364. if messages:
  365. data["params"]["messages"] = messages
  366. return data
  367. def convert_logit_bias_input_to_json(user_input):
  368. logit_bias_pairs = user_input.split(",")
  369. logit_bias_json = {}
  370. for pair in logit_bias_pairs:
  371. token, bias = pair.split(":")
  372. token = str(token.strip())
  373. bias = int(bias.strip())
  374. bias = 100 if bias > 100 else -100 if bias < -100 else bias
  375. logit_bias_json[token] = bias
  376. return json.dumps(logit_bias_json)
  377. def freeze(value):
  378. """
  379. Freeze a value to make it hashable.
  380. """
  381. if isinstance(value, dict):
  382. return frozenset((k, freeze(v)) for k, v in value.items())
  383. elif isinstance(value, list):
  384. return tuple(freeze(v) for v in value)
  385. return value
  386. def throttle(interval: float = 10.0):
  387. """
  388. Decorator to prevent a function from being called more than once within a specified duration.
  389. If the function is called again within the duration, it returns None. To avoid returning
  390. different types, the return type of the function should be Optional[T].
  391. :param interval: Duration in seconds to wait before allowing the function to be called again.
  392. """
  393. def decorator(func):
  394. last_calls = {}
  395. lock = threading.Lock()
  396. def wrapper(*args, **kwargs):
  397. if interval is None:
  398. return func(*args, **kwargs)
  399. key = (args, freeze(kwargs))
  400. now = time.time()
  401. if now - last_calls.get(key, 0) < interval:
  402. return None
  403. with lock:
  404. if now - last_calls.get(key, 0) < interval:
  405. return None
  406. last_calls[key] = now
  407. return func(*args, **kwargs)
  408. return wrapper
  409. return decorator
  410. def extract_urls(text: str) -> list[str]:
  411. # Regex pattern to match URLs
  412. url_pattern = re.compile(
  413. r"(https?://[^\s]+)", re.IGNORECASE
  414. ) # Matches http and https URLs
  415. return url_pattern.findall(text)