瀏覽代碼

enhance(upstream): proxy parser to support grpc_pass

0xJacky 2 周之前
父節點
當前提交
44246c9423
共有 3 個文件被更改,包括 209 次插入87 次删除
  1. 7 7
      internal/cron/upstream_availability.go
  2. 54 80
      internal/upstream/proxy_parser.go
  3. 148 0
      internal/upstream/proxy_parser_test.go

+ 7 - 7
internal/cron/upstream_availability.go

@@ -37,24 +37,24 @@ func executeUpstreamAvailabilityTest() {
 
 	targetCount := service.GetTargetCount()
 	if targetCount == 0 {
-		// logger.Debug("No upstream targets to test")
+		logger.Debug("No upstream targets to test")
 		return
 	}
 
 	// Check if we should skip this test due to active WebSocket connections
 	// (WebSocket connections trigger more frequent checks)
 	if hasActiveWebSocketConnections() {
-		// logger.Debug("Skipping scheduled test due to active WebSocket connections")
+		logger.Debug("Skipping scheduled test due to active WebSocket connections")
 		return
 	}
 
-	// start := time.Now()
-	// logger.Debug("Starting scheduled upstream availability test for", targetCount, "targets")
+	start := time.Now()
+	logger.Debug("Starting scheduled upstream availability test for", targetCount, "targets")
 
-	// service.PerformAvailabilityTest()
+	service.PerformAvailabilityTest()
 
-	// duration := time.Since(start)
-	// logger.Debug("Upstream availability test completed in", duration)
+	duration := time.Since(start)
+	logger.Debug("Upstream availability test completed in", duration)
 }
 
 // hasActiveWebSocketConnections checks if there are active WebSocket connections

+ 54 - 80
internal/upstream/proxy_parser.go

@@ -5,7 +5,6 @@ import (
 	"regexp"
 	"strings"
 
-	"github.com/0xJacky/Nginx-UI/internal/nginx"
 	"github.com/0xJacky/Nginx-UI/settings"
 )
 
@@ -13,7 +12,7 @@ import (
 type ProxyTarget struct {
 	Host       string `json:"host"`
 	Port       string `json:"port"`
-	Type       string `json:"type"`        // "proxy_pass" or "upstream"
+	Type       string `json:"type"`        // "proxy_pass", "grpc_pass" or "upstream"
 	Resolver   string `json:"resolver"`    // DNS resolver address (e.g., "127.0.0.1:8600")
 	IsConsul   bool   `json:"is_consul"`   // Whether this is a consul service discovery target
 	ServiceURL string `json:"service_url"` // Full service URL for consul (e.g., "service.consul service=redacted-net resolve")
@@ -82,7 +81,7 @@ func ParseProxyTargetsFromRawContent(content string) []ProxyTarget {
 			proxyPassURL := strings.TrimSpace(match[1])
 			// Skip if this proxy_pass references an upstream
 			if !isUpstreamReference(proxyPassURL, upstreamNames) {
-				target := parseProxyPassURL(proxyPassURL)
+				target := parseProxyPassURL(proxyPassURL, "proxy_pass")
 				if target.Host != "" {
 					targets = append(targets, target)
 				}
@@ -90,80 +89,51 @@ func ParseProxyTargetsFromRawContent(content string) []ProxyTarget {
 		}
 	}
 
-	return deduplicateTargets(targets)
-}
-
-// parseUpstreamServers extracts server addresses from upstream blocks
-func parseUpstreamServers(upstream *nginx.NgxUpstream) []ProxyTarget {
-	var targets []ProxyTarget
-
-	// Create upstream context for this upstream block
-	ctx := &UpstreamContext{
-		Name: upstream.Name,
-	}
-
-	// Extract resolver from upstream directives
-	for _, directive := range upstream.Directives {
-		if directive.Directive == "resolver" {
-			resolverParts := strings.Fields(directive.Params)
-			if len(resolverParts) > 0 {
-				ctx.Resolver = resolverParts[0]
-			}
-		}
-	}
-
-	for _, directive := range upstream.Directives {
-		if directive.Directive == "server" {
-			target := parseServerAddress(directive.Params, "upstream", ctx)
-			if target.Host != "" {
-				targets = append(targets, target)
-			}
-		}
-	}
-
-	return targets
-}
-
-// parseLocationProxyPass extracts proxy_pass from location content
-func parseLocationProxyPass(content string) []ProxyTarget {
-	var targets []ProxyTarget
-
-	// Use regex to find proxy_pass directives
-	proxyPassRegex := regexp.MustCompile(`(?m)^\s*proxy_pass\s+([^;]+);`)
-	matches := proxyPassRegex.FindAllStringSubmatch(content, -1)
+	// Parse grpc_pass directives, but skip upstream references
+	grpcPassRegex := regexp.MustCompile(`(?m)^\s*grpc_pass\s+([^;]+);`)
+	grpcMatches := grpcPassRegex.FindAllStringSubmatch(content, -1)
 
-	for _, match := range matches {
+	for _, match := range grpcMatches {
 		if len(match) >= 2 {
-			target := parseProxyPassURL(strings.TrimSpace(match[1]))
-			if target.Host != "" {
-				targets = append(targets, target)
+			grpcPassURL := strings.TrimSpace(match[1])
+			// Skip if this grpc_pass references an upstream
+			if !isUpstreamReference(grpcPassURL, upstreamNames) {
+				target := parseProxyPassURL(grpcPassURL, "grpc_pass")
+				if target.Host != "" {
+					targets = append(targets, target)
+				}
 			}
 		}
 	}
 
-	return targets
+	return deduplicateTargets(targets)
 }
 
-// parseProxyPassURL parses a proxy_pass URL and extracts host and port
-func parseProxyPassURL(proxyPass string) ProxyTarget {
-	proxyPass = strings.TrimSpace(proxyPass)
+// parseProxyPassURL parses a proxy_pass or grpc_pass URL and extracts host and port
+func parseProxyPassURL(passURL, passType string) ProxyTarget {
+	passURL = strings.TrimSpace(passURL)
 
 	// Skip URLs that contain Nginx variables
-	if strings.Contains(proxyPass, "$") {
+	if strings.Contains(passURL, "$") {
 		return ProxyTarget{}
 	}
 
-	// Handle HTTP/HTTPS URLs (e.g., "http://backend")
-	if strings.HasPrefix(proxyPass, "http://") || strings.HasPrefix(proxyPass, "https://") {
-		if parsedURL, err := url.Parse(proxyPass); err == nil {
+	// Handle HTTP/HTTPS/gRPC URLs (e.g., "http://backend", "grpc://backend")
+	if strings.HasPrefix(passURL, "http://") || strings.HasPrefix(passURL, "https://") || strings.HasPrefix(passURL, "grpc://") || strings.HasPrefix(passURL, "grpcs://") {
+		if parsedURL, err := url.Parse(passURL); err == nil {
 			host := parsedURL.Hostname()
 			port := parsedURL.Port()
 
 			// Set default ports if not specified
 			if port == "" {
-				if parsedURL.Scheme == "https" {
+				switch parsedURL.Scheme {
+				case "https":
+					port = "443"
+				case "grpcs":
 					port = "443"
-				} else {
+				case "grpc":
+					port = "80"
+				default: // http
 					port = "80"
 				}
 			}
@@ -176,15 +146,15 @@ func parseProxyPassURL(proxyPass string) ProxyTarget {
 			return ProxyTarget{
 				Host: host,
 				Port: port,
-				Type: "proxy_pass",
+				Type: passType,
 			}
 		}
 	}
 
 	// Handle direct address format for stream module (e.g., "127.0.0.1:8080", "backend.example.com:12345")
-	// This is used in stream configurations where proxy_pass doesn't require a protocol
-	if !strings.Contains(proxyPass, "://") {
-		target := parseServerAddress(proxyPass, "proxy_pass", nil) // No upstream context for this function
+	// This is used in stream configurations where proxy_pass/grpc_pass doesn't require a protocol
+	if !strings.Contains(passURL, "://") {
+		target := parseServerAddress(passURL, passType, nil) // No upstream context for this function
 
 		// Skip if this is the HTTP challenge port used by Let's Encrypt
 		if target.Host == "127.0.0.1" && target.Port == settings.CertSettings.HTTPChallengePort {
@@ -262,7 +232,7 @@ func isConsulServiceDiscovery(serverAddr string) bool {
 	if strings.Contains(serverAddr, "service=") && strings.Contains(serverAddr, "resolve") {
 		return true
 	}
-	// Legacy consul format: "service.consul service=name resolve" 
+	// Legacy consul format: "service.consul service=name resolve"
 	return strings.Contains(serverAddr, "service.consul") &&
 		(strings.Contains(serverAddr, "service=") || strings.Contains(serverAddr, "resolve"))
 }
@@ -327,17 +297,17 @@ func deduplicateTargets(targets []ProxyTarget) []ProxyTarget {
 	return result
 }
 
-// isUpstreamReference checks if a proxy_pass URL references an upstream block
-func isUpstreamReference(proxyPass string, upstreamNames map[string]bool) bool {
-	proxyPass = strings.TrimSpace(proxyPass)
+// isUpstreamReference checks if a proxy_pass or grpc_pass URL references an upstream block
+func isUpstreamReference(passURL string, upstreamNames map[string]bool) bool {
+	passURL = strings.TrimSpace(passURL)
 
-	// For HTTP/HTTPS URLs, parse the URL to extract the hostname
-	if strings.HasPrefix(proxyPass, "http://") || strings.HasPrefix(proxyPass, "https://") {
+	// For HTTP/HTTPS/gRPC URLs, parse the URL to extract the hostname
+	if strings.HasPrefix(passURL, "http://") || strings.HasPrefix(passURL, "https://") || strings.HasPrefix(passURL, "grpc://") || strings.HasPrefix(passURL, "grpcs://") {
 		// Handle URLs with nginx variables (e.g., "https://myUpStr$request_uri")
 		// Extract the scheme and hostname part before any nginx variables
-		schemeAndHost := proxyPass
-		if dollarIndex := strings.Index(proxyPass, "$"); dollarIndex != -1 {
-			schemeAndHost = proxyPass[:dollarIndex]
+		schemeAndHost := passURL
+		if dollarIndex := strings.Index(passURL, "$"); dollarIndex != -1 {
+			schemeAndHost = passURL[:dollarIndex]
 		}
 
 		// Try to parse the URL, if it fails, try manual extraction
@@ -348,11 +318,15 @@ func isUpstreamReference(proxyPass string, upstreamNames map[string]bool) bool {
 		} else {
 			// Fallback: manually extract hostname for URLs with variables
 			// Remove scheme prefix
-			withoutScheme := proxyPass
-			if strings.HasPrefix(proxyPass, "https://") {
-				withoutScheme = strings.TrimPrefix(proxyPass, "https://")
-			} else if strings.HasPrefix(proxyPass, "http://") {
-				withoutScheme = strings.TrimPrefix(proxyPass, "http://")
+			withoutScheme := passURL
+			if strings.HasPrefix(passURL, "https://") {
+				withoutScheme = strings.TrimPrefix(passURL, "https://")
+			} else if strings.HasPrefix(passURL, "http://") {
+				withoutScheme = strings.TrimPrefix(passURL, "http://")
+			} else if strings.HasPrefix(passURL, "grpc://") {
+				withoutScheme = strings.TrimPrefix(passURL, "grpc://")
+			} else if strings.HasPrefix(passURL, "grpcs://") {
+				withoutScheme = strings.TrimPrefix(passURL, "grpcs://")
 			}
 
 			// Extract hostname before any path, port, or variable
@@ -371,10 +345,10 @@ func isUpstreamReference(proxyPass string, upstreamNames map[string]bool) bool {
 		}
 	}
 
-	// For stream module, proxy_pass can directly reference upstream name without protocol
-	// Check if the proxy_pass value directly matches an upstream name
-	if !strings.Contains(proxyPass, "://") && !strings.Contains(proxyPass, ":") {
-		return upstreamNames[proxyPass]
+	// For stream module, proxy_pass/grpc_pass can directly reference upstream name without protocol
+	// Check if the pass value directly matches an upstream name
+	if !strings.Contains(passURL, "://") && !strings.Contains(passURL, ":") {
+		return upstreamNames[passURL]
 	}
 
 	return false

+ 148 - 0
internal/upstream/proxy_parser_test.go

@@ -605,3 +605,151 @@ server {
 		}
 	}
 }
+
+func TestParseGrpcPassDirectives(t *testing.T) {
+	config := `
+upstream grpc-backend {
+    server 127.0.0.1:9090;
+    server 127.0.0.1:9091;
+}
+
+server {
+    listen 80 http2;
+    server_name grpc.example.com;
+
+    location /api.v1.Service/ {
+        grpc_pass grpc://127.0.0.1:9090;
+    }
+
+    location /api.v2.Service/ {
+        grpc_pass grpcs://secure-grpc.example.com:443;
+    }
+
+    location /upstream-service/ {
+        grpc_pass grpc://grpc-backend;
+    }
+
+    location /direct-service/ {
+        grpc_pass 192.168.1.100:9090;
+    }
+}
+`
+
+	targets := ParseProxyTargetsFromRawContent(config)
+
+	// Verify we found the expected targets
+	expected := []struct {
+		host string
+		port string
+		typ  string
+	}{
+		{"127.0.0.1", "9090", "upstream"},
+		{"127.0.0.1", "9091", "upstream"},
+		{"127.0.0.1", "9090", "grpc_pass"},
+		{"secure-grpc.example.com", "443", "grpc_pass"},
+		{"192.168.1.100", "9090", "grpc_pass"},
+	}
+
+	if len(targets) < len(expected) {
+		t.Errorf("Expected at least %d targets, got %d", len(expected), len(targets))
+		for i, target := range targets {
+			t.Logf("Target %d: Host=%s, Port=%s, Type=%s", i+1, target.Host, target.Port, target.Type)
+		}
+		return
+	}
+
+	// Count targets by type
+	grpcPassCount := 0
+	upstreamCount := 0
+	for _, target := range targets {
+		switch target.Type {
+		case "grpc_pass":
+			grpcPassCount++
+		case "upstream":
+			upstreamCount++
+		}
+	}
+
+	if grpcPassCount != 3 {
+		t.Errorf("Expected 3 grpc_pass targets, got %d", grpcPassCount)
+	}
+	if upstreamCount != 2 {
+		t.Errorf("Expected 2 upstream targets, got %d", upstreamCount)
+	}
+
+	// Verify specific targets exist
+	found := make(map[string]bool)
+	for _, target := range targets {
+		key := target.Host + ":" + target.Port + ":" + target.Type
+		found[key] = true
+	}
+
+	expectedKeys := []string{
+		"127.0.0.1:9090:upstream",
+		"127.0.0.1:9091:upstream",
+		"127.0.0.1:9090:grpc_pass",
+		"secure-grpc.example.com:443:grpc_pass",
+		"192.168.1.100:9090:grpc_pass",
+	}
+
+	for _, key := range expectedKeys {
+		if !found[key] {
+			t.Errorf("Expected to find target: %s", key)
+		}
+	}
+}
+
+func TestGrpcPassPortDefaults(t *testing.T) {
+	tests := []struct {
+		name         string
+		grpcPassURL  string
+		expectedHost string
+		expectedPort string
+		expectedType string
+	}{
+		{
+			name:         "grpc:// without port should default to 80",
+			grpcPassURL:  "grpc://api.example.com",
+			expectedHost: "api.example.com",
+			expectedPort: "80",
+			expectedType: "grpc_pass",
+		},
+		{
+			name:         "grpcs:// without port should default to 443",
+			grpcPassURL:  "grpcs://secure-api.example.com",
+			expectedHost: "secure-api.example.com",
+			expectedPort: "443",
+			expectedType: "grpc_pass",
+		},
+		{
+			name:         "grpc:// with explicit port",
+			grpcPassURL:  "grpc://api.example.com:9090",
+			expectedHost: "api.example.com",
+			expectedPort: "9090",
+			expectedType: "grpc_pass",
+		},
+		{
+			name:         "grpcs:// with explicit port",
+			grpcPassURL:  "grpcs://secure-api.example.com:9443",
+			expectedHost: "secure-api.example.com",
+			expectedPort: "9443",
+			expectedType: "grpc_pass",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			target := parseProxyPassURL(tt.grpcPassURL, "grpc_pass")
+
+			if target.Host != tt.expectedHost {
+				t.Errorf("Expected host %s, got %s", tt.expectedHost, target.Host)
+			}
+			if target.Port != tt.expectedPort {
+				t.Errorf("Expected port %s, got %s", tt.expectedPort, target.Port)
+			}
+			if target.Type != tt.expectedType {
+				t.Errorf("Expected type %s, got %s", tt.expectedType, target.Type)
+			}
+		})
+	}
+}