ソースを参照

refactor(jwt): migrate jwt to v5

Jacky 6 ヶ月 前
コミット
a93938eedb
1 ファイル変更27 行追加51 行削除
  1. 27 51
      internal/user/user.go

+ 27 - 51
internal/user/user.go

@@ -3,11 +3,11 @@ package user
 import (
 	"github.com/0xJacky/Nginx-UI/model"
 	"github.com/0xJacky/Nginx-UI/query"
-	"github.com/golang-jwt/jwt/v4"
+	"github.com/golang-jwt/jwt/v5"
 	"github.com/pkg/errors"
+	"github.com/spf13/cast"
 	"github.com/uozi-tech/cosy/logger"
 	cSettings "github.com/uozi-tech/cosy/settings"
-	"strings"
 	"time"
 )
 
@@ -16,14 +16,7 @@ const ExpiredTime = 24 * time.Hour
 type JWTClaims struct {
 	Name   string `json:"name"`
 	UserID int    `json:"user_id"`
-	jwt.StandardClaims
-}
-
-func BuildCacheTokenKey(token string) string {
-	var sb strings.Builder
-	sb.WriteString("token:")
-	sb.WriteString(token)
-	return sb.String()
+	jwt.RegisteredClaims
 }
 
 func GetUser(name string) (user *model.User, err error) {
@@ -42,6 +35,12 @@ func DeleteToken(token string) {
 }
 
 func GetTokenUser(token string) (*model.User, bool) {
+	_, err := ValidateJWT(token)
+	if err != nil {
+		logger.Error(err)
+		return nil, false
+	}
+
 	q := query.AuthToken
 	authToken, err := q.Where(q.Token.Eq(token)).First()
 	if err != nil {
@@ -59,11 +58,17 @@ func GetTokenUser(token string) (*model.User, bool) {
 }
 
 func GenerateJWT(user *model.User) (string, error) {
+	now := time.Now()
 	claims := JWTClaims{
 		Name:   user.Name,
 		UserID: user.ID,
-		StandardClaims: jwt.StandardClaims{
-			ExpiresAt: time.Now().Add(ExpiredTime).Unix(),
+		RegisteredClaims: jwt.RegisteredClaims{
+			ExpiresAt: jwt.NewNumericDate(now.Add(ExpiredTime)),
+			IssuedAt:  jwt.NewNumericDate(now),
+			NotBefore: jwt.NewNumericDate(now),
+			Issuer:    "Nginx UI",
+			Subject:   user.Name,
+			ID:        cast.ToString(user.ID),
 		},
 	}
 
@@ -77,7 +82,7 @@ func GenerateJWT(user *model.User) (string, error) {
 	err = q.Create(&model.AuthToken{
 		UserID:    user.ID,
 		Token:     signedToken,
-		ExpiredAt: time.Now().Add(ExpiredTime).Unix(),
+		ExpiredAt: now.Add(ExpiredTime).Unix(),
 	})
 
 	if err != nil {
@@ -87,49 +92,20 @@ func GenerateJWT(user *model.User) (string, error) {
 	return signedToken, err
 }
 
-func ValidateJWT(token string) (claims *JWTClaims, err error) {
-	if token == "" {
+func ValidateJWT(tokenStr string) (claims *JWTClaims, err error) {
+	if tokenStr == "" {
 		err = errors.New("token is empty")
 		return
 	}
-	unsignedToken, err := jwt.ParseWithClaims(
-		token,
-		&JWTClaims{},
-		func(token *jwt.Token) (interface{}, error) {
-			return []byte(cSettings.AppSettings.JwtSecret), nil
-		},
-	)
-	if err != nil {
-		err = errors.New("parse with claims error")
-		return
-	}
-	claims, ok := unsignedToken.Claims.(*JWTClaims)
-	if !ok {
-		err = errors.New("convert to jwt claims error")
-		return
-	}
-	if claims.ExpiresAt < time.Now().UTC().Unix() {
-		err = errors.New("jwt is expired")
-	}
-	return
-}
-
-func CurrentUser(token string) (u *model.User, err error) {
-	// validate token
-	var claims *JWTClaims
-	claims, err = ValidateJWT(token)
+	token, err := jwt.ParseWithClaims(tokenStr, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
+		return []byte(cSettings.AppSettings.JwtSecret), nil
+	})
 	if err != nil {
 		return
 	}
-
-	// get user by id
-	user := query.User
-	u, err = user.FirstByID(claims.UserID)
-	if err != nil {
-		return
+	var ok bool
+	if claims, ok = token.Claims.(*JWTClaims); ok && token.Valid {
+		return claims, nil
 	}
-
-	logger.Info("[Current User]", u.Name)
-
-	return
+	return nil, errors.New("invalid claims type")
 }