otp.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package user
  2. import (
  3. "bytes"
  4. "crypto/sha1"
  5. "encoding/hex"
  6. "fmt"
  7. "time"
  8. "github.com/0xJacky/Nginx-UI/internal/cache"
  9. "github.com/0xJacky/Nginx-UI/internal/crypto"
  10. "github.com/0xJacky/Nginx-UI/model"
  11. "github.com/0xJacky/Nginx-UI/query"
  12. "github.com/google/uuid"
  13. "github.com/pquerna/otp/totp"
  14. )
  15. func VerifyOTP(user *model.User, otp, recoveryCode string) (err error) {
  16. if otp != "" {
  17. decrypted, err := crypto.AesDecrypt(user.OTPSecret)
  18. if err != nil {
  19. return err
  20. }
  21. if ok := totp.Validate(otp, string(decrypted)); !ok {
  22. return ErrOTPCode
  23. }
  24. } else {
  25. // get user from db
  26. u := query.User
  27. user, err = u.Where(u.ID.Eq(user.ID)).First()
  28. if err != nil {
  29. return err
  30. }
  31. // legacy recovery code
  32. if !user.RecoveryCodeGenerated() {
  33. if user.OTPSecret == nil {
  34. return ErrTOTPNotEnabled
  35. }
  36. recoverCode, err := hex.DecodeString(recoveryCode)
  37. if err != nil {
  38. return err
  39. }
  40. k := sha1.Sum(user.OTPSecret)
  41. if !bytes.Equal(k[:], recoverCode) {
  42. return ErrRecoveryCode
  43. }
  44. }
  45. // check recovery code
  46. for _, code := range user.RecoveryCodes.Codes {
  47. if code.Code == recoveryCode && code.UsedTime == nil {
  48. t := time.Now().Unix()
  49. code.UsedTime = &t
  50. _, err = u.Where(u.ID.Eq(user.ID)).Updates(user)
  51. return
  52. }
  53. }
  54. return ErrRecoveryCode
  55. }
  56. return
  57. }
  58. func secureSessionIDCacheKey(sessionId string) string {
  59. return fmt.Sprintf("2fa_secure_session:_%s", sessionId)
  60. }
  61. func SetSecureSessionID(userId uint64) (sessionId string) {
  62. sessionId = uuid.NewString()
  63. cache.Set(secureSessionIDCacheKey(sessionId), userId, 5*time.Minute)
  64. return
  65. }
  66. func VerifySecureSessionID(sessionId string, userId uint64) bool {
  67. if v, ok := cache.Get(secureSessionIDCacheKey(sessionId)); ok {
  68. if v.(uint64) == userId {
  69. return true
  70. }
  71. }
  72. return false
  73. }