prompts.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import time
  2. from typing import Optional
  3. from open_webui.internal.db import Base, get_db
  4. from open_webui.models.groups import Groups
  5. from open_webui.models.users import Users, UserResponse
  6. from pydantic import BaseModel, ConfigDict
  7. from sqlalchemy import BigInteger, Column, String, Text, JSON
  8. from open_webui.utils.access_control import has_access
  9. ####################
  10. # Prompts DB Schema
  11. ####################
  12. class Prompt(Base):
  13. __tablename__ = "prompt"
  14. command = Column(String, primary_key=True)
  15. user_id = Column(String)
  16. title = Column(Text)
  17. content = Column(Text)
  18. timestamp = Column(BigInteger)
  19. access_control = Column(JSON, nullable=True) # Controls data access levels.
  20. # Defines access control rules for this entry.
  21. # - `None`: Public access, available to all users with the "user" role.
  22. # - `{}`: Private access, restricted exclusively to the owner.
  23. # - Custom permissions: Specific access control for reading and writing;
  24. # Can specify group or user-level restrictions:
  25. # {
  26. # "read": {
  27. # "group_ids": ["group_id1", "group_id2"],
  28. # "user_ids": ["user_id1", "user_id2"]
  29. # },
  30. # "write": {
  31. # "group_ids": ["group_id1", "group_id2"],
  32. # "user_ids": ["user_id1", "user_id2"]
  33. # }
  34. # }
  35. class PromptModel(BaseModel):
  36. command: str
  37. user_id: str
  38. title: str
  39. content: str
  40. timestamp: int # timestamp in epoch
  41. access_control: Optional[dict] = None
  42. model_config = ConfigDict(from_attributes=True)
  43. ####################
  44. # Forms
  45. ####################
  46. class PromptUserResponse(PromptModel):
  47. user: Optional[UserResponse] = None
  48. class PromptForm(BaseModel):
  49. command: str
  50. title: str
  51. content: str
  52. access_control: Optional[dict] = None
  53. class PromptsTable:
  54. def insert_new_prompt(
  55. self, user_id: str, form_data: PromptForm
  56. ) -> Optional[PromptModel]:
  57. prompt = PromptModel(
  58. **{
  59. "user_id": user_id,
  60. **form_data.model_dump(),
  61. "timestamp": int(time.time()),
  62. }
  63. )
  64. try:
  65. with get_db() as db:
  66. result = Prompt(**prompt.model_dump())
  67. db.add(result)
  68. db.commit()
  69. db.refresh(result)
  70. if result:
  71. return PromptModel.model_validate(result)
  72. else:
  73. return None
  74. except Exception:
  75. return None
  76. def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
  77. try:
  78. with get_db() as db:
  79. prompt = db.query(Prompt).filter_by(command=command).first()
  80. return PromptModel.model_validate(prompt)
  81. except Exception:
  82. return None
  83. def get_prompts(self) -> list[PromptUserResponse]:
  84. with get_db() as db:
  85. all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
  86. user_ids = list(set(prompt.user_id for prompt in all_prompts))
  87. users = Users.get_users_by_user_ids(user_ids) if user_ids else []
  88. users_dict = {user.id: user for user in users}
  89. prompts = []
  90. for prompt in all_prompts:
  91. user = users_dict.get(prompt.user_id)
  92. prompts.append(
  93. PromptUserResponse.model_validate(
  94. {
  95. **PromptModel.model_validate(prompt).model_dump(),
  96. "user": user.model_dump() if user else None,
  97. }
  98. )
  99. )
  100. return prompts
  101. def get_prompts_by_user_id(
  102. self, user_id: str, permission: str = "write"
  103. ) -> list[PromptUserResponse]:
  104. prompts = self.get_prompts()
  105. user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)}
  106. return [
  107. prompt
  108. for prompt in prompts
  109. if prompt.user_id == user_id
  110. or has_access(user_id, permission, prompt.access_control, user_group_ids)
  111. ]
  112. def update_prompt_by_command(
  113. self, command: str, form_data: PromptForm
  114. ) -> Optional[PromptModel]:
  115. try:
  116. with get_db() as db:
  117. prompt = db.query(Prompt).filter_by(command=command).first()
  118. prompt.title = form_data.title
  119. prompt.content = form_data.content
  120. prompt.access_control = form_data.access_control
  121. prompt.timestamp = int(time.time())
  122. db.commit()
  123. return PromptModel.model_validate(prompt)
  124. except Exception:
  125. return None
  126. def delete_prompt_by_command(self, command: str) -> bool:
  127. try:
  128. with get_db() as db:
  129. db.query(Prompt).filter_by(command=command).delete()
  130. db.commit()
  131. return True
  132. except Exception:
  133. return False
  134. Prompts = PromptsTable()