修改websocket

This commit is contained in:
liuxiaobo 2025-06-02 15:29:03 +08:00
parent 61e26c7e94
commit a8bb1d01b4
4 changed files with 84 additions and 104 deletions

View File

@ -1,7 +1,5 @@
package ipb package ipb
import "sync/atomic"
func MakeMsg(serviceName string, connId uint32, userId int64, msgId int32, msg []byte) *InternalMsg { func MakeMsg(serviceName string, connId uint32, userId int64, msgId int32, msg []byte) *InternalMsg {
return &InternalMsg{ return &InternalMsg{
ServiceName: serviceName, ServiceName: serviceName,
@ -12,28 +10,14 @@ func MakeMsg(serviceName string, connId uint32, userId int64, msgId int32, msg [
} }
} }
func MakeRpcMsg(serviceName string, connId uint32, userId int64, msgId int32, msg []byte) *InternalMsg { //func MakeRpcMsg(serviceName string, connId uint32, userId int64, msgId int32, msg []byte) *InternalMsg {
return &InternalMsg{ // return &InternalMsg{
ServiceName: serviceName, // ServiceName: serviceName,
ConnId: connId, // ConnId: connId,
UserId: userId, // UserId: userId,
MsgId: msgId, // MsgId: msgId,
Msg: msg, // Msg: msg,
Type: MsgType_RpcMsg, // Type: MsgType_RpcMsg,
RetRpcMsgId: genRpcId(), // RetRpcMsgId: genRpcId(),
} // }
} //}
const (
rpcBeginId = -500000
rpcEndId = -100000
)
var rpcId int32
func genRpcId() int32 {
if atomic.LoadInt32(&rpcId) > rpcEndId {
atomic.StoreInt32(&rpcId, rpcBeginId)
}
return atomic.AddInt32(&rpcId, 1)
}

View File

@ -2,7 +2,7 @@ package ws
type IConn interface { type IConn interface {
Addr() string Addr() string
NotifyClose() Close()
SendMsg(data []byte) error SendMsg(data []byte) error
Name() string Name() string
Id() uint32 Id() uint32

View File

@ -2,6 +2,7 @@ package ws
import ( import (
"context" "context"
"fmt"
"github.com/fox/fox/ksync" "github.com/fox/fox/ksync"
"github.com/fox/fox/log" "github.com/fox/fox/log"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -21,6 +22,7 @@ type Client struct {
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
onFunc IOnFunc onFunc IOnFunc
uid int64
} }
func NewClient(url string, onFunc IOnFunc) (*Client, error) { func NewClient(url string, onFunc IOnFunc) (*Client, error) {
@ -49,35 +51,47 @@ func (c *Client) Start() {
ksync.GoSafe(c.heartbeatLoop, nil) ksync.GoSafe(c.heartbeatLoop, nil)
} }
func (c *Client) Log(format string, v ...interface{}) string {
s := fmt.Sprintf("连接:%v, uid:%v ", c.conn.RemoteAddr().String(), c.uid)
return s + fmt.Sprintf(format, v...)
}
/* /*
readLoop暂时没有好的办法及时退出协程c.conn.ReadMessage()是阻塞式导致协程无法及时catch到关闭信号 readLoop暂时没有好的办法及时退出协程c.conn.ReadMessage()是阻塞式导致协程无法及时catch到关闭信号
如果在ReadMessage前调用SetReadDeadline设置超时它会在超时后将底部连接状态标记为已损坏后续ReadMessage会触发崩溃 如果在ReadMessage前调用SetReadDeadline设置超时它会在超时后将底部连接状态标记为已损坏后续ReadMessage会触发崩溃
*/ */
func (c *Client) readLoop() { func (c *Client) readLoop() {
//defer c.wg.Done() //defer c.wg.Done()
_ = c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// 设置 Pong 处理器
c.conn.SetPongHandler(func(string) error {
_ = c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for { for {
select { select {
case <-c.ctx.Done(): case <-c.ctx.Done():
log.Debug("readLoop 收到关闭信号") log.Debug(c.Log("readLoop 收到关闭信号"))
return return
default: default:
messageType, message, err := c.conn.ReadMessage() messageType, message, err := c.conn.ReadMessage()
if err != nil { if err != nil {
//log.Error(fmt.Sprintf("读取错误:%v", err)) log.Error(c.Log("读取错误:%v", err))
c.NotifyStop() c.NotifyStop()
return return
} }
switch messageType { switch messageType {
case websocket.PingMessage: case websocket.PingMessage:
log.Debug(c.Log("receive ping message"))
c.sendChan <- &wsMessage{messageType: websocket.PongMessage, data: []byte("pong")} c.sendChan <- &wsMessage{messageType: websocket.PongMessage, data: []byte("pong")}
case websocket.PongMessage: case websocket.PongMessage:
log.Debug(c.Log("receive pong message"))
case websocket.TextMessage, websocket.BinaryMessage: case websocket.TextMessage, websocket.BinaryMessage:
_ = c.onFunc.OnMessage(message) _ = c.onFunc.OnMessage(message)
case websocket.CloseMessage: case websocket.CloseMessage:
log.Debug("收到关闭帧") log.Debug(c.Log("收到关闭帧"))
c.NotifyStop() c.NotifyStop()
return return
} }
@ -96,14 +110,16 @@ func (c *Client) writeLoop() {
case msg := <-c.sendChan: case msg := <-c.sendChan:
switch msg.messageType { switch msg.messageType {
case websocket.PingMessage: case websocket.PingMessage:
log.Debug(c.Log("send ping message"))
_ = c.conn.WriteMessage(websocket.PingMessage, []byte("ping")) _ = c.conn.WriteMessage(websocket.PingMessage, []byte("ping"))
case websocket.PongMessage: case websocket.PongMessage:
log.Debug(c.Log("send pong message"))
_ = c.conn.WriteMessage(websocket.PongMessage, []byte("pong")) _ = c.conn.WriteMessage(websocket.PongMessage, []byte("pong"))
default: default:
_ = c.conn.WriteMessage(msg.messageType, msg.data) _ = c.conn.WriteMessage(msg.messageType, msg.data)
} }
case <-c.ctx.Done(): case <-c.ctx.Done():
//log.Debug("writeLoop 收到关闭信号") log.Debug(c.Log("writeLoop 收到关闭信号"))
// 发送关闭帧 // 发送关闭帧
_ = c.conn.WriteControl( _ = c.conn.WriteControl(
websocket.CloseMessage, websocket.CloseMessage,
@ -139,3 +155,7 @@ func (c *Client) NotifyStop() {
c.cancel() c.cancel()
_ = c.conn.Close() _ = c.conn.Close()
} }
func (c *Client) SetUid(uid int64) {
c.uid = uid
}

View File

@ -3,7 +3,10 @@ package ws
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/fox/fox/ipb"
"github.com/fox/fox/log" "github.com/fox/fox/log"
"github.com/fox/fox/safeChan"
"github.com/golang/protobuf/proto"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"sync" "sync"
"time" "time"
@ -23,73 +26,45 @@ type wsMessage struct {
// 客户端连接 // 客户端连接
type wsConnect struct { type wsConnect struct {
wsConn *websocket.Conn // 底层websocket wsConn *websocket.Conn // 底层websocket
inChan chan *wsMessage // 读队列 inChan *safeChan.SafeChan[*wsMessage] // 读队列
outChan chan *wsMessage // 写队列 outChan *safeChan.SafeChan[*wsMessage] // 写队列
mutex sync.Mutex // 避免重复关闭管道,加锁处理
isClosed bool
closeCh chan struct{} // 关闭通知
id uint32 id uint32
userId int64 userId int64
onDisconnect func(IConn) onDisconnect func(IConn)
once sync.Once
} }
func newWsConnect(wsConn *websocket.Conn, onDisconnect func(IConn)) *wsConnect { func newWsConnect(wsConn *websocket.Conn, onDisconnect func(IConn)) *wsConnect {
return &wsConnect{ c := &wsConnect{
wsConn: wsConn, wsConn: wsConn,
inChan: make(chan *wsMessage, 1000), inChan: safeChan.NewSafeChan[*wsMessage](1000),
outChan: make(chan *wsMessage, 1000), outChan: safeChan.NewSafeChan[*wsMessage](1000),
closeCh: make(chan struct{}),
isClosed: false,
id: nextConnId, id: nextConnId,
userId: 0, userId: 0,
onDisconnect: onDisconnect, onDisconnect: onDisconnect,
} }
} return c
// 从读队列读取消息
func (c *wsConnect) readFromChan() (*wsMessage, error) {
select {
case msg := <-c.inChan:
return msg, nil
case <-c.closeCh:
return nil, fmt.Errorf("连接已关闭")
}
}
// 把消息放进写队列
func (c *wsConnect) sendMsg(msgType int, data []byte) error {
select {
case c.outChan <- &wsMessage{messageType: msgType, data: data}:
case <-c.closeCh:
return fmt.Errorf("连接已关闭")
}
return nil
} }
// 把消息放进写队列 // 把消息放进写队列
func (c *wsConnect) SendMsg(data []byte) error { func (c *wsConnect) SendMsg(data []byte) error {
return c.sendMsg(wsMsgType, data) return c.outChan.Write(&wsMessage{messageType: wsMsgType, data: data})
} }
// 关闭链接 // 关闭链接
func (c *wsConnect) NotifyClose() { func (c *wsConnect) Close() {
c.closeCh <- struct{}{} c.once.Do(func() {
} log.Debug(c.Log("关闭链接"))
_ = c.wsConn.WriteMessage(websocket.CloseMessage, []byte{})
// 关闭链接 c.inChan.Close()
func (c *wsConnect) close() { c.outChan.Close()
log.Debug(c.Log("关闭链接")) _ = c.wsConn.Close()
c.mutex.Lock()
defer c.mutex.Unlock()
if c.isClosed == false {
c.isClosed = true
if c.onDisconnect != nil { if c.onDisconnect != nil {
c.onDisconnect(c) c.onDisconnect(c)
} }
wsMgr.Remove(c) wsMgr.Remove(c)
//close(c.closeCh) })
}
} }
// 循环从websocket中读取消息放入到读队列中 // 循环从websocket中读取消息放入到读队列中
@ -109,14 +84,16 @@ func (c *wsConnect) readWsLoop() {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
log.Error(c.Log("消息读取出现错误:%v", err)) log.Error(c.Log("消息读取出现错误:%v", err))
} }
c.close() log.Debug(c.Log("关闭连接:%v", err))
c.Close()
return return
} }
switch msgType { switch msgType {
case websocket.PingMessage: //case websocket.PingMessage:
_ = c.sendMsg(websocket.PongMessage, []byte("pong")) // log.Debug(c.Log("received ping message"))
case websocket.PongMessage: // _ = c.sendMsg(websocket.PongMessage, []byte("pong"))
log.Debug(c.Log("received pong from client")) //case websocket.PongMessage:
// log.Debug(c.Log("received pong from client"))
case websocket.CloseMessage: case websocket.CloseMessage:
code := websocket.CloseNormalClosure code := websocket.CloseNormalClosure
reason := "" reason := ""
@ -128,17 +105,13 @@ func (c *wsConnect) readWsLoop() {
// 发送响应关闭帧(必须回传相同状态码) // 发送响应关闭帧(必须回传相同状态码)
rspMsg := websocket.FormatCloseMessage(code, reason) rspMsg := websocket.FormatCloseMessage(code, reason)
_ = c.wsConn.WriteControl(websocket.CloseMessage, rspMsg, time.Now().Add(5*time.Second)) _ = c.wsConn.WriteControl(websocket.CloseMessage, rspMsg, time.Now().Add(5*time.Second))
c.close() c.Close()
default: default:
if msgType != wsMsgType { if msgType != wsMsgType {
continue continue
} }
msg := &wsMessage{messageType: msgType, data: data} msg := &wsMessage{messageType: msgType, data: data}
select { _ = c.inChan.Write(msg)
case c.inChan <- msg:
case <-c.closeCh:
return
}
} }
} }
} }
@ -149,20 +122,22 @@ func (c *wsConnect) writeWsLoop() {
for { for {
select { select {
// 取一个消息发送给客户端 // 取一个消息发送给客户端
case msg := <-c.outChan: case msg := <-c.outChan.Reader():
if err := c.wsConn.WriteMessage(msg.messageType, msg.data); err != nil { if err := c.wsConn.WriteMessage(msg.messageType, msg.data); err != nil {
log.Error(c.Log("发送消息错误:%v", err)) iMsg := &ipb.InternalMsg{}
_ = proto.Unmarshal(msg.data, iMsg)
log.Error(c.Log("发送消息错误:%v 消息长度:%v 消息内容:%v", err, len(msg.data), iMsg))
// 关闭连接 // 关闭连接
c.close() c.Close()
return return
} }
case <-c.closeCh:
// 收到关闭通知
return
case <-ticker.C: case <-ticker.C:
_ = c.wsConn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.wsConn.WriteMessage(websocket.PingMessage, []byte("ping")); err != nil {
if err := c.wsConn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { log.Error(c.Log("发送心跳失败:%v", err))
c.Close()
return return
} else {
log.Debug(c.Log("发送心跳"))
} }
} }
} }
@ -170,13 +145,14 @@ func (c *wsConnect) writeWsLoop() {
func (c *wsConnect) handle(process func(IConn, []byte)) { func (c *wsConnect) handle(process func(IConn, []byte)) {
for { for {
msg, err := c.readFromChan() select {
if err != nil { case msg, ok := <-c.inChan.Reader():
// log.Error(c.Log("获取消息错误:%v", err)) if ok {
break process(c, msg.data)
} else {
c.Close()
}
} }
// Log.Debug(c.Log("接收消息:%v", msg.data))
process(c, msg.data)
} }
} }
@ -204,6 +180,6 @@ func (c *wsConnect) Addr() string {
} }
func (c *wsConnect) Log(format string, v ...interface{}) string { func (c *wsConnect) Log(format string, v ...interface{}) string {
s := fmt.Sprintf("连接:%v, id:%v ", c.wsConn.RemoteAddr().String(), c.id) s := fmt.Sprintf("连接:%v, id:%v uid:%v", c.wsConn.RemoteAddr().String(), c.id, c.userId)
return s + fmt.Sprintf(format, v...) return s + fmt.Sprintf(format, v...)
} }