websocket.go 7.5 KB


  1. package cluster
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/hex"
  6. "encoding/json"
  7. "net/http"
  8. "sync"
  9. "time"
  10. "github.com/0xJacky/Nginx-UI/internal/analytic"
  11. "github.com/0xJacky/Nginx-UI/internal/helper"
  12. "github.com/0xJacky/Nginx-UI/internal/kernel"
  13. "github.com/0xJacky/Nginx-UI/model"
  14. "github.com/gin-gonic/gin"
  15. "github.com/gorilla/websocket"
  16. "github.com/uozi-tech/cosy/logger"
  17. )
  18. // WebSocketMessage represents the structure of messages sent to the client
  19. type WebSocketMessage struct {
  20. Event string `json:"event"`
  21. Data interface{} `json:"data"`
  22. }
  23. // Client represents a WebSocket client connection for cluster node monitoring
  24. type Client struct {
  25. conn *websocket.Conn
  26. send chan WebSocketMessage
  27. ctx context.Context
  28. cancel context.CancelFunc
  29. }
  30. // Hub maintains the set of active clients and broadcasts messages to them
  31. type Hub struct {
  32. clients map[*Client]bool
  33. broadcast chan WebSocketMessage
  34. register chan *Client
  35. unregister chan *Client
  36. mutex sync.RWMutex
  37. }
  38. var (
  39. hub *Hub
  40. hubOnce sync.Once
  41. )
  42. // GetHub returns the singleton hub instance
  43. func GetHub() *Hub {
  44. hubOnce.Do(func() {
  45. hub = &Hub{
  46. clients: make(map[*Client]bool),
  47. broadcast: make(chan WebSocketMessage, 1024), // Increased buffer size
  48. register: make(chan *Client),
  49. unregister: make(chan *Client),
  50. }
  51. go hub.run()
  52. })
  53. return hub
  54. }
  55. // run handles the main hub loop
  56. func (h *Hub) run() {
  57. for {
  58. select {
  59. case client := <-h.register:
  60. h.mutex.Lock()
  61. h.clients[client] = true
  62. h.mutex.Unlock()
  63. logger.Debug("Cluster node client connected, total clients:", len(h.clients))
  64. case client := <-h.unregister:
  65. h.mutex.Lock()
  66. if _, ok := h.clients[client]; ok {
  67. delete(h.clients, client)
  68. close(client.send)
  69. }
  70. h.mutex.Unlock()
  71. logger.Debug("Cluster node client disconnected, total clients:", len(h.clients))
  72. case message := <-h.broadcast:
  73. h.mutex.RLock()
  74. deadClients := make([]*Client, 0)
  75. for client := range h.clients {
  76. select {
  77. case client.send <- message:
  78. case <-time.After(100 * time.Millisecond):
  79. // Client is too slow, mark for removal
  80. logger.Debug("Client send channel timeout, marking for removal")
  81. deadClients = append(deadClients, client)
  82. default:
  83. // Channel is full, mark for removal
  84. logger.Debug("Client send channel full, marking for removal")
  85. deadClients = append(deadClients, client)
  86. }
  87. }
  88. h.mutex.RUnlock()
  89. // Clean up dead clients
  90. if len(deadClients) > 0 {
  91. h.mutex.Lock()
  92. for _, client := range deadClients {
  93. if _, ok := h.clients[client]; ok {
  94. close(client.send)
  95. delete(h.clients, client)
  96. client.cancel() // Trigger client cleanup
  97. }
  98. }
  99. h.mutex.Unlock()
  100. logger.Info("Cleaned up slow/unresponsive clients", "count", len(deadClients))
  101. }
  102. }
  103. }
  104. }
  105. // BroadcastMessage sends a message to all connected clients
  106. func (h *Hub) BroadcastMessage(event string, data any) {
  107. message := WebSocketMessage{
  108. Event: event,
  109. Data: data,
  110. }
  111. select {
  112. case h.broadcast <- message:
  113. default:
  114. logger.Warn("Cluster node broadcast channel full, message dropped")
  115. }
  116. }
  117. // WebSocket upgrader configuration
  118. var upgrader = websocket.Upgrader{
  119. CheckOrigin: func(r *http.Request) bool {
  120. return true
  121. },
  122. ReadBufferSize: 1024,
  123. WriteBufferSize: 1024,
  124. }
  125. type respNode struct {
  126. *model.Node
  127. Status bool `json:"status"`
  128. }
  129. // GetAllEnabledNodeWS handles WebSocket connections for real-time node monitoring
  130. func GetAllEnabledNodeWS(c *gin.Context) {
  131. ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  132. if err != nil {
  133. logger.Error("Failed to upgrade connection:", err)
  134. return
  135. }
  136. defer ws.Close()
  137. ctx, cancel := context.WithCancel(context.Background())
  138. defer cancel()
  139. client := &Client{
  140. conn: ws,
  141. send: make(chan WebSocketMessage, 1024), // Increased buffer size
  142. ctx: ctx,
  143. cancel: cancel,
  144. }
  145. hub := GetHub()
  146. hub.register <- client
  147. // Start goroutines for handling node monitoring
  148. go client.handleNodeMonitoring()
  149. // Start write and read pumps
  150. go client.writePump()
  151. client.readPump()
  152. }
  153. // handleNodeMonitoring monitors node status and sends updates
  154. func (c *Client) handleNodeMonitoring() {
  155. interval := 10 * time.Second
  156. heartbeatInterval := 30 * time.Second
  157. getNodeData := func() (interface{}, bool) {
  158. // Query nodes directly from database
  159. var nodes []model.Node
  160. err := model.UseDB().Where("enabled = ?", true).Find(&nodes).Error
  161. if err != nil {
  162. logger.Error("Failed to query nodes:", err)
  163. return nil, false
  164. }
  165. // Transform nodes to response format
  166. var result []respNode
  167. for _, node := range nodes {
  168. result = append(result, respNode{
  169. Node: &node,
  170. Status: analytic.GetNode(&node).Status,
  171. })
  172. }
  173. return result, true
  174. }
  175. getHash := func(data interface{}) string {
  176. bytes, _ := json.Marshal(data)
  177. hash := sha256.New()
  178. hash.Write(bytes)
  179. hashSum := hash.Sum(nil)
  180. return hex.EncodeToString(hashSum)
  181. }
  182. var dataHash string
  183. // Send initial data
  184. data, ok := getNodeData()
  185. if ok {
  186. dataHash = getHash(data)
  187. c.sendMessage("message", data)
  188. }
  189. ticker := time.NewTicker(interval)
  190. heartbeatTicker := time.NewTicker(heartbeatInterval)
  191. defer ticker.Stop()
  192. defer heartbeatTicker.Stop()
  193. for {
  194. select {
  195. case <-ticker.C:
  196. data, ok := getNodeData()
  197. if !ok {
  198. return
  199. }
  200. newHash := getHash(data)
  201. if dataHash != newHash {
  202. dataHash = newHash
  203. c.sendMessage("message", data)
  204. }
  205. case <-heartbeatTicker.C:
  206. c.sendMessage("heartbeat", "")
  207. case <-c.ctx.Done():
  208. return
  209. }
  210. }
  211. }
  212. // sendMessage sends a message to the client with timeout and better error handling
  213. func (c *Client) sendMessage(event string, data any) {
  214. message := WebSocketMessage{
  215. Event: event,
  216. Data: data,
  217. }
  218. select {
  219. case c.send <- message:
  220. case <-time.After(5 * time.Second):
  221. logger.Warn("Client send channel full, message dropped after timeout", "event", event)
  222. // Force disconnect slow clients to prevent resource leakage
  223. c.cancel()
  224. default:
  225. logger.Warn("Client send channel full, message dropped immediately", "event", event)
  226. // For non-critical messages, we can drop them immediately
  227. if event != "heartbeat" {
  228. logger.Info("Dropping non-critical message due to full channel", "event", event)
  229. }
  230. }
  231. }
  232. // writePump pumps messages from the hub to the websocket connection
  233. func (c *Client) writePump() {
  234. ticker := time.NewTicker(54 * time.Second)
  235. defer ticker.Stop()
  236. for {
  237. select {
  238. case message, ok := <-c.send:
  239. c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
  240. if !ok {
  241. c.conn.WriteMessage(websocket.CloseMessage, []byte{})
  242. return
  243. }
  244. if err := c.conn.WriteJSON(message); err != nil {
  245. logger.Error("Error writing message to websocket:", err)
  246. return
  247. }
  248. case <-ticker.C:
  249. c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
  250. if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
  251. return
  252. }
  253. case <-kernel.Context.Done():
  254. return
  255. case <-c.ctx.Done():
  256. return
  257. }
  258. }
  259. }
  260. // readPump pumps messages from the websocket connection to the hub
  261. func (c *Client) readPump() {
  262. defer func() {
  263. hub := GetHub()
  264. hub.unregister <- c
  265. c.conn.Close()
  266. c.cancel()
  267. }()
  268. c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  269. c.conn.SetPongHandler(func(string) error {
  270. c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  271. return nil
  272. })
  273. go func() {
  274. for {
  275. _, _, err := c.conn.ReadMessage()
  276. if err != nil {
  277. if helper.IsUnexpectedWebsocketError(err) {
  278. logger.Error("Websocket error:", err)
  279. }
  280. return
  281. }
  282. }
  283. }()
  284. select {
  285. case <-kernel.Context.Done():
  286. return
  287. case <-c.ctx.Done():
  288. return
  289. }
  290. }