Просмотр исходного кода

feat(otp): encrypt recovery codes with AES

Hintay 2 месяцев назад
Родитель
Сommit
5ade465ac6
5 измененных файлов с 77 добавлено и 31 удалено
  1. 1 1
      api/user/otp.go
  2. 11 20
      api/user/recovery.go
  3. 52 1
      internal/crypto/aes.go
  4. 1 1
      internal/user/otp.go
  5. 12 8
      model/user.go

+ 1 - 1
api/user/otp.go

@@ -87,7 +87,7 @@ func EnrollTOTP(c *gin.Context) {
 		return
 	}
 
-	t := time.Now()
+	t := time.Now().Unix()
 	recoveryCodes := model.RecoveryCodes{Codes: generateRecoveryCodes(16), LastViewed: &t}
 	codesJson, err := json.Marshal(&recoveryCodes)
 	if err != nil {

+ 11 - 20
api/user/recovery.go

@@ -1,7 +1,6 @@
 package user
 
 import (
-	"encoding/json"
 	"fmt"
 	"math/rand"
 	"net/http"
@@ -23,10 +22,12 @@ func generateRecoveryCode() string {
 	return fmt.Sprintf("%05x-%05x", rand.Intn(0x100000), rand.Intn(0x100000))
 }
 
-func generateRecoveryCodes(count int) []model.RecoveryCode {
-	recoveryCodes := make([]model.RecoveryCode, count)
+func generateRecoveryCodes(count int) []*model.RecoveryCode {
+	recoveryCodes := make([]*model.RecoveryCode, count)
 	for i := 0; i < count; i++ {
-		recoveryCodes[i].Code = generateRecoveryCode()
+		recoveryCodes[i] = &model.RecoveryCode{
+			Code: generateRecoveryCode(),
+		}
 	}
 	return recoveryCodes
 }
@@ -34,17 +35,11 @@ func generateRecoveryCodes(count int) []model.RecoveryCode {
 func ViewRecoveryCodes(c *gin.Context) {
 	user := api.CurrentUser(c)
 
-	u := query.User
-	user, err := u.Where(u.ID.Eq(user.ID)).First()
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
-
 	// update last viewed time
-	t := time.Now()
+	u := query.User
+	t := time.Now().Unix()
 	user.RecoveryCodes.LastViewed = &t
-	_, err = u.Where(u.ID.Eq(user.ID)).Updates(user)
+	_, err := u.Where(u.ID.Eq(user.ID)).Updates(user)
 	if err != nil {
 		api.ErrHandler(c, err)
 		return
@@ -59,16 +54,12 @@ func ViewRecoveryCodes(c *gin.Context) {
 func GenerateRecoveryCodes(c *gin.Context) {
 	user := api.CurrentUser(c)
 
-	t := time.Now()
+	t := time.Now().Unix()
 	recoveryCodes := model.RecoveryCodes{Codes: generateRecoveryCodes(16), LastViewed: &t}
-	codesJson, err := json.Marshal(&recoveryCodes)
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
+	user.RecoveryCodes = recoveryCodes
 
 	u := query.User
-	_, err = u.Where(u.ID.Eq(user.ID)).Update(u.RecoveryCodes, codesJson)
+	_, err := u.Where(u.ID.Eq(user.ID)).Updates(user)
 	if err != nil {
 		api.ErrHandler(c, err)
 		return

+ 52 - 1
internal/crypto/aes.go

@@ -1,12 +1,17 @@
 package crypto
 
 import (
+	"context"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/rand"
 	"encoding/base64"
-	"github.com/0xJacky/Nginx-UI/settings"
+	"encoding/json"
 	"io"
+	"reflect"
+
+	"github.com/0xJacky/Nginx-UI/settings"
+	"gorm.io/gorm/schema"
 )
 
 // AesEncrypt encrypts text and given key with AES.
@@ -55,3 +60,49 @@ func AesDecrypt(text []byte) ([]byte, error) {
 
 	return data, nil
 }
+
+type JSONAesSerializer struct{}
+
+func (JSONAesSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
+	fieldValue := reflect.New(field.FieldType)
+
+	if dbValue != nil {
+		var bytes []byte
+		switch v := dbValue.(type) {
+		case []byte:
+			bytes = v
+		case string:
+			bytes = []byte(v)
+		default:
+			bytes, err = json.Marshal(v)
+			if err != nil {
+				return err
+			}
+		}
+
+		if len(bytes) > 0 {
+			bytes, err = AesDecrypt(bytes)
+			if err != nil {
+				return err
+			}
+			err = json.Unmarshal(bytes, fieldValue.Interface())
+		}
+	}
+
+	field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
+	return
+}
+
+// Value implements serializer interface
+func (JSONAesSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
+	result, err := json.Marshal(fieldValue)
+	if string(result) == "null" {
+		if field.TagSettings["NOT NULL"] != "" {
+			return "", nil
+		}
+		return nil, err
+	}
+
+	encrypt, err := AesEncrypt(result)
+	return string(encrypt), err
+}

+ 1 - 1
internal/user/otp.go

@@ -52,7 +52,7 @@ func VerifyOTP(user *model.User, otp, recoveryCode string) (err error) {
 		// check recovery code
 		for _, code := range user.RecoveryCodes.Codes {
 			if code.Code == recoveryCode && code.UsedTime == nil {
-				t := time.Now()
+				t := time.Now().Unix()
 				code.UsedTime = &t
 				_, err = u.Where(u.ID.Eq(user.ID)).Updates(user)
 				return

+ 12 - 8
model/user.go

@@ -1,22 +1,26 @@
 package model
 
 import (
-	"time"
-
+	"github.com/0xJacky/Nginx-UI/internal/crypto"
 	"github.com/go-webauthn/webauthn/webauthn"
 	"github.com/spf13/cast"
 	"gorm.io/gorm"
+	"gorm.io/gorm/schema"
 )
 
+func init() {
+	schema.RegisterSerializer("json[aes]", crypto.JSONAesSerializer{})
+}
+
 type RecoveryCode struct {
-	Code     string     `json:"code"`
-	UsedTime *time.Time `json:"used_time,omitempty"  gorm:"type:datetime;default:null"`
+	Code     string `json:"code"`
+	UsedTime *int64 `json:"used_time,omitempty"  gorm:"type:datetime;default:null"`
 }
 
 type RecoveryCodes struct {
-	Codes          []RecoveryCode `json:"codes"`
-	LastViewed     *time.Time     `json:"last_viewed,omitempty" gorm:"type:datetime;default:null"`
-	LastDownloaded *time.Time     `json:"last_downloaded,omitempty" gorm:"type:datetime;default:null"`
+	Codes          []*RecoveryCode `json:"codes"`
+	LastViewed     *int64          `json:"last_viewed,omitempty" gorm:"serializer:unixtime;type:datetime;default:null"`
+	LastDownloaded *int64          `json:"last_downloaded,omitempty" gorm:"serializer:unixtime;type:datetime;default:null"`
 }
 
 type User struct {
@@ -26,7 +30,7 @@ type User struct {
 	Password      string        `json:"-" cosy:"json:password;add:required,max=20;update:omitempty,max=20"`
 	Status        bool          `json:"status" gorm:"default:1"`
 	OTPSecret     []byte        `json:"-" gorm:"type:blob"`
-	RecoveryCodes RecoveryCodes `json:"-" gorm:"serializer:json"`
+	RecoveryCodes RecoveryCodes `json:"-" gorm:"serializer:json[aes]"`
 	EnabledTwoFA  bool          `json:"enabled_2fa" gorm:"-"`
 }