Ver código fonte

enhance: validate certificate content before save

0xJacky 1 ano atrás
pai
commit
8581bdd3c6

+ 11 - 24
api/api.go

@@ -4,28 +4,12 @@ import (
 	"errors"
 	"github.com/0xJacky/Nginx-UI/internal/logger"
 	"github.com/gin-gonic/gin"
-	"github.com/gin-gonic/gin/binding"
 	val "github.com/go-playground/validator/v10"
 	"net/http"
 	"reflect"
-	"regexp"
 	"strings"
 )
 
-func init() {
-	if v, ok := binding.Validator.Engine().(*val.Validate); ok {
-		err := v.RegisterValidation("alphanumdash", func(fl val.FieldLevel) bool {
-			return regexp.MustCompile(`^[a-zA-Z0-9-]+$`).MatchString(fl.Field().String())
-		})
-
-		if err != nil {
-			logger.Fatal(err)
-		}
-		return
-	}
-	logger.Fatal("binding validator engine is not initialized")
-}
-
 func ErrHandler(c *gin.Context, err error) {
 	logger.GetLogger().Errorln(err)
 	c.JSON(http.StatusInternalServerError, gin.H{
@@ -54,11 +38,18 @@ func BindAndValid(c *gin.Context, target interface{}) bool {
 			return false
 		}
 
-		t := reflect.TypeOf(target).Elem()
+		t := reflect.TypeOf(target)
 		errorsMap := make(map[string]interface{})
 		for _, value := range verrs {
 			var path []string
-			getJsonPath(t, value.StructNamespace(), &path)
+
+			namespace := strings.Split(value.StructNamespace(), ".")
+
+			if t.Name() == "" && len(namespace) > 1 {
+				namespace = namespace[1:]
+			}
+
+			getJsonPath(t.Elem(), namespace, &path)
 			insertError(errorsMap, path, value.Tag())
 		}
 
@@ -75,11 +66,7 @@ func BindAndValid(c *gin.Context, target interface{}) bool {
 }
 
 // findField recursively finds the field in a nested struct
-func getJsonPath(t reflect.Type, namespace string, path *[]string) {
-	fields := strings.Split(namespace, ".")
-	if len(fields) == 0 {
-		return
-	}
+func getJsonPath(t reflect.Type, fields []string, path *[]string) {
 	f, ok := t.FieldByName(fields[0])
 	if !ok {
 		return
@@ -88,7 +75,7 @@ func getJsonPath(t reflect.Type, namespace string, path *[]string) {
 	*path = append(*path, f.Tag.Get("json"))
 
 	if len(fields) > 1 {
-		subFields := strings.Join(fields[1:], ".")
+		subFields := fields[1:]
 		getJsonPath(f.Type, subFields, path)
 	}
 }

+ 144 - 144
api/certificate/certificate.go

@@ -1,174 +1,174 @@
 package certificate
 
 import (
-	"github.com/0xJacky/Nginx-UI/api"
-	"github.com/0xJacky/Nginx-UI/api/cosy"
-	"github.com/0xJacky/Nginx-UI/internal/cert"
-	"github.com/0xJacky/Nginx-UI/model"
-	"github.com/0xJacky/Nginx-UI/query"
-	"github.com/gin-gonic/gin"
-	"github.com/spf13/cast"
-	"net/http"
-	"os"
+    "github.com/0xJacky/Nginx-UI/api"
+    "github.com/0xJacky/Nginx-UI/api/cosy"
+    "github.com/0xJacky/Nginx-UI/internal/cert"
+    "github.com/0xJacky/Nginx-UI/model"
+    "github.com/0xJacky/Nginx-UI/query"
+    "github.com/gin-gonic/gin"
+    "github.com/spf13/cast"
+    "net/http"
+    "os"
 )
 
 type APICertificate struct {
-	*model.Cert
-	SSLCertificate    string     `json:"ssl_certificate,omitempty"`
-	SSLCertificateKey string     `json:"ssl_certificate_key,omitempty"`
-	CertificateInfo   *cert.Info `json:"certificate_info,omitempty"`
+    *model.Cert
+    SSLCertificate    string     `json:"ssl_certificate,omitempty"`
+    SSLCertificateKey string     `json:"ssl_certificate_key,omitempty"`
+    CertificateInfo   *cert.Info `json:"certificate_info,omitempty"`
 }
 
 func Transformer(certModel *model.Cert) (certificate *APICertificate) {
-	var sslCertificationBytes, sslCertificationKeyBytes []byte
-	var certificateInfo *cert.Info
-	if certModel.SSLCertificatePath != "" {
-		if _, err := os.Stat(certModel.SSLCertificatePath); err == nil {
-			sslCertificationBytes, _ = os.ReadFile(certModel.SSLCertificatePath)
-		}
-
-		certificateInfo, _ = cert.GetCertInfo(certModel.SSLCertificatePath)
-	}
-
-	if certModel.SSLCertificateKeyPath != "" {
-		if _, err := os.Stat(certModel.SSLCertificateKeyPath); err == nil {
-			sslCertificationKeyBytes, _ = os.ReadFile(certModel.SSLCertificateKeyPath)
-		}
-	}
-
-	return &APICertificate{
-		Cert:              certModel,
-		SSLCertificate:    string(sslCertificationBytes),
-		SSLCertificateKey: string(sslCertificationKeyBytes),
-		CertificateInfo:   certificateInfo,
-	}
+    var sslCertificationBytes, sslCertificationKeyBytes []byte
+    var certificateInfo *cert.Info
+    if certModel.SSLCertificatePath != "" {
+        if _, err := os.Stat(certModel.SSLCertificatePath); err == nil {
+            sslCertificationBytes, _ = os.ReadFile(certModel.SSLCertificatePath)
+            if !cert.IsPublicKey(string(sslCertificationBytes)) {
+                sslCertificationBytes = []byte{}
+            }
+        }
+
+        certificateInfo, _ = cert.GetCertInfo(certModel.SSLCertificatePath)
+    }
+
+    if certModel.SSLCertificateKeyPath != "" {
+        if _, err := os.Stat(certModel.SSLCertificateKeyPath); err == nil {
+            sslCertificationKeyBytes, _ = os.ReadFile(certModel.SSLCertificateKeyPath)
+            if !cert.IsPrivateKey(string(sslCertificationKeyBytes)) {
+                sslCertificationKeyBytes = []byte{}
+            }
+        }
+    }
+
+    return &APICertificate{
+        Cert:              certModel,
+        SSLCertificate:    string(sslCertificationBytes),
+        SSLCertificateKey: string(sslCertificationKeyBytes),
+        CertificateInfo:   certificateInfo,
+    }
 }
 
 func GetCertList(c *gin.Context) {
-	cosy.Core[model.Cert](c).SetFussy("name", "domain").SetTransformer(func(m *model.Cert) any {
+    cosy.Core[model.Cert](c).SetFussy("name", "domain").SetTransformer(func(m *model.Cert) any {
 
-		info, _ := cert.GetCertInfo(m.SSLCertificatePath)
+        info, _ := cert.GetCertInfo(m.SSLCertificatePath)
 
-		return APICertificate{
-			Cert:            m,
-			CertificateInfo: info,
-		}
-	}).PagingList()
+        return APICertificate{
+            Cert:            m,
+            CertificateInfo: info,
+        }
+    }).PagingList()
 }
 
 func GetCert(c *gin.Context) {
-	q := query.Cert
+    q := query.Cert
 
-	certModel, err := q.FirstByID(cast.ToInt(c.Param("id")))
+    certModel, err := q.FirstByID(cast.ToInt(c.Param("id")))
 
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
+    if err != nil {
+        api.ErrHandler(c, err)
+        return
+    }
 
-	c.JSON(http.StatusOK, Transformer(certModel))
+    c.JSON(http.StatusOK, Transformer(certModel))
+}
+
+type certJson struct {
+    Name                  string `json:"name"`
+    SSLCertificatePath    string `json:"ssl_certificate_path" binding:"publickey_path"`
+    SSLCertificateKeyPath string `json:"ssl_certificate_key_path" binding:"privatekey_path"`
+    SSLCertificate        string `json:"ssl_certificate" binding:"omitempty,publickey"`
+    SSLCertificateKey     string `json:"ssl_certificate_key" binding:"omitempty,privatekey"`
+    ChallengeMethod       string `json:"challenge_method"`
+    DnsCredentialID       int    `json:"dns_credential_id"`
 }
 
 func AddCert(c *gin.Context) {
-	var json struct {
-		Name                  string `json:"name"`
-		SSLCertificatePath    string `json:"ssl_certificate_path" binding:"required"`
-		SSLCertificateKeyPath string `json:"ssl_certificate_key_path" binding:"required"`
-		SSLCertificate        string `json:"ssl_certificate"`
-		SSLCertificateKey     string `json:"ssl_certificate_key"`
-		ChallengeMethod       string `json:"challenge_method"`
-		DnsCredentialID       int    `json:"dns_credential_id"`
-	}
-	if !api.BindAndValid(c, &json) {
-		return
-	}
-	certModel := &model.Cert{
-		Name:                  json.Name,
-		SSLCertificatePath:    json.SSLCertificatePath,
-		SSLCertificateKeyPath: json.SSLCertificateKeyPath,
-		ChallengeMethod:       json.ChallengeMethod,
-		DnsCredentialID:       json.DnsCredentialID,
-	}
-
-	err := certModel.Insert()
-
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
-
-	content := &cert.Content{
-		SSLCertificatePath:    json.SSLCertificatePath,
-		SSLCertificateKeyPath: json.SSLCertificateKeyPath,
-		SSLCertificate:        json.SSLCertificate,
-		SSLCertificateKey:     json.SSLCertificateKey,
-	}
-
-	err = content.WriteFile()
-
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
-
-	c.JSON(http.StatusOK, Transformer(certModel))
+    var json certJson
+    if !api.BindAndValid(c, &json) {
+        return
+    }
+    certModel := &model.Cert{
+        Name:                  json.Name,
+        SSLCertificatePath:    json.SSLCertificatePath,
+        SSLCertificateKeyPath: json.SSLCertificateKeyPath,
+        ChallengeMethod:       json.ChallengeMethod,
+        DnsCredentialID:       json.DnsCredentialID,
+    }
+
+    err := certModel.Insert()
+
+    if err != nil {
+        api.ErrHandler(c, err)
+        return
+    }
+
+    content := &cert.Content{
+        SSLCertificatePath:    json.SSLCertificatePath,
+        SSLCertificateKeyPath: json.SSLCertificateKeyPath,
+        SSLCertificate:        json.SSLCertificate,
+        SSLCertificateKey:     json.SSLCertificateKey,
+    }
+
+    err = content.WriteFile()
+
+    if err != nil {
+        api.ErrHandler(c, err)
+        return
+    }
+
+    c.JSON(http.StatusOK, Transformer(certModel))
 }
 
 func ModifyCert(c *gin.Context) {
-	id := cast.ToInt(c.Param("id"))
-
-	var json struct {
-		Name                  string `json:"name"`
-		SSLCertificatePath    string `json:"ssl_certificate_path" binding:"required"`
-		SSLCertificateKeyPath string `json:"ssl_certificate_key_path" binding:"required"`
-		SSLCertificate        string `json:"ssl_certificate"`
-		SSLCertificateKey     string `json:"ssl_certificate_key"`
-		ChallengeMethod       string `json:"challenge_method"`
-		DnsCredentialID       int    `json:"dns_credential_id"`
-	}
-
-	if !api.BindAndValid(c, &json) {
-		return
-	}
-
-	q := query.Cert
-
-	certModel, err := q.FirstByID(id)
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
-
-	err = certModel.Updates(&model.Cert{
-		Name:                  json.Name,
-		SSLCertificatePath:    json.SSLCertificatePath,
-		SSLCertificateKeyPath: json.SSLCertificateKeyPath,
-		ChallengeMethod:       json.ChallengeMethod,
-		DnsCredentialID:       json.DnsCredentialID,
-	})
-
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
-
-	content := &cert.Content{
-		SSLCertificatePath:    json.SSLCertificatePath,
-		SSLCertificateKeyPath: json.SSLCertificateKeyPath,
-		SSLCertificate:        json.SSLCertificate,
-		SSLCertificateKey:     json.SSLCertificateKey,
-	}
-
-	err = content.WriteFile()
-
-	if err != nil {
-		api.ErrHandler(c, err)
-		return
-	}
-
-	GetCert(c)
+    id := cast.ToInt(c.Param("id"))
+
+    var json certJson
+
+    if !api.BindAndValid(c, &json) {
+        return
+    }
+
+    q := query.Cert
+
+    certModel, err := q.FirstByID(id)
+    if err != nil {
+        api.ErrHandler(c, err)
+        return
+    }
+
+    err = certModel.Updates(&model.Cert{
+        Name:                  json.Name,
+        SSLCertificatePath:    json.SSLCertificatePath,
+        SSLCertificateKeyPath: json.SSLCertificateKeyPath,
+        ChallengeMethod:       json.ChallengeMethod,
+        DnsCredentialID:       json.DnsCredentialID,
+    })
+
+    if err != nil {
+        api.ErrHandler(c, err)
+        return
+    }
+
+    content := &cert.Content{
+        SSLCertificatePath:    json.SSLCertificatePath,
+        SSLCertificateKeyPath: json.SSLCertificateKeyPath,
+        SSLCertificate:        json.SSLCertificate,
+        SSLCertificateKey:     json.SSLCertificateKey,
+    }
+
+    err = content.WriteFile()
+
+    if err != nil {
+        api.ErrHandler(c, err)
+        return
+    }
+
+    GetCert(c)
 }
 
 func RemoveCert(c *gin.Context) {
-	cosy.Core[model.Cert](c).Destroy()
+    cosy.Core[model.Cert](c).Destroy()
 }

+ 70 - 0
internal/cert/helper.go

@@ -0,0 +1,70 @@
+package cert
+
+import (
+	"crypto/x509"
+	"encoding/pem"
+	"os"
+)
+
+func IsPublicKey(pemStr string) bool {
+	block, _ := pem.Decode([]byte(pemStr))
+	if block == nil {
+		return false
+	}
+
+	_, err := x509.ParsePKIXPublicKey(block.Bytes)
+	return err == nil
+}
+
+func IsPrivateKey(pemStr string) bool {
+	block, _ := pem.Decode([]byte(pemStr))
+	if block == nil {
+		return false
+	}
+
+	_, errRSA := x509.ParsePKCS1PrivateKey(block.Bytes)
+	if errRSA == nil {
+		return true
+	}
+
+	_, errECDSA := x509.ParseECPrivateKey(block.Bytes)
+	return errECDSA == nil
+}
+
+// IsPublicKeyPath checks if the file at the given path is a public key or not exists.
+func IsPublicKeyPath(path string) bool {
+	_, err := os.Stat(path)
+
+	if err != nil {
+		if os.IsNotExist(err) {
+			return true
+		}
+		return false
+	}
+
+	bytes, err := os.ReadFile(path)
+	if err != nil {
+		return false
+	}
+
+	return IsPublicKey(string(bytes))
+}
+
+// IsPrivateKeyPath checks if the file at the given path is a private key or not exists.
+func IsPrivateKeyPath(path string) bool {
+	_, err := os.Stat(path)
+
+	if err != nil {
+		if os.IsNotExist(err) {
+			return true
+		}
+		return false
+	}
+
+	bytes, err := os.ReadFile(path)
+	if err != nil {
+		return false
+	}
+
+	return IsPrivateKey(string(bytes))
+}

+ 2 - 0
internal/kernal/boot.go

@@ -4,6 +4,7 @@ import (
 	"github.com/0xJacky/Nginx-UI/internal/analytic"
 	"github.com/0xJacky/Nginx-UI/internal/cert"
 	"github.com/0xJacky/Nginx-UI/internal/logger"
+	"github.com/0xJacky/Nginx-UI/internal/validation"
 	"github.com/0xJacky/Nginx-UI/model"
 	"github.com/0xJacky/Nginx-UI/query"
 	"github.com/0xJacky/Nginx-UI/settings"
@@ -21,6 +22,7 @@ func Boot() {
 		InitJsExtensionType,
 		InitDatabase,
 		InitNodeSecret,
+		validation.Init,
 	}
 
 	syncs := []func(){

+ 10 - 0
internal/validation/alphanumdash.go

@@ -0,0 +1,10 @@
+package validation
+
+import (
+	val "github.com/go-playground/validator/v10"
+	"regexp"
+)
+
+func alphaNumDash(fl val.FieldLevel) bool {
+	return regexp.MustCompile(`^[a-zA-Z0-9-]+$`).MatchString(fl.Field().String())
+}

+ 22 - 0
internal/validation/certificate.go

@@ -0,0 +1,22 @@
+package validation
+
+import (
+	"github.com/0xJacky/Nginx-UI/internal/cert"
+	val "github.com/go-playground/validator/v10"
+)
+
+func isPublicKey(fl val.FieldLevel) bool {
+	return cert.IsPublicKey(fl.Field().String())
+}
+
+func isPrivateKey(fl val.FieldLevel) bool {
+	return cert.IsPrivateKey(fl.Field().String())
+}
+
+func isPublicKeyPath(fl val.FieldLevel) bool {
+	return cert.IsPublicKeyPath(fl.Field().String())
+}
+
+func isPrivateKeyPath(fl val.FieldLevel) bool {
+	return cert.IsPrivateKeyPath(fl.Field().String())
+}

+ 46 - 0
internal/validation/validation.go

@@ -0,0 +1,46 @@
+package validation
+
+import (
+	"github.com/0xJacky/Nginx-UI/internal/logger"
+	"github.com/gin-gonic/gin/binding"
+	val "github.com/go-playground/validator/v10"
+)
+
+func Init() {
+	v, ok := binding.Validator.Engine().(*val.Validate)
+	if !ok {
+		logger.Fatal("binding validator engine is not initialized")
+	}
+
+	err := v.RegisterValidation("alphanumdash", alphaNumDash)
+
+	if err != nil {
+		logger.Fatal(err)
+	}
+
+	err = v.RegisterValidation("publickey", isPublicKey)
+
+	if err != nil {
+		logger.Fatal(err)
+	}
+
+	err = v.RegisterValidation("privatekey", isPrivateKey)
+
+	if err != nil {
+		logger.Fatal(err)
+	}
+
+	err = v.RegisterValidation("publickey_path", isPublicKeyPath)
+
+	if err != nil {
+		logger.Fatal(err)
+	}
+
+	err = v.RegisterValidation("privatekey_path", isPrivateKeyPath)
+
+	if err != nil {
+		logger.Fatal(err)
+	}
+
+	return
+}