package socket import ( "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "log" "net/http" "strings" "sync" ) // 全局连接管理器:维护用户名与WebSocket连接的映射 var ( userConnections = make(map[string]*websocket.Conn) // 用户名 -> 连接 connMutex sync.RWMutex // 读写锁,保证并发安全 ) // WebSocket升级器(gorilla/websocket专用) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // 允许跨域(生产环境需根据实际情况限制Origin) CheckOrigin: func(r *http.Request) bool { return true }, } // 添加用户连接到管理器 func addUserConnection(username string, conn *websocket.Conn) { connMutex.Lock() defer connMutex.Unlock() userConnections[username] = conn } // 从管理器移除用户连接 func removeUserConnection(username string) { connMutex.Lock() defer connMutex.Unlock() delete(userConnections, username) } // 向特定用户发送消息 func SendToUser(username string, message []byte) error { // 读锁:仅读取连接,不修改映射 connMutex.RLock() conn, exists := userConnections[username] connMutex.RUnlock() if !exists { return fmt.Errorf("用户 %s 不在线", username) } // 使用gorilla/websocket的WriteMessage方法发送消息 err := conn.WriteMessage(websocket.TextMessage, message) if err != nil { // 发送失败时移除无效连接 removeUserConnection(username) return fmt.Errorf("向用户 %s 发送消息失败: %v", username, err) } return nil } // WebsocketHandler 处理WebSocket连接和消息 func WebsocketHandler(c *gin.Context) { w := c.Writer r := c.Request // 1. 从Header提取用户名(假设Header键为"Username") username := r.Header.Get("Username") if username == "" { log.Printf("连接失败:请求Header中未包含Username") c.JSON(http.StatusBadRequest, gin.H{ "error": "Header中必须包含用户名(Username)", }) return } // 2. 将HTTP连接升级为WebSocket(gorilla/websocket的升级方法) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket升级失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{ "error": "WebSocket连接建立失败", }) return } // 连接关闭时清理资源 defer func() { conn.Close() removeUserConnection(username) log.Printf("用户 %s 已断开连接", username) }() // 3. 将连接加入全局管理器 addUserConnection(username, conn) log.Printf("用户 %s 建立WebSocket连接,当前在线用户: %d", username, getOnlineUserCount()) // 4. 循环处理消息 for { // 读取消息类型和内容(gorilla/websocket的ReadMessage方法) msgType, msg, err := conn.ReadMessage() if err != nil { log.Printf("用户 %s 消息读取错误: %v", username, err) break } // 打印收到的消息 log.Printf("收到用户 %s 的消息(类型: %d): %s", username, msgType, string(msg)) // 5. 解析消息并转发给指定用户(示例格式:"目标用户名:消息内容") parts := strings.SplitN(string(msg), ":", 2) if len(parts) == 2 { targetUser := parts[0] content := parts[1] // 构造转发消息 forwardMsg := []byte(fmt.Sprintf("[%s] 对你说: %s", username, content)) // 发送给目标用户 if err := SendToUser(targetUser, forwardMsg); err != nil { // 发送失败时,向当前用户反馈错误 errorMsg := []byte(fmt.Sprintf("系统提示: %v", err)) conn.WriteMessage(websocket.TextMessage, errorMsg) } } else { // 消息格式错误时的提示 helpMsg := []byte("消息格式错误,请使用: 目标用户名:消息内容") conn.WriteMessage(websocket.TextMessage, helpMsg) } } } // 获取当前在线用户数(辅助函数) func getOnlineUserCount() int { connMutex.RLock() defer connMutex.RUnlock() return len(userConnections) }