package ws import ( "context" "github.com/fox/fox/ksync" "github.com/fox/fox/log" "github.com/gorilla/websocket" "net/http" "sync" "time" ) type IOnFunc interface { OnMessage(msg []byte) error } type Client struct { conn *websocket.Conn sendChan chan *wsMessage ctx context.Context cancel context.CancelFunc wg sync.WaitGroup onFunc IOnFunc } func NewClient(url string, onFunc IOnFunc) (*Client, error) { dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 30 * time.Second conn, _, err := dialer.Dial(url, http.Header{"User-Agent": {"MyClient/1.0"}}) if err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) return &Client{ conn: conn, sendChan: make(chan *wsMessage, 100), ctx: ctx, cancel: cancel, onFunc: onFunc, }, nil } func (c *Client) Start() { c.wg.Add(2) ksync.GoSafe(c.readLoop, nil) ksync.GoSafe(c.writeLoop, nil) ksync.GoSafe(c.heartbeatLoop, nil) } /* readLoop暂时没有好的办法及时退出协程,c.conn.ReadMessage()是阻塞式,导致协程无法及时catch到关闭信号 如果在ReadMessage前调用SetReadDeadline设置超时,它会在超时后将底部连接状态标记为已损坏,后续ReadMessage会触发崩溃 */ func (c *Client) readLoop() { //defer c.wg.Done() for { select { case <-c.ctx.Done(): log.Debug("readLoop 收到关闭信号") return default: messageType, message, err := c.conn.ReadMessage() if err != nil { //log.Error(fmt.Sprintf("读取错误:%v", err)) c.NotifyStop() return } switch messageType { case websocket.PingMessage: c.sendChan <- &wsMessage{messageType: websocket.PongMessage, data: []byte("pong")} case websocket.PongMessage: case websocket.TextMessage, websocket.BinaryMessage: _ = c.onFunc.OnMessage(message) case websocket.CloseMessage: log.Debug("收到关闭帧") c.NotifyStop() return } } } } func (c *Client) SendMsg(data []byte) { c.sendChan <- &wsMessage{messageType: websocket.BinaryMessage, data: data} } func (c *Client) writeLoop() { defer c.wg.Done() for { select { case msg := <-c.sendChan: switch msg.messageType { case websocket.PingMessage: _ = c.conn.WriteMessage(websocket.PingMessage, []byte("ping")) case websocket.PongMessage: _ = c.conn.WriteMessage(websocket.PongMessage, []byte("pong")) default: _ = c.conn.WriteMessage(msg.messageType, msg.data) } case <-c.ctx.Done(): //log.Debug("writeLoop 收到关闭信号") // 发送关闭帧 _ = c.conn.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(10*time.Second), ) return } } } func (c *Client) heartbeatLoop() { defer c.wg.Done() ticker := time.NewTicker(25 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: c.sendChan <- &wsMessage{messageType: websocket.PingMessage, data: []byte("ping")} case <-c.ctx.Done(): //log.Debug("heartbeatLoop 收到关闭信号") return } } } func (c *Client) WaitStop() { c.wg.Wait() } func (c *Client) NotifyStop() { c.cancel() _ = c.conn.Close() }