Browse Source

feat(server): implement tls certificate hot-reload

Jacky 2 months ago
parent
commit
de1860718e
4 changed files with 64 additions and 11 deletions
  1. 5 7
      api/settings/settings.go
  2. 6 0
      internal/cert/cert.go
  3. 36 0
      internal/cert/server_tls.go
  4. 17 4
      main.go

+ 5 - 7
api/settings/settings.go

@@ -3,14 +3,13 @@ package settings
 import (
 	"fmt"
 	"net/http"
-	"time"
 
+	"github.com/0xJacky/Nginx-UI/internal/cert"
 	"github.com/0xJacky/Nginx-UI/internal/cron"
 	"github.com/0xJacky/Nginx-UI/internal/nginx"
 	"github.com/0xJacky/Nginx-UI/internal/system"
 	"github.com/0xJacky/Nginx-UI/settings"
 	"github.com/gin-gonic/gin"
-	"github.com/jpillora/overseer"
 	"github.com/uozi-tech/cosy"
 	cSettings "github.com/uozi-tech/cosy/settings"
 )
@@ -84,9 +83,9 @@ func SaveSettings(c *gin.Context) {
 	}
 
 	// Validate SSL certificates if HTTPS is enabled
-	needRestart := false
+	needReloadCert := false
 	if json.Server.EnableHTTPS != cSettings.ServerSettings.EnableHTTPS {
-		needRestart = true
+		needReloadCert = true
 	}
 
 	if json.Server.EnableHTTPS {
@@ -112,10 +111,9 @@ func SaveSettings(c *gin.Context) {
 		return
 	}
 
-	if needRestart {
+	if needReloadCert {
 		go func() {
-			time.Sleep(2 * time.Second)
-			overseer.Restart()
+			cert.ReloadServerTLSCertificate()
 		}()
 	}
 

+ 6 - 0
internal/cert/cert.go

@@ -17,6 +17,7 @@ import (
 	dnsproviders "github.com/go-acme/lego/v4/providers/dns"
 	"github.com/pkg/errors"
 	"github.com/uozi-tech/cosy/logger"
+	cSettings "github.com/uozi-tech/cosy/settings"
 )
 
 const (
@@ -174,6 +175,11 @@ func IssueCert(payload *ConfigPayload, logChan chan string, errChan chan error)
 
 	l.Println("[INFO] [Nginx UI] Finished")
 
+	if payload.GetCertificatePath() == cSettings.ServerSettings.SSLCert &&
+		payload.GetCertificateKeyPath() == cSettings.ServerSettings.SSLKey {
+		ReloadServerTLSCertificate()
+	}
+
 	// Wait log to be written
 	time.Sleep(2 * time.Second)
 }

+ 36 - 0
internal/cert/server_tls.go

@@ -0,0 +1,36 @@
+package cert
+
+import (
+	"crypto/tls"
+	"errors"
+	"sync/atomic"
+
+	cSettings "github.com/uozi-tech/cosy/settings"
+)
+
+var tlsCert atomic.Value
+
+// LoadServerTLSCertificate loads the TLS certificate
+func LoadServerTLSCertificate() error {
+	return ReloadServerTLSCertificate()
+}
+
+// ReloadServerTLSCertificate reloads the TLS certificate
+func ReloadServerTLSCertificate() error {
+	newCert, err := tls.LoadX509KeyPair(cSettings.ServerSettings.SSLCert, cSettings.ServerSettings.SSLKey)
+	if err != nil {
+		return err
+	}
+
+	tlsCert.Store(newCert)
+	return nil
+}
+
+// GetServerTLSCertificate returns the current TLS certificate
+func GetServerTLSCertificate() (*tls.Certificate, error) {
+	cert, ok := tlsCert.Load().(*tls.Certificate)
+	if !ok {
+		return nil, errors.New("no certificate available")
+	}
+	return cert, nil
+}

+ 17 - 4
main.go

@@ -1,11 +1,13 @@
 package main
 
 import (
+	"crypto/tls"
 	"errors"
 	"fmt"
 	"net/http"
 	"time"
 
+	"github.com/0xJacky/Nginx-UI/internal/cert"
 	"github.com/0xJacky/Nginx-UI/internal/cmd"
 	"github.com/0xJacky/Nginx-UI/internal/kernel"
 	"github.com/0xJacky/Nginx-UI/model"
@@ -56,12 +58,23 @@ func Program(confPath string) func(state overseer.State) {
 		}
 		var err error
 		if cSettings.ServerSettings.EnableHTTPS {
-			// Convert SSL certificate and key paths to absolute paths if they are relative
-			sslCert := cSettings.ServerSettings.SSLCert
-			sslKey := cSettings.ServerSettings.SSLKey
+			// Load TLS certificate
+			err = cert.LoadServerTLSCertificate()
+			if err != nil {
+				logger.Fatalf("Failed to load TLS certificate: %v", err)
+				return
+			}
+
+			tlsConfig := &tls.Config{
+				GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+					return cert.GetServerTLSCertificate()
+				},
+			}
+
+			srv.TLSConfig = tlsConfig
 
 			logger.Info("Starting HTTPS server")
-			err = srv.ServeTLS(state.Listener, sslCert, sslKey)
+			err = srv.ServeTLS(state.Listener, "", "")
 		} else {
 			logger.Info("Starting HTTP server")
 			err = srv.Serve(state.Listener)