Browse Source

fix: pass context to cert and cache

Jacky 1 day ago
parent
commit
0162e10c53

+ 1 - 1
api/system/install.go

@@ -25,7 +25,7 @@ func init() {
 }
 
 func installLockStatus() bool {
-	return settings.NodeSettings.SkipInstallation || "" != cSettings.AppSettings.JwtSecret
+	return settings.NodeSettings.SkipInstallation || cSettings.AppSettings.JwtSecret != ""
 }
 
 // Check if installation time limit (10 minutes) is exceeded

+ 3 - 2
internal/cache/cache.go

@@ -1,6 +1,7 @@
 package cache
 
 import (
+	"context"
 	"time"
 
 	"github.com/dgraph-io/ristretto/v2"
@@ -9,7 +10,7 @@ import (
 
 var cache *ristretto.Cache[string, any]
 
-func Init() {
+func Init(ctx context.Context) {
 	var err error
 	cache, err = ristretto.NewCache(&ristretto.Config[string, any]{
 		NumCounters: 1e7,     // number of keys to track frequency of (10M).
@@ -22,7 +23,7 @@ func Init() {
 	}
 
 	// Initialize the config scanner
-	InitScanner()
+	InitScanner(ctx)
 }
 
 func Set(key string, value interface{}, ttl time.Duration) {

+ 34 - 13
internal/cache/index.go

@@ -1,6 +1,7 @@
 package cache
 
 import (
+	"context"
 	"os"
 	"path/filepath"
 	"regexp"
@@ -19,6 +20,7 @@ type ScanCallback func(configPath string, content []byte) error
 
 // Scanner is responsible for scanning and watching nginx config files
 type Scanner struct {
+	ctx           context.Context        // Context for the scanner
 	watcher       *fsnotify.Watcher      // File system watcher
 	scanTicker    *time.Ticker           // Ticker for periodic scanning
 	initialized   bool                   // Whether the scanner has been initialized
@@ -39,24 +41,19 @@ var (
 	includeRegex = regexp.MustCompile(`include\s+([^;]+);`)
 
 	// Global callbacks that will be executed during config file scanning
-	scanCallbacks      []ScanCallback
+	scanCallbacks      = make([]ScanCallback, 0)
 	scanCallbacksMutex sync.RWMutex
 )
 
-func init() {
-	// Initialize the callbacks slice
-	scanCallbacks = make([]ScanCallback, 0)
-}
-
 // InitScanner initializes the config scanner
-func InitScanner() {
+func InitScanner(ctx context.Context) {
 	if nginx.GetConfPath() == "" {
 		logger.Error("Nginx config path is not set")
 		return
 	}
 
 	s := GetScanner()
-	err := s.Initialize()
+	err := s.Initialize(ctx)
 	if err != nil {
 		logger.Error("Failed to initialize config scanner:", err)
 	}
@@ -140,7 +137,7 @@ func UnsubscribeScanningStatus(ch chan bool) {
 }
 
 // Initialize sets up the scanner and starts watching for file changes
-func (s *Scanner) Initialize() error {
+func (s *Scanner) Initialize(ctx context.Context) error {
 	if s.initialized {
 		return nil
 	}
@@ -151,6 +148,7 @@ func (s *Scanner) Initialize() error {
 		return err
 	}
 	s.watcher = watcher
+	s.ctx = ctx
 
 	// Scan for the first time
 	err = s.ScanAllConfigs()
@@ -207,14 +205,26 @@ func (s *Scanner) Initialize() error {
 	// Setup a ticker for periodic scanning (every 5 minutes)
 	s.scanTicker = time.NewTicker(5 * time.Minute)
 	go func() {
-		for range s.scanTicker.C {
-			err := s.ScanAllConfigs()
-			if err != nil {
-				logger.Error("Periodic config scan failed:", err)
+		for {
+			select {
+			case <-s.ctx.Done():
+				return
+			case <-s.scanTicker.C:
+				err := s.ScanAllConfigs()
+				if err != nil {
+					logger.Error("Periodic config scan failed:", err)
+				}
 			}
 		}
 	}()
 
+	// Start a goroutine to listen for context cancellation
+	go func() {
+		<-s.ctx.Done()
+		logger.Debug("Context cancelled, shutting down scanner")
+		s.Shutdown()
+	}()
+
 	s.initialized = true
 	return nil
 }
@@ -223,6 +233,8 @@ func (s *Scanner) Initialize() error {
 func (s *Scanner) watchForChanges() {
 	for {
 		select {
+		case <-s.ctx.Done():
+			return
 		case event, ok := <-s.watcher.Events:
 			if !ok {
 				return
@@ -471,3 +483,12 @@ func IsScanningInProgress() bool {
 	defer s.scanMutex.RUnlock()
 	return s.scanning
 }
+
+// WithContext sets a context for the scanner that will be used to control its lifecycle
+func (s *Scanner) WithContext(ctx context.Context) *Scanner {
+	// Create a context with cancel if not already done in Initialize
+	if s.ctx == nil {
+		s.ctx = ctx
+	}
+	return s
+}

+ 24 - 12
internal/cert/mutex.go

@@ -1,6 +1,7 @@
 package cert
 
 import (
+	"context"
 	"sync"
 )
 
@@ -24,28 +25,39 @@ var (
 	processingMutex sync.RWMutex
 )
 
-func init() {
+func initBroadcastStatus(ctx context.Context) {
 	// Initialize channels and maps
 	statusChan = make(chan bool, 10) // Buffer to prevent blocking
 	subscribers = make(map[chan bool]struct{})
 
 	// Start broadcasting goroutine
-	go broadcastStatus()
+	go broadcastStatus(ctx)
 }
 
 // broadcastStatus listens for status changes and broadcasts to all subscribers
-func broadcastStatus() {
-	for status := range statusChan {
-		subscriberMux.RLock()
-		for ch := range subscribers {
-			// Non-blocking send to prevent slow subscribers from blocking others
-			select {
-			case ch <- status:
-			default:
-				// Skip if channel buffer is full
+func broadcastStatus(ctx context.Context) {
+	for {
+		select {
+		case <-ctx.Done():
+			// Context cancelled, clean up resources and exit
+			close(statusChan)
+			return
+		case status, ok := <-statusChan:
+			if !ok {
+				// Channel closed, exit
+				return
 			}
+			subscriberMux.RLock()
+			for ch := range subscribers {
+				// Non-blocking send to prevent slow subscribers from blocking others
+				select {
+				case ch <- status:
+				default:
+					// Skip if channel buffer is full
+				}
+			}
+			subscriberMux.RUnlock()
 		}
-		subscriberMux.RUnlock()
 	}
 }
 

+ 2 - 0
internal/cert/register.go

@@ -52,6 +52,8 @@ func InitRegister(ctx context.Context) {
 	}
 
 	logger.Info("ACME Default User registered")
+
+	initBroadcastStatus(ctx)
 }
 
 func GetDefaultACMEUser() (user *model.AcmeUser, err error) {

+ 3 - 1
internal/kernel/boot.go

@@ -38,7 +38,9 @@ func Boot(ctx context.Context) {
 		InitNodeSecret,
 		InitCryptoSecret,
 		validation.Init,
-		cache.Init,
+		func() {
+			cache.Init(ctx)
+		},
 		CheckAndCleanupOTAContainers,
 	}
 

+ 1 - 6
internal/nginx_log/log_cache.go

@@ -13,15 +13,10 @@ type NginxLogCache struct {
 
 var (
 	// logCache is the map to store all found log files
-	logCache   map[string]*NginxLogCache
+	logCache   = make(map[string]*NginxLogCache)
 	cacheMutex sync.RWMutex
 )
 
-func init() {
-	// Initialize the cache
-	logCache = make(map[string]*NginxLogCache)
-}
-
 // AddLogPath adds a log path to the log cache
 func AddLogPath(path, logType, name string) {
 	cacheMutex.Lock()