Преглед изворни кода

feat(socket): implement formatSocketAddress function for IPv6 handling

0xJacky пре 1 недеља
родитељ
комит
8de49f7759

+ 4 - 4
internal/upstream/dynamic_resolver.go

@@ -242,7 +242,7 @@ func TestDynamicTargets(dynamicTargets []ProxyTarget) map[string]*Status {
 			dynamicTargetsByResolver[target.Resolver] = append(dynamicTargetsByResolver[target.Resolver], target)
 		} else {
 			// No resolver specified, mark as offline
-			key := target.Host + ":" + target.Port
+			key := formatSocketAddress(target.Host, target.Port)
 			result[key] = &Status{
 				Online:  false,
 				Latency: 0,
@@ -255,7 +255,7 @@ func TestDynamicTargets(dynamicTargets []ProxyTarget) map[string]*Status {
 		dynamicResolver := NewDynamicResolver(resolver)
 
 		for _, target := range targets {
-			key := target.Host + ":" + target.Port
+			key := formatSocketAddress(target.Host, target.Port)
 
 			// Try to resolve the service
 			addresses, err := dynamicResolver.ResolveService(target.ServiceURL)
@@ -305,8 +305,8 @@ func EnhancedAvailabilityTest(targets []ProxyTarget) map[string]*Status {
 		if target.IsConsul && target.Resolver != "" {
 			dynamicTargets = append(dynamicTargets, target)
 		} else {
-			// Regular target - use existing format for traditional AvailabilityTest
-			key := target.Host + ":" + target.Port
+			// Regular target - use properly formatted socket address for traditional AvailabilityTest
+			key := formatSocketAddress(target.Host, target.Port)
 			regularTargets = append(regularTargets, key)
 		}
 	}

+ 133 - 0
internal/upstream/ipv6_socket_test.go

@@ -0,0 +1,133 @@
+package upstream
+
+import (
+	"sync"
+	"testing"
+)
+
+func TestFormatSocketAddress_IPv6(t *testing.T) {
+	tests := []struct {
+		name     string
+		host     string
+		port     string
+		expected string
+	}{
+		{
+			name:     "IPv6 all addresses",
+			host:     "::",
+			port:     "9001",
+			expected: "[::]:9001",
+		},
+		{
+			name:     "IPv6 localhost",
+			host:     "::1",
+			port:     "8080",
+			expected: "[::1]:8080",
+		},
+		{
+			name:     "IPv6 full address",
+			host:     "2001:db8::1",
+			port:     "9000",
+			expected: "[2001:db8::1]:9000",
+		},
+		{
+			name:     "IPv6 with brackets already",
+			host:     "[::1]",
+			port:     "8080",
+			expected: "[::1]:8080",
+		},
+		{
+			name:     "IPv4 address",
+			host:     "127.0.0.1",
+			port:     "9001",
+			expected: "127.0.0.1:9001",
+		},
+		{
+			name:     "hostname",
+			host:     "example.com",
+			port:     "80",
+			expected: "example.com:80",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			result := formatSocketAddress(tt.host, tt.port)
+			if result != tt.expected {
+				t.Errorf("formatSocketAddress(%q, %q) = %q, want %q", tt.host, tt.port, result, tt.expected)
+			}
+		})
+	}
+}
+
+func TestAvailabilityTest_IPv6Socket(t *testing.T) {
+	// Test that IPv6 socket addresses are properly formatted
+	// This test verifies that the socket string passed to net.DialTimeout is correct
+
+	// Test with properly formatted IPv6 addresses
+	sockets := []string{
+		"[::1]:8080",     // IPv6 localhost with port
+		"127.0.0.1:8080", // IPv4 for comparison
+	}
+
+	// This should not panic or cause parsing errors
+	results := AvailabilityTest(sockets)
+
+	// Verify we get results for both sockets (even if they're offline)
+	if len(results) != 2 {
+		t.Errorf("Expected 2 results, got %d", len(results))
+	}
+
+	// Check that the keys are preserved correctly
+	for _, socket := range sockets {
+		if _, exists := results[socket]; !exists {
+			t.Errorf("Expected result for socket %q", socket)
+		}
+	}
+}
+
+func TestTCPLatency_IPv6Support(t *testing.T) {
+	// Test that testTCPLatency can handle IPv6 addresses correctly
+	// Note: This test verifies the function doesn't panic with IPv6 addresses
+	// The actual connection will likely fail since we're testing non-existent services
+
+	tests := []struct {
+		name   string
+		socket string
+	}{
+		{
+			name:   "IPv6 localhost",
+			socket: "[::1]:8080",
+		},
+		{
+			name:   "IPv6 all addresses",
+			socket: "[::]:9001",
+		},
+		{
+			name:   "IPv4 for comparison",
+			socket: "127.0.0.1:8080",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var wg sync.WaitGroup
+			status := &Status{}
+
+			wg.Add(1)
+
+			// This should not panic even if the connection fails
+			defer func() {
+				if r := recover(); r != nil {
+					t.Errorf("testTCPLatency panicked with socket %q: %v", tt.socket, r)
+				}
+			}()
+
+			testTCPLatency(&wg, tt.socket, status)
+			wg.Wait()
+
+			// We don't check if it's online since the service likely doesn't exist
+			// We just verify the function completed without panicking
+		})
+	}
+}

+ 20 - 3
internal/upstream/service.go

@@ -3,6 +3,7 @@ package upstream
 import (
 	"maps"
 	"slices"
+	"strings"
 	"sync"
 	"time"
 
@@ -44,6 +45,22 @@ var (
 	serviceOnce     sync.Once
 )
 
+// formatSocketAddress formats a host:port combination into a proper socket address
+// For IPv6 addresses, it adds brackets around the host if they're not already present
+func formatSocketAddress(host, port string) string {
+	// Check if this is an IPv6 address by looking for colons
+	if strings.Contains(host, ":") {
+		// IPv6 address - check if it already has brackets
+		if !strings.HasPrefix(host, "[") {
+			return "[" + host + "]:" + port
+		}
+		// Already has brackets, just append port
+		return host + ":" + port
+	}
+	// IPv4 address or hostname
+	return host + ":" + port
+}
+
 // GetUpstreamService returns the singleton upstream service instance
 func GetUpstreamService() *UpstreamService {
 	serviceOnce.Do(func() {
@@ -117,7 +134,7 @@ func (s *UpstreamService) updateTargetsFromConfig(configPath string, targets []P
 	// Add/update new targets
 	newTargetKeys := make([]string, 0, len(targets))
 	for _, target := range targets {
-		key := target.Host + ":" + target.Port
+		key := formatSocketAddress(target.Host, target.Port)
 		newTargetKeys = append(newTargetKeys, key)
 
 		if existingTarget, exists := s.targets[key]; exists {
@@ -228,8 +245,8 @@ func (s *UpstreamService) PerformAvailabilityTest() {
 		if targetInfo.ProxyTarget.IsConsul {
 			consulTargets = append(consulTargets, targetInfo.ProxyTarget)
 		} else {
-			// Traditional target - use host:port key format
-			key := targetInfo.ProxyTarget.Host + ":" + targetInfo.ProxyTarget.Port
+			// Traditional target - use properly formatted socket address
+			key := formatSocketAddress(targetInfo.ProxyTarget.Host, targetInfo.ProxyTarget.Port)
 			regularTargetKeys = append(regularTargetKeys, key)
 		}
 	}

+ 3 - 1
internal/upstream/upstream_parser.go

@@ -319,7 +319,9 @@ func deduplicateTargets(targets []ProxyTarget) []ProxyTarget {
 
 	for _, target := range targets {
 		// Create a unique key that includes resolver and consul information
-		key := target.Host + ":" + target.Port + ":" + target.Type + ":" + target.Resolver
+		// Use formatSocketAddress for proper IPv6 handling in the key
+		socketAddr := formatSocketAddress(target.Host, target.Port)
+		key := socketAddr + ":" + target.Type + ":" + target.Resolver
 		if target.IsConsul {
 			key += ":consul:" + target.ServiceURL
 		}

+ 7 - 7
internal/upstream/upstream_parser_test.go

@@ -201,12 +201,12 @@ server {
 	// Create a map for easier comparison
 	targetMap := make(map[string]ProxyTarget)
 	for _, target := range targets {
-		key := target.Host + ":" + target.Port + ":" + target.Type
+		key := formatSocketAddress(target.Host, target.Port) + ":" + target.Type
 		targetMap[key] = target
 	}
 
 	for _, expected := range expectedTargets {
-		key := expected.Host + ":" + expected.Port + ":" + expected.Type
+		key := formatSocketAddress(expected.Host, expected.Port) + ":" + expected.Type
 		if _, found := targetMap[key]; !found {
 			t.Errorf("Expected target not found: %+v", expected)
 		}
@@ -258,12 +258,12 @@ server {
 	// Create a map for easier comparison
 	targetMap := make(map[string]ProxyTarget)
 	for _, target := range targets {
-		key := target.Host + ":" + target.Port + ":" + target.Type
+		key := formatSocketAddress(target.Host, target.Port) + ":" + target.Type
 		targetMap[key] = target
 	}
 
 	for _, expected := range expectedTargets {
-		key := expected.Host + ":" + expected.Port + ":" + expected.Type
+		key := formatSocketAddress(expected.Host, expected.Port) + ":" + expected.Type
 		if _, found := targetMap[key]; !found {
 			t.Errorf("Expected target not found: %+v", expected)
 		}
@@ -332,12 +332,12 @@ server {
 	// Create a map for easier comparison
 	targetMap := make(map[string]ProxyTarget)
 	for _, target := range targets {
-		key := target.Host + ":" + target.Port + ":" + target.Type
+		key := formatSocketAddress(target.Host, target.Port) + ":" + target.Type
 		targetMap[key] = target
 	}
 
 	for _, expected := range expectedTargets {
-		key := expected.Host + ":" + expected.Port + ":" + expected.Type
+		key := formatSocketAddress(expected.Host, expected.Port) + ":" + expected.Type
 		if _, found := targetMap[key]; !found {
 			t.Errorf("Expected target not found: %+v", expected)
 		}
@@ -680,7 +680,7 @@ server {
 	// Verify specific targets exist
 	found := make(map[string]bool)
 	for _, target := range targets {
-		key := target.Host + ":" + target.Port + ":" + target.Type
+		key := formatSocketAddress(target.Host, target.Port) + ":" + target.Type
 		found[key] = true
 	}