package ws import ( "context" "fmt" "github.com/fox/fox/ksync" "github.com/fox/fox/log" "github.com/gorilla/websocket" "net/http" "sync" "time" ) type IOnFunc interface { OnMessage(msg []byte) error } type Client struct { conn *websocket.Conn sendChan chan *wsMessage ctx context.Context cancel context.CancelFunc wg sync.WaitGroup onFunc IOnFunc uid int64 } func NewClient(url string, onFunc IOnFunc) (*Client, error) { dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 30 * time.Second conn, _, err := dialer.Dial(url, http.Header{"User-Agent": {"MyClient/1.0"}}) if err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) return &Client{ conn: conn, sendChan: make(chan *wsMessage, 100), ctx: ctx, cancel: cancel, onFunc: onFunc, }, nil } func (c *Client) Start() { c.wg.Add(2) ksync.GoSafe(c.readLoop, nil) ksync.GoSafe(c.writeLoop, nil) ksync.GoSafe(c.heartbeatLoop, nil) } func (c *Client) Log(format string, v ...interface{}) string { s := fmt.Sprintf("连接:%v, uid:%v ", c.conn.RemoteAddr().String(), c.uid) return s + fmt.Sprintf(format, v...) } /* readLoop暂时没有好的办法及时退出协程,c.conn.ReadMessage()是阻塞式,导致协程无法及时catch到关闭信号 如果在ReadMessage前调用SetReadDeadline设置超时,它会在超时后将底部连接状态标记为已损坏,后续ReadMessage会触发崩溃 */ func (c *Client) readLoop() { //defer c.wg.Done() _ = c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // 设置 Pong 处理器 c.conn.SetPongHandler(func(string) error { _ = c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { select { case <-c.ctx.Done(): log.Debug(c.Log("readLoop 收到关闭信号")) return default: messageType, message, err := c.conn.ReadMessage() if err != nil { log.Error(c.Log("读取错误:%v", err)) c.NotifyStop() return } switch messageType { case websocket.PingMessage: log.Debug(c.Log("receive ping message")) c.sendChan <- &wsMessage{messageType: websocket.PongMessage, data: []byte("pong")} case websocket.PongMessage: log.Debug(c.Log("receive pong message")) case websocket.TextMessage, websocket.BinaryMessage: _ = c.onFunc.OnMessage(message) case websocket.CloseMessage: log.Debug(c.Log("收到关闭帧")) c.NotifyStop() return } } } } func (c *Client) SendMsg(data []byte) { c.sendChan <- &wsMessage{messageType: websocket.BinaryMessage, data: data} } func (c *Client) writeLoop() { defer c.wg.Done() for { select { case msg := <-c.sendChan: switch msg.messageType { case websocket.PingMessage: log.Debug(c.Log("send ping message")) _ = c.conn.WriteMessage(websocket.PingMessage, []byte("ping")) case websocket.PongMessage: log.Debug(c.Log("send pong message")) _ = c.conn.WriteMessage(websocket.PongMessage, []byte("pong")) default: _ = c.conn.WriteMessage(msg.messageType, msg.data) } case <-c.ctx.Done(): log.Debug(c.Log("writeLoop 收到关闭信号")) // 发送关闭帧 _ = c.conn.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(10*time.Second), ) return } } } func (c *Client) heartbeatLoop() { defer c.wg.Done() ticker := time.NewTicker(25 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: c.sendChan <- &wsMessage{messageType: websocket.PingMessage, data: []byte("ping")} case <-c.ctx.Done(): //log.Debug("heartbeatLoop 收到关闭信号") return } } } func (c *Client) WaitStop() { c.wg.Wait() } func (c *Client) NotifyStop() { c.cancel() _ = c.conn.Close() } func (c *Client) SetUid(uid int64) { c.uid = uid }