Browse Source

fix: certificate dir cannot be created in some windows-docker setup #403

Jacky 1 year ago
parent
commit
11e460765a
2 changed files with 68 additions and 31 deletions
  1. 4 11
      api/certificate/issue.go
  2. 64 20
      internal/cert/payload.go

+ 4 - 11
api/certificate/issue.go

@@ -3,13 +3,11 @@ package certificate
 import (
 	"github.com/0xJacky/Nginx-UI/internal/cert"
 	"github.com/0xJacky/Nginx-UI/internal/logger"
-	"github.com/0xJacky/Nginx-UI/internal/nginx"
 	"github.com/0xJacky/Nginx-UI/model"
 	"github.com/gin-gonic/gin"
 	"github.com/go-acme/lego/v4/certcrypto"
 	"github.com/gorilla/websocket"
 	"net/http"
-	"strings"
 )
 
 const (
@@ -71,7 +69,6 @@ func IssueCert(c *gin.Context) {
 	payload := &cert.ConfigPayload{}
 
 	err = ws.ReadJSON(payload)
-
 	if err != nil {
 		logger.Error(err)
 		return
@@ -122,14 +119,10 @@ func IssueCert(c *gin.Context) {
 		return
 	}
 
-	certDirName := strings.Join(payload.ServerName, "_") + "_" + string(payload.GetKeyType())
-	sslCertificatePath := nginx.GetConfPath("ssl", certDirName, "fullchain.cer")
-	sslCertificateKeyPath := nginx.GetConfPath("ssl", certDirName, "private.key")
-
 	err = certModel.Updates(&model.Cert{
 		Domains:               payload.ServerName,
-		SSLCertificatePath:    sslCertificatePath,
-		SSLCertificateKeyPath: sslCertificateKeyPath,
+		SSLCertificatePath:    payload.GetCertificatePath(),
+		SSLCertificateKeyPath: payload.GetCertificateKeyPath(),
 		AutoCert:              model.AutoCertEnabled,
 		KeyType:               payload.KeyType,
 		ChallengeMethod:       payload.ChallengeMethod,
@@ -152,8 +145,8 @@ func IssueCert(c *gin.Context) {
 	err = ws.WriteJSON(IssueCertResponse{
 		Status:            Success,
 		Message:           "Issued certificate successfully",
-		SSLCertificate:    sslCertificatePath,
-		SSLCertificateKey: sslCertificateKeyPath,
+		SSLCertificate:    payload.GetCertificatePath(),
+		SSLCertificateKey: payload.GetCertificateKeyPath(),
 		KeyType:           payload.GetKeyType(),
 	})
 

+ 64 - 20
internal/cert/payload.go

@@ -16,14 +16,17 @@ import (
 )
 
 type ConfigPayload struct {
-	CertID          int                        `json:"cert_id"`
-	ServerName      []string                   `json:"server_name"`
-	ChallengeMethod string                     `json:"challenge_method"`
-	DNSCredentialID int                        `json:"dns_credential_id"`
-	ACMEUserID      int                        `json:"acme_user_id"`
-	KeyType         certcrypto.KeyType         `json:"key_type"`
-	Resource        *model.CertificateResource `json:"resource,omitempty"`
-	NotBefore       time.Time
+	CertID                int                        `json:"cert_id"`
+	ServerName            []string                   `json:"server_name"`
+	ChallengeMethod       string                     `json:"challenge_method"`
+	DNSCredentialID       int                        `json:"dns_credential_id"`
+	ACMEUserID            int                        `json:"acme_user_id"`
+	KeyType               certcrypto.KeyType         `json:"key_type"`
+	Resource              *model.CertificateResource `json:"resource,omitempty"`
+	NotBefore             time.Time                  `json:"-"`
+	CertificateDir        string                     `json:"-"`
+	SSLCertificatePath    string                     `json:"-"`
+	SSLCertificateKeyPath string                     `json:"-"`
 }
 
 func (c *ConfigPayload) GetACMEUser() (user *model.AcmeUser, err error) {
@@ -46,21 +49,38 @@ func (c *ConfigPayload) GetKeyType() certcrypto.KeyType {
 	return helper.GetKeyType(c.KeyType)
 }
 
-func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) {
-	name := strings.Join(c.ServerName, "_")
-	saveDir := nginx.GetConfPath("ssl/" + name + "_" + string(c.KeyType))
-	if _, err := os.Stat(saveDir); os.IsNotExist(err) {
-		err = os.MkdirAll(saveDir, 0755)
-		if err != nil {
-			errChan <- errors.Wrap(err, "mkdir error")
-			return
+func (c *ConfigPayload) mkCertificateDir() (err error) {
+	dir := c.getCertificateDirPath()
+	if _, err = os.Stat(dir); os.IsNotExist(err) {
+		err = os.MkdirAll(dir, 0755)
+		if err == nil {
+			return nil
+		}
+	}
+
+	// For windows, replace # with * (issue #403)
+	c.CertificateDir = strings.ReplaceAll(c.CertificateDir, "#", "*")
+	if _, err = os.Stat(c.CertificateDir); os.IsNotExist(err) {
+		err = os.MkdirAll(c.CertificateDir, 0755)
+		if err == nil {
+			return nil
 		}
 	}
 
+	return
+}
+
+func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) {
+	err := c.mkCertificateDir()
+	if err != nil {
+		errChan <- errors.Wrap(err, "make certificate dir error")
+		return
+	}
+
 	// Each certificate comes back with the cert bytes, the bytes of the client's
 	// private key, and a certificate URL. SAVE THESE TO DISK.
 	l.Println("[INFO] [Nginx UI] Writing certificate to disk")
-	err := os.WriteFile(filepath.Join(saveDir, "fullchain.cer"),
+	err = os.WriteFile(c.GetCertificatePath(),
 		c.Resource.Certificate, 0644)
 
 	if err != nil {
@@ -69,7 +89,7 @@ func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) {
 	}
 
 	l.Println("[INFO] [Nginx UI] Writing certificate private key to disk")
-	err = os.WriteFile(filepath.Join(saveDir, "private.key"),
+	err = os.WriteFile(c.GetCertificateKeyPath(),
 		c.Resource.PrivateKey, 0644)
 
 	if err != nil {
@@ -84,7 +104,31 @@ func (c *ConfigPayload) WriteFile(l *log.Logger, errChan chan error) {
 
 	db := model.UseDB()
 	db.Where("id = ?", c.CertID).Updates(&model.Cert{
-		SSLCertificatePath:    filepath.Join(saveDir, "fullchain.cer"),
-		SSLCertificateKeyPath: filepath.Join(saveDir, "private.key"),
+		SSLCertificatePath:    c.GetCertificatePath(),
+		SSLCertificateKeyPath: c.GetCertificateKeyPath(),
 	})
 }
+
+func (c *ConfigPayload) getCertificateDirPath() string {
+	if c.CertificateDir != "" {
+		return c.CertificateDir
+	}
+	c.CertificateDir = nginx.GetConfPath("ssl", strings.Join(c.ServerName, "_")+"_"+string(c.GetKeyType()))
+	return c.CertificateDir
+}
+
+func (c *ConfigPayload) GetCertificatePath() string {
+	if c.SSLCertificatePath != "" {
+		return c.SSLCertificatePath
+	}
+	c.SSLCertificatePath = filepath.Join(c.getCertificateDirPath(), "fullchain.cer")
+	return c.SSLCertificatePath
+}
+
+func (c *ConfigPayload) GetCertificateKeyPath() string {
+	if c.SSLCertificateKeyPath != "" {
+		return c.SSLCertificateKeyPath
+	}
+	c.SSLCertificateKeyPath = filepath.Join(c.getCertificateDirPath(), "private.key")
+	return c.SSLCertificateKeyPath
+}