diff --git a/controllers/system/sendMsg.go b/controllers/system/sendMsg.go new file mode 100644 index 0000000..02d7413 --- /dev/null +++ b/controllers/system/sendMsg.go @@ -0,0 +1,32 @@ +package system + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "net/http" + "toutoukan/model/systemmodel" + "toutoukan/socket" +) + +func SendMsg(c *gin.Context) { + var req systemmodel.MsgSend + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } + // 1. 将Msg结构体序列化为JSON字符串 + msgJson, err := json.Marshal(req.Msg) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "消息序列化失败:" + err.Error()}) + return + } + + // 2. 将JSON字符串转换为[]byte,传给SendToUser + err = socket.SendToUser(req.Username, msgJson) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "发送消息失败:" + err.Error()}) + return + } + + // 发送成功的响应 + c.JSON(http.StatusOK, gin.H{"code": 200, "message": "消息发送成功"}) +} diff --git a/model/systemmodel/msgsend.go b/model/systemmodel/msgsend.go new file mode 100644 index 0000000..6ed64d8 --- /dev/null +++ b/model/systemmodel/msgsend.go @@ -0,0 +1,9 @@ +package systemmodel + +type MsgSend struct { + Username string `json:"username"` + Msg struct { + Type string `json:"type"` + Content string `json:"content"` + } `json:"msg"` +} diff --git a/router/setupRouter.go b/router/setupRouter.go index 762f92a..ec63a56 100644 --- a/router/setupRouter.go +++ b/router/setupRouter.go @@ -2,6 +2,7 @@ package router import ( "github.com/gin-gonic/gin" + "toutoukan/controllers/system" "toutoukan/controllers/test" "toutoukan/controllers/user" "toutoukan/socket" @@ -16,9 +17,13 @@ func SetupRouter() *gin.Engine { apiGroup.POST("/login", user.UserLogin) apiGroup.POST("/test", utill.JWTAuthMiddleware(), test.Testjwt) } - r.GET("/socket", func(c *gin.Context) { + r.GET("/socket", utill.JWTAuthMiddleware(), func(c *gin.Context) { socket.WebsocketHandler(c) }) + systemGroup := r.Group("/system") + { + systemGroup.POST("/sendMsg", system.SendMsg) + } return r } diff --git a/socket/connect.go b/socket/connect.go index a817ead..b88d7ad 100644 --- a/socket/connect.go +++ b/socket/connect.go @@ -1,60 +1,138 @@ package socket import ( + "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "log" "net/http" + "strings" + "sync" ) -// 1. 配置WebSocket升级器(处理HTTP到WebSocket的握手) +// 全局连接管理器:维护用户名与WebSocket连接的映射 +var ( + userConnections = make(map[string]*websocket.Conn) // 用户名 -> 连接 + connMutex sync.RWMutex // 读写锁,保证并发安全 +) + +// WebSocket升级器(gorilla/websocket专用) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, - // 允许跨域(生产环境需根据实际情况限制) + // 允许跨域(生产环境需根据实际情况限制Origin) CheckOrigin: func(r *http.Request) bool { return true }, } -// 2. WebSocket处理器(处理实时消息) +// 添加用户连接到管理器 +func addUserConnection(username string, conn *websocket.Conn) { + connMutex.Lock() + defer connMutex.Unlock() + userConnections[username] = conn +} + +// 从管理器移除用户连接 +func removeUserConnection(username string) { + connMutex.Lock() + defer connMutex.Unlock() + delete(userConnections, username) +} + +// 向特定用户发送消息 +func SendToUser(username string, message []byte) error { + // 读锁:仅读取连接,不修改映射 + connMutex.RLock() + conn, exists := userConnections[username] + connMutex.RUnlock() + + if !exists { + return fmt.Errorf("用户 %s 不在线", username) + } + + // 使用gorilla/websocket的WriteMessage方法发送消息 + err := conn.WriteMessage(websocket.TextMessage, message) + if err != nil { + // 发送失败时移除无效连接 + removeUserConnection(username) + return fmt.Errorf("向用户 %s 发送消息失败: %v", username, err) + } + return nil +} + +// WebsocketHandler 处理WebSocket连接和消息 func WebsocketHandler(c *gin.Context) { - // 从Gin上下文中提取标准库的ResponseWriter和Request w := c.Writer r := c.Request - // 将HTTP连接升级为WebSocket - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("升级WebSocket失败: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "WebSocket连接失败", + // 1. 从Header提取用户名(假设Header键为"Username") + username := r.Header.Get("Username") + if username == "" { + log.Printf("连接失败:请求Header中未包含Username") + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Header中必须包含用户名(Username)", }) return } - defer conn.Close() // 确保连接关闭 - // 客户端IP,用于日志 - clientIP := r.RemoteAddr - log.Printf("新的WebSocket连接: %s", clientIP) + // 2. 将HTTP连接升级为WebSocket(gorilla/websocket的升级方法) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket升级失败: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "WebSocket连接建立失败", + }) + return + } + // 连接关闭时清理资源 + defer func() { + conn.Close() + removeUserConnection(username) + log.Printf("用户 %s 已断开连接", username) + }() - // 循环处理消息 + // 3. 将连接加入全局管理器 + addUserConnection(username, conn) + log.Printf("用户 %s 建立WebSocket连接,当前在线用户: %d", username, getOnlineUserCount()) + + // 4. 循环处理消息 for { - // 读取消息类型(文本/二进制)和内容 - msgType, p, err := conn.ReadMessage() + // 读取消息类型和内容(gorilla/websocket的ReadMessage方法) + msgType, msg, err := conn.ReadMessage() if err != nil { - log.Printf("WebSocket读取错误(%s): %v", clientIP, err) + log.Printf("用户 %s 消息读取错误: %v", username, err) break } - // 处理消息内容(示例:打印收到的消息) - log.Printf("收到来自%s的消息(类型:%d): %s", clientIP, msgType, string(p)) + // 打印收到的消息 + log.Printf("收到用户 %s 的消息(类型: %d): %s", username, msgType, string(msg)) - // 示例:根据消息类型回复(保持原类型) - response := []byte("已收到:" + string(p)) - if err := conn.WriteMessage(msgType, response); err != nil { - log.Printf("WebSocket发送错误(%s): %v", clientIP, err) - break + // 5. 解析消息并转发给指定用户(示例格式:"目标用户名:消息内容") + parts := strings.SplitN(string(msg), ":", 2) + if len(parts) == 2 { + targetUser := parts[0] + content := parts[1] + + // 构造转发消息 + forwardMsg := []byte(fmt.Sprintf("[%s] 对你说: %s", username, content)) + // 发送给目标用户 + if err := SendToUser(targetUser, forwardMsg); err != nil { + // 发送失败时,向当前用户反馈错误 + errorMsg := []byte(fmt.Sprintf("系统提示: %v", err)) + conn.WriteMessage(websocket.TextMessage, errorMsg) + } + } else { + // 消息格式错误时的提示 + helpMsg := []byte("消息格式错误,请使用: 目标用户名:消息内容") + conn.WriteMessage(websocket.TextMessage, helpMsg) } } } + +// 获取当前在线用户数(辅助函数) +func getOnlineUserCount() int { + connMutex.RLock() + defer connMutex.RUnlock() + return len(userConnections) +}