package sitecheck import ( "context" "crypto/tls" "encoding/base64" "fmt" "io" "maps" "net" "net/http" "net/url" "regexp" "strings" "sync" "time" "github.com/0xJacky/Nginx-UI/internal/site" "github.com/0xJacky/Nginx-UI/model" "github.com/0xJacky/Nginx-UI/query" "github.com/uozi-tech/cosy/logger" ) type SiteChecker struct { sites map[string]*SiteInfo mu sync.RWMutex options CheckOptions client *http.Client onUpdateCallback func([]*SiteInfo) // Callback for notifying updates } // NewSiteChecker creates a new site checker func NewSiteChecker(options CheckOptions) *SiteChecker { transport := &http.Transport{ Dial: (&net.Dialer{ Timeout: 5 * time.Second, }).Dial, TLSHandshakeTimeout: 5 * time.Second, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, // Skip SSL verification for internal sites }, } client := &http.Client{ Transport: transport, Timeout: options.Timeout, } if !options.FollowRedirects { client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } } else if options.MaxRedirects > 0 { client.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) >= options.MaxRedirects { return fmt.Errorf("stopped after %d redirects", options.MaxRedirects) } return nil } } return &SiteChecker{ sites: make(map[string]*SiteInfo), options: options, client: client, } } // SetUpdateCallback sets the callback function for site updates func (sc *SiteChecker) SetUpdateCallback(callback func([]*SiteInfo)) { sc.onUpdateCallback = callback } // CollectSites collects URLs from enabled indexed sites only func (sc *SiteChecker) CollectSites() { sc.mu.Lock() defer sc.mu.Unlock() // Clear existing sites sc.sites = make(map[string]*SiteInfo) // Debug: log indexed sites count logger.Infof("Found %d indexed sites", len(site.IndexedSites)) // Collect URLs from indexed sites, but only from enabled sites for siteName, indexedSite := range site.IndexedSites { // Check site status - only collect from enabled sites siteStatus := site.GetSiteStatus(siteName) if siteStatus != site.SiteStatusEnabled { logger.Debugf("Skipping site %s (status: %s) - only collecting from enabled sites", siteName, siteStatus) continue } logger.Debugf("Processing enabled site: %s with %d URLs", siteName, len(indexedSite.Urls)) for _, url := range indexedSite.Urls { if url != "" { logger.Debugf("Adding site URL: %s", url) // Load site config to determine display URL config, err := LoadSiteConfig(url) protocol := "http" // default protocol if err == nil && config != nil && config.HealthCheckConfig != nil && config.HealthCheckConfig.Protocol != "" { protocol = config.HealthCheckConfig.Protocol logger.Debugf("Site %s using protocol: %s", url, protocol) } else { logger.Debugf("Site %s using default protocol: %s (config error: %v)", url, protocol, err) } // Parse URL components for legacy fields _, hostPort := parseURLComponents(url, protocol) // Get or create site config to get ID siteConfig := getOrCreateSiteConfigForURL(url) siteInfo := &SiteInfo{ ID: siteConfig.ID, Host: siteConfig.Host, Port: siteConfig.Port, Scheme: siteConfig.Scheme, DisplayURL: siteConfig.GetURL(), Name: extractDomainName(url), Status: StatusChecking, LastChecked: time.Now().Unix(), // Legacy fields for backward compatibility URL: url, HealthCheckProtocol: protocol, HostPort: hostPort, } sc.sites[url] = siteInfo } } } logger.Infof("Collected %d sites for checking (enabled sites only)", len(sc.sites)) } // getOrCreateSiteConfigForURL gets or creates a site config for the given URL func getOrCreateSiteConfigForURL(url string) *model.SiteConfig { // Parse URL to get host:port tempConfig := &model.SiteConfig{} tempConfig.SetFromURL(url) sc := query.SiteConfig siteConfig, err := sc.Where(sc.Host.Eq(tempConfig.Host)).First() if err != nil { // Record doesn't exist, create a new one newConfig := &model.SiteConfig{ Host: tempConfig.Host, Port: tempConfig.Port, Scheme: tempConfig.Scheme, DisplayURL: url, HealthCheckEnabled: true, CheckInterval: 300, Timeout: 10, UserAgent: "Nginx-UI Site Checker/1.0", MaxRedirects: 3, FollowRedirects: true, CheckFavicon: true, } // Create the record in database if err := sc.Create(newConfig); err != nil { logger.Errorf("Failed to create site config for %s: %v", url, err) // Return temp config with a fake ID to avoid crashes tempConfig.ID = 0 return tempConfig } return newConfig } // Record exists, ensure it has the correct URL information if siteConfig.DisplayURL == "" { siteConfig.DisplayURL = url siteConfig.SetFromURL(url) // Try to save the updated config, but don't fail if it doesn't work sc.Save(siteConfig) } return siteConfig } // CheckSite checks a single site's availability func (sc *SiteChecker) CheckSite(ctx context.Context, siteURL string) (*SiteInfo, error) { // Try enhanced health check first if config exists config, err := LoadSiteConfig(siteURL) if err == nil && config != nil && config.HealthCheckConfig != nil { enhancedChecker := NewEnhancedSiteChecker() siteInfo, err := enhancedChecker.CheckSiteWithConfig(ctx, siteURL, config.HealthCheckConfig) if err == nil && siteInfo != nil { // Fill in additional details siteInfo.Name = extractDomainName(siteURL) siteInfo.LastChecked = time.Now().Unix() // Set health check protocol and display URL siteInfo.HealthCheckProtocol = config.HealthCheckConfig.Protocol siteInfo.DisplayURL = generateDisplayURL(siteURL, config.HealthCheckConfig.Protocol) // Parse URL components scheme, hostPort := parseURLComponents(siteURL, config.HealthCheckConfig.Protocol) siteInfo.Scheme = scheme siteInfo.HostPort = hostPort // Try to get favicon if enabled and not a gRPC check if sc.options.CheckFavicon && !isGRPCProtocol(config.HealthCheckConfig.Protocol) { faviconURL, faviconData := sc.tryGetFavicon(ctx, siteURL) siteInfo.FaviconURL = faviconURL siteInfo.FaviconData = faviconData } return siteInfo, nil } } // Fallback to basic HTTP check, but preserve original protocol if available originalProtocol := "http" // default if config != nil && config.HealthCheckConfig != nil && config.HealthCheckConfig.Protocol != "" { originalProtocol = config.HealthCheckConfig.Protocol } return sc.checkSiteBasic(ctx, siteURL, originalProtocol) } // checkSiteBasic performs basic HTTP health check func (sc *SiteChecker) checkSiteBasic(ctx context.Context, siteURL string, originalProtocol string) (*SiteInfo, error) { start := time.Now() req, err := http.NewRequestWithContext(ctx, "GET", siteURL, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", sc.options.UserAgent) resp, err := sc.client.Do(req) if err != nil { // Parse URL components for legacy fields _, hostPort := parseURLComponents(siteURL, originalProtocol) // Get or create site config to get ID siteConfig := getOrCreateSiteConfigForURL(siteURL) return &SiteInfo{ ID: siteConfig.ID, Host: siteConfig.Host, Port: siteConfig.Port, Scheme: siteConfig.Scheme, DisplayURL: siteConfig.GetURL(), Name: extractDomainName(siteURL), Status: StatusOffline, ResponseTime: time.Since(start).Milliseconds(), LastChecked: time.Now().Unix(), Error: err.Error(), // Legacy fields for backward compatibility URL: siteURL, HealthCheckProtocol: originalProtocol, HostPort: hostPort, }, nil } defer resp.Body.Close() responseTime := time.Since(start).Milliseconds() // Parse URL components for legacy fields _, hostPort := parseURLComponents(siteURL, originalProtocol) // Get or create site config to get ID siteConfig := getOrCreateSiteConfigForURL(siteURL) siteInfo := &SiteInfo{ ID: siteConfig.ID, Host: siteConfig.Host, Port: siteConfig.Port, Scheme: siteConfig.Scheme, DisplayURL: siteConfig.GetURL(), Name: extractDomainName(siteURL), StatusCode: resp.StatusCode, ResponseTime: responseTime, LastChecked: time.Now().Unix(), // Legacy fields for backward compatibility URL: siteURL, HealthCheckProtocol: originalProtocol, HostPort: hostPort, } // Determine status based on status code if resp.StatusCode >= 200 && resp.StatusCode < 400 { siteInfo.Status = StatusOnline } else { siteInfo.Status = StatusError siteInfo.Error = fmt.Sprintf("HTTP %d", resp.StatusCode) } // Read response body for title and favicon extraction body, err := io.ReadAll(resp.Body) if err != nil { logger.Warnf("Failed to read response body for %s: %v", siteURL, err) return siteInfo, nil } // Extract title siteInfo.Title = extractTitle(string(body)) // Extract favicon if enabled if sc.options.CheckFavicon { faviconURL, faviconData := sc.extractFavicon(ctx, siteURL, string(body)) siteInfo.FaviconURL = faviconURL siteInfo.FaviconData = faviconData } return siteInfo, nil } // tryGetFavicon attempts to get favicon for enhanced checks func (sc *SiteChecker) tryGetFavicon(ctx context.Context, siteURL string) (string, string) { // Make a simple GET request to get the HTML req, err := http.NewRequestWithContext(ctx, "GET", siteURL, nil) if err != nil { return "", "" } req.Header.Set("User-Agent", sc.options.UserAgent) resp, err := sc.client.Do(req) if err != nil { return "", "" } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 400 { return "", "" } body, err := io.ReadAll(resp.Body) if err != nil { return "", "" } return sc.extractFavicon(ctx, siteURL, string(body)) } // CheckAllSites checks all collected sites concurrently func (sc *SiteChecker) CheckAllSites(ctx context.Context) { sc.mu.RLock() urls := make([]string, 0, len(sc.sites)) for url := range sc.sites { urls = append(urls, url) } sc.mu.RUnlock() // Use a semaphore to limit concurrent requests semaphore := make(chan struct{}, 10) // Max 10 concurrent requests var wg sync.WaitGroup for _, url := range urls { wg.Add(1) go func(siteURL string) { defer wg.Done() semaphore <- struct{}{} // Acquire semaphore defer func() { <-semaphore }() // Release semaphore siteInfo, err := sc.CheckSite(ctx, siteURL) if err != nil { logger.Errorf("Failed to check site %s: %v", siteURL, err) return } sc.mu.Lock() sc.sites[siteURL] = siteInfo sc.mu.Unlock() }(url) } wg.Wait() logger.Infof("Completed checking %d sites", len(urls)) // Notify WebSocket clients of the update if sc.onUpdateCallback != nil { sites := make([]*SiteInfo, 0, len(sc.sites)) sc.mu.RLock() for _, site := range sc.sites { sites = append(sites, site) } sc.mu.RUnlock() sc.onUpdateCallback(sites) } } // GetSites returns all checked sites func (sc *SiteChecker) GetSites() map[string]*SiteInfo { sc.mu.RLock() defer sc.mu.RUnlock() // Create a copy to avoid race conditions result := make(map[string]*SiteInfo) maps.Copy(result, sc.sites) return result } // GetSitesList returns sites as a slice func (sc *SiteChecker) GetSitesList() []*SiteInfo { sc.mu.RLock() defer sc.mu.RUnlock() result := make([]*SiteInfo, 0, len(sc.sites)) for _, site := range sc.sites { result = append(result, site) } return result } // extractDomainName extracts domain name from URL func extractDomainName(siteURL string) string { parsed, err := url.Parse(siteURL) if err != nil { return siteURL } return parsed.Host } // extractTitle extracts title from HTML content func extractTitle(html string) string { titleRegex := regexp.MustCompile(`(?i)