Browse Source

enhance: file scanning with recursion protection and abs path resolution #1178

Jacky 1 day ago
parent
commit
dd8dfe0f8b
1 changed files with 52 additions and 20 deletions
  1. 52 20
      internal/cache/index.go

+ 52 - 20
internal/cache/index.go

@@ -296,34 +296,66 @@ func (s *Scanner) watchForChanges() {
 
 
 // scanSingleFile scans a single file and executes all registered callbacks
 // scanSingleFile scans a single file and executes all registered callbacks
 func (s *Scanner) scanSingleFile(filePath string) error {
 func (s *Scanner) scanSingleFile(filePath string) error {
-	// Set scanning state to true
-	s.scanMutex.Lock()
-	wasScanning := s.scanning
-	s.scanning = true
-	if !wasScanning {
-		// Only broadcast if status changed from not scanning to scanning
-		s.statusChan <- true
+	return s.scanSingleFileWithDepth(filePath, make(map[string]bool), 0)
+}
+
+// scanSingleFileWithDepth scans a single file with recursion protection
+func (s *Scanner) scanSingleFileWithDepth(filePath string, visited map[string]bool, depth int) error {
+	// Maximum recursion depth to prevent infinite recursion
+	const maxDepth = 10
+
+	if depth > maxDepth {
+		logger.Warn("Maximum recursion depth reached for file:", filePath)
+		return nil
 	}
 	}
-	s.scanMutex.Unlock()
 
 
-	// Ensure we reset scanning state when done
-	defer func() {
+	// Resolve the absolute path to handle symlinks properly
+	absPath, err := filepath.Abs(filePath)
+	if err != nil {
+		logger.Error("Failed to resolve absolute path for:", filePath, err)
+		return err
+	}
+
+	// Check for circular includes
+	if visited[absPath] {
+		// Circular include detected, skip this file
+		return nil
+	}
+
+	// Mark this file as visited
+	visited[absPath] = true
+
+	// Set scanning state to true only for the root call (depth 0)
+	var wasScanning bool
+	if depth == 0 {
 		s.scanMutex.Lock()
 		s.scanMutex.Lock()
-		s.scanning = false
-		// Broadcast the completion
-		s.statusChan <- false
+		wasScanning = s.scanning
+		s.scanning = true
+		if !wasScanning {
+			// Only broadcast if status changed from not scanning to scanning
+			s.statusChan <- true
+		}
 		s.scanMutex.Unlock()
 		s.scanMutex.Unlock()
-	}()
+
+		// Ensure we reset scanning state when done (only for root call)
+		defer func() {
+			s.scanMutex.Lock()
+			s.scanning = false
+			// Broadcast the completion
+			s.statusChan <- false
+			s.scanMutex.Unlock()
+		}()
+	}
 
 
 	// Open the file
 	// Open the file
-	file, err := os.Open(filePath)
+	file, err := os.Open(absPath)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	defer file.Close()
 	defer file.Close()
 
 
 	// Read the entire file content
 	// Read the entire file content
-	content, err := os.ReadFile(filePath)
+	content, err := os.ReadFile(absPath)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -331,9 +363,9 @@ func (s *Scanner) scanSingleFile(filePath string) error {
 	// Execute all registered callbacks
 	// Execute all registered callbacks
 	scanCallbacksMutex.RLock()
 	scanCallbacksMutex.RLock()
 	for _, callback := range scanCallbacks {
 	for _, callback := range scanCallbacks {
-		err := callback(filePath, content)
+		err := callback(absPath, content)
 		if err != nil {
 		if err != nil {
-			logger.Error("Callback error for file", filePath, ":", err)
+			logger.Error("Callback error for file", absPath, ":", err)
 		}
 		}
 	}
 	}
 	scanCallbacksMutex.RUnlock()
 	scanCallbacksMutex.RUnlock()
@@ -363,7 +395,7 @@ func (s *Scanner) scanSingleFile(filePath string) error {
 				for _, matchedFile := range matchedFiles {
 				for _, matchedFile := range matchedFiles {
 					fileInfo, err := os.Stat(matchedFile)
 					fileInfo, err := os.Stat(matchedFile)
 					if err == nil && !fileInfo.IsDir() {
 					if err == nil && !fileInfo.IsDir() {
-						err = s.scanSingleFile(matchedFile)
+						err = s.scanSingleFileWithDepth(matchedFile, visited, depth+1)
 						if err != nil {
 						if err != nil {
 							logger.Error("Failed to scan included file:", matchedFile, err)
 							logger.Error("Failed to scan included file:", matchedFile, err)
 						}
 						}
@@ -379,7 +411,7 @@ func (s *Scanner) scanSingleFile(filePath string) error {
 
 
 				fileInfo, err := os.Stat(includePath)
 				fileInfo, err := os.Stat(includePath)
 				if err == nil && !fileInfo.IsDir() {
 				if err == nil && !fileInfo.IsDir() {
-					err = s.scanSingleFile(includePath)
+					err = s.scanSingleFileWithDepth(includePath, visited, depth+1)
 					if err != nil {
 					if err != nil {
 						logger.Error("Failed to scan included file:", includePath, err)
 						logger.Error("Failed to scan included file:", includePath, err)
 					}
 					}