瀏覽代碼

refac: use dependencies to verify token

- feat: added new util to get the current user when needed. Middleware was adding authentication logic to all the routes. let's revisit if we can move the non-auth endpoints to a separate route.
- refac: update the routes to use new helpers for verification and retrieving user
- chore: added black for local formatting of py code
Anuraag Jain 1 年之前
父節點
當前提交
bdd153d8f5

+ 7 - 3
backend/apps/ollama/main.py

@@ -8,7 +8,7 @@ import json
 
 
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
-from utils.utils import extract_token_from_auth_header
+from utils.utils import decode_token
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 
 app = Flask(__name__)
 app = Flask(__name__)
@@ -34,8 +34,12 @@ def proxy(path):
     # Basic RBAC support
     # Basic RBAC support
     if WEBUI_AUTH:
     if WEBUI_AUTH:
         if "Authorization" in headers:
         if "Authorization" in headers:
-            token = extract_token_from_auth_header(headers["Authorization"])
-            user = Users.get_user_by_token(token)
+            _, credentials = headers["Authorization"].split()
+            token_data = decode_token(credentials)
+            if token_data is None or "email" not in token_data:
+                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
+
+            user = Users.get_user_by_email(token_data["email"])
             if user:
             if user:
                 # Only user and admin roles can access
                 # Only user and admin roles can access
                 if user.role in ["user", "admin"]:
                 if user.role in ["user", "admin"]:

+ 21 - 8
backend/apps/web/main.py

@@ -1,9 +1,10 @@
-from fastapi import FastAPI
+from fastapi import FastAPI, Depends
+from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.middleware.authentication import AuthenticationMiddleware
 from starlette.middleware.authentication import AuthenticationMiddleware
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 from config import WEBUI_VERSION, WEBUI_AUTH
-from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
+from utils.utils import verify_auth_token
 
 
 app = FastAPI()
 app = FastAPI()
 
 
@@ -17,14 +18,26 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 
 
-app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
-
-app.include_router(users.router, prefix="/users", tags=["users"])
-app.include_router(chats.router, prefix="/chats", tags=["chats"])
-app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
+app.include_router(
+    users.router,
+    prefix="/users",
+    tags=["users"],
+    dependencies=[Depends(verify_auth_token)],
+)
+app.include_router(
+    chats.router,
+    prefix="/chats",
+    tags=["chats"],
+    dependencies=[Depends(verify_auth_token)],
+)
+app.include_router(
+    modelfiles.router,
+    prefix="/modelfiles",
+    tags=["modelfiles"],
+    dependencies=[Depends(verify_auth_token)],
+)
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 
 
 

+ 0 - 27
backend/apps/web/middlewares/auth.py

@@ -1,27 +0,0 @@
-from apps.web.models.users import Users
-from fastapi import Request, status
-from starlette.authentication import (
-    AuthCredentials, AuthenticationBackend, AuthenticationError, 
-)
-from starlette.requests import HTTPConnection
-from utils.utils import verify_token
-from starlette.responses import JSONResponse
-from constants import ERROR_MESSAGES
-
-class BearerTokenAuthBackend(AuthenticationBackend):
-
-    async def authenticate(self, conn: HTTPConnection):
-        if "Authorization" not in conn.headers:
-            return
-        data = verify_token(conn)
-        if data != None and 'email' in data:
-            user = Users.get_user_by_email(data['email'])
-            if user is None:
-                raise AuthenticationError('Invalid credentials') 
-            return AuthCredentials([user.role]), user
-        else:
-            raise AuthenticationError('Invalid credentials') 
-
-def on_auth_error(request: Request, exc: Exception):
-    print('Authentication failed: ', exc)
-    return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

+ 0 - 10
backend/apps/web/models/users.py

@@ -3,8 +3,6 @@ from peewee import *
 from playhouse.shortcuts import model_to_dict
 from playhouse.shortcuts import model_to_dict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 import time
 import time
-
-from utils.utils import decode_token
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 
 
 from apps.web.internal.db import DB
 from apps.web.internal.db import DB
@@ -83,14 +81,6 @@ class UsersTable:
         except:
         except:
             return None
             return None
 
 
-    def get_user_by_token(self, token: str) -> Optional[UserModel]:
-        data = decode_token(token)
-
-        if data != None and "email" in data:
-            return self.get_user_by_email(data["email"])
-        else:
-            return None
-
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
         return [
         return [
             UserModel(**model_to_dict(user))
             UserModel(**model_to_dict(user))

+ 9 - 17
backend/apps/web/routers/auths.py

@@ -20,7 +20,7 @@ from apps.web.models.users import Users
 
 
 from utils.utils import (
 from utils.utils import (
     get_password_hash,
     get_password_hash,
-    bearer_scheme,
+    get_current_user,
     create_token,
     create_token,
 )
 )
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
@@ -35,22 +35,14 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=UserResponse)
 @router.get("/", response_model=UserResponse)
-async def get_session_user(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-    if user:
-        return {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-            "profile_image_url": user.profile_image_url,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_session_user(user=Depends(get_current_user)):
+    return {
+        "id": user.id,
+        "email": user.email,
+        "name": user.name,
+        "role": user.role,
+        "profile_image_url": user.profile_image_url,
+    }
 
 
 
 
 ############################
 ############################

+ 30 - 26
backend/apps/web/routers/chats.py

@@ -1,8 +1,7 @@
-
 from fastapi import Depends, Request, HTTPException, status
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
-
+from utils.utils import get_current_user
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
@@ -30,8 +29,10 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/", response_model=List[ChatTitleIdResponse])
-async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
-    return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
+async def get_user_chats(
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -40,11 +41,11 @@ async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
 
 
 
 
 @router.get("/all", response_model=List[ChatResponse])
 @router.get("/all", response_model=List[ChatResponse])
-async def get_all_user_chats(request:Request,):
+async def get_all_user_chats(user=Depends(get_current_user)):
     return [
     return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats_by_user_id(request.user.id)
-        ]
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats_by_user_id(user.id)
+    ]
 
 
 
 
 ############################
 ############################
@@ -53,8 +54,8 @@ async def get_all_user_chats(request:Request,):
 
 
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm,request:Request):
-    chat = Chats.insert_new_chat(request.user.id, form_data)
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+    chat = Chats.insert_new_chat(user.id, form_data)
     return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
 
 
 
 
@@ -64,14 +65,15 @@ async def create_new_chat(form_data: ChatForm,request:Request):
 
 
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, request:Request):
-    chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
+async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
 
 
     if chat:
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
-        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
-                            detail=ERROR_MESSAGES.NOT_FOUND)
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
+        )
 
 
 
 
 ############################
 ############################
@@ -80,18 +82,20 @@ async def get_chat_by_id(id: str, request:Request):
 
 
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 @router.post("/{id}", response_model=Optional[ChatResponse])
-async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
-    chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
+async def update_chat_by_id(
+    id: str, form_data: ChatForm, user=Depends(get_current_user)
+):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
-            updated_chat = {**json.loads(chat.chat), **form_data.chat}
+        updated_chat = {**json.loads(chat.chat), **form_data.chat}
 
 
-            chat = Chats.update_chat_by_id(id, updated_chat)
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        chat = Chats.update_chat_by_id(id, updated_chat)
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
 
 
 
 
 ############################
 ############################
@@ -100,6 +104,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
 
 
 
 
 @router.delete("/{id}", response_model=bool)
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(id: str, request: Request):
-    result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
-    return result
+async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
+    result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+    return result

+ 62 - 112
backend/apps/web/routers/modelfiles.py

@@ -1,4 +1,3 @@
-from fastapi import Response
 from fastapi import Depends, FastAPI, HTTPException, status
 from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
@@ -16,9 +15,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
     ModelfileResponse,
 )
 )
 
 
-from utils.utils import (
-    bearer_scheme,
-)
+from utils.utils import bearer_scheme, get_current_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -30,16 +27,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ModelfileResponse])
 @router.get("/", response_model=List[ModelfileResponse])
 async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
 async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Modelfiles.get_modelfiles(skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+    return Modelfiles.get_modelfiles(skip, limit)
 
 
 
 
 ############################
 ############################
@@ -48,36 +36,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
 @router.post("/create", response_model=Optional[ModelfileResponse])
-async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        # Admin Only
-        if user.role == "admin":
-            modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
-
-            if modelfile:
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def create_new_modelfile(
+    form_data: ModelfileForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
         )
 
 
 
 
@@ -87,31 +67,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
 
 
 
 
 @router.post("/", response_model=Optional[ModelfileResponse])
 @router.post("/", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
-):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-
-        if modelfile:
-            return ModelfileResponse(
-                **{
-                    **modelfile.model_dump(),
-                    "modelfile": json.loads(modelfile.modelfile),
-                }
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm):
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.NOT_FOUND,
         )
         )
 
 
 
 
@@ -122,44 +91,34 @@ async def get_modelfile_by_tag_name(
 
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(
 async def update_modelfile_by_tag_name(
-    form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileUpdateForm, user=Depends(get_current_user)
 ):
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-            if modelfile:
-                updated_modelfile = {
-                    **json.loads(modelfile.modelfile),
-                    **form_data.modelfile,
-                }
-
-                modelfile = Modelfiles.update_modelfile_by_tag_name(
-                    form_data.tag_name, updated_modelfile
-                )
-
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+    if modelfile:
+        updated_modelfile = {
+            **json.loads(modelfile.modelfile),
+            **form_data.modelfile,
+        }
+
+        modelfile = Modelfiles.update_modelfile_by_tag_name(
+            form_data.tag_name, updated_modelfile
+        )
+
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
 
 
 
 
@@ -170,22 +129,13 @@ async def update_modelfile_by_tag_name(
 
 
 @router.delete("/delete", response_model=bool)
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(
 async def delete_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileTagNameForm, user=Depends(get_current_user)
 ):
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
-            return result
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+    if user.role != "admin":
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
+
+    result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
+    return result

+ 18 - 39
backend/apps/web/routers/users.py

@@ -10,11 +10,7 @@ import uuid
 
 
 from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
 from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 router = APIRouter()
 router = APIRouter()
@@ -25,23 +21,13 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[UserModel])
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            return Users.get_users(skip, limit)
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    if user.role != "admin":
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
+    return Users.get_users(skip, limit)
 
 
 
 
 ############################
 ############################
@@ -50,26 +36,19 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
 
 
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
 
 
-    if user:
-        if user.role == "admin":
-            if user.id != form_data.id:
-                return Users.update_user_role_by_id(form_data.id, form_data.role)
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.id != form_data.id:
+        return Users.update_user_role_by_id(form_data.id, form_data.role)
     else:
     else:
         raise HTTPException(
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
         )
         )

+ 2 - 0
backend/requirements.txt

@@ -18,3 +18,5 @@ bcrypt
 
 
 PyJWT
 PyJWT
 pyjwt[crypto]
 pyjwt[crypto]
+
+black

+ 23 - 14
backend/utils/utils.py

@@ -1,7 +1,9 @@
-from fastapi.security import HTTPBasicCredentials, HTTPBearer
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from fastapi import HTTPException, status, Depends
+from apps.web.models.users import Users
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import Union, Optional
 from typing import Union, Optional
-
+from constants import ERROR_MESSAGES
 from passlib.context import CryptContext
 from passlib.context import CryptContext
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 import requests
 import requests
@@ -53,16 +55,23 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
     return auth_header[len("Bearer ") :]
 
 
 
 
-def verify_token(request):
-    try:
-        authorization = request.headers["authorization"]
-        if authorization:
-            _, token = authorization.split()
-            decoded_token = jwt.decode(
-                token, JWT_SECRET_KEY, options={"verify_signature": False}
+def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    if data != None and "email" in data:
+        user = Users.get_user_by_email(data["email"])
+        if user is None:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
             )
-            return decoded_token
-        else:
-            return None
-    except Exception as e:
-        return None
+        return
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )
+
+
+def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    return Users.get_user_by_email(data["email"])