package ws import ( "github.com/fox/fox/ksync" "github.com/fox/fox/log" "github.com/gorilla/websocket" "net/http" "runtime" "time" ) const ( // 允许等待的写入时间 //writeWait = 10 * time.Second // pong间隔时间 pongWait = 60 * time.Second // ping间隔时间 pingPeriod = (pongWait * 9) / 10 // 最大消息长度 1MB maxMessageSize = 1024 * 1024 ) var upGrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有跨域请求(生产环境应限制) // 跨域限制 // allowedOrigins := map[string]bool{ // "https://yourdomain.com": true, // "https://api.yourdomain.com": true, // } // return allowedOrigins[r.Header.Get("Origin")] }, HandshakeTimeout: 10 * time.Second, // 握手超时 ReadBufferSize: 4096, // 读缓冲区 WriteBufferSize: 4096, // 写缓冲区 // MaxMessageSize: 1024 * 1024 * 2, // 最大消息2MB } type WsServer struct { addr string // 0.0.0.0:8888 onMessage func(IConn, []byte) onDisconnect func(IConn) connMgr *connManager userMgr *userManager } func NewWsServer(addr string, onMessage func(IConn, []byte), onDisconnect func(IConn)) *WsServer { wss := &WsServer{ addr: addr, onMessage: onMessage, onDisconnect: onDisconnect, connMgr: newConnManager(nil), userMgr: newUserManager(nil), } wss.connMgr.userMgr = wss.userMgr wss.userMgr.connMgr = wss.connMgr return wss } func (s *WsServer) wsHandle(w http.ResponseWriter, r *http.Request) { conn, err := upGrader.Upgrade(w, r, nil) if err != nil { log.ErrorF("升级到WebSocket失败:%v", err) return } // defer func() { _ = conn.close() }() nextConnId++ wsConn := newWsConnect(conn, s.onDisconnect, s.connMgr) s.connMgr.Add(wsConn) log.DebugF("当前连接数:%v", s.connMgr.Count()) log.DebugF("新连接id:%v %v", wsConn.Id(), wsConn.Name()) ksync.GoSafe(func() { wsConn.handle(s.onMessage) }, nil) ksync.GoSafe(wsConn.readWsLoop, nil) ksync.GoSafe(wsConn.writeWsLoop, nil) } func (s *WsServer) debugGoroutineNum() { ch := time.Tick(2 * time.Minute) go func() { for { select { case <-ch: log.DebugF("当前协程数量:%v", runtime.NumGoroutine()) } } }() } func (s *WsServer) Run() { router := http.NewServeMux() router.HandleFunc("/", s.wsHandle) log.DebugF("websocket server listening on :%v", s.addr) ksync.GoSafe(func() { err := http.ListenAndServe(s.addr, router) if err != nil { log.Error(err.Error()) } }, nil) s.debugGoroutineNum() } func (s *WsServer) SetUserId(connId uint32, userId int64) { s.connMgr.SetUserId(connId, userId) } func (s *WsServer) FindConnByUserId(userId int64) (IConn, bool) { return s.connMgr.FindByUserId(userId) } func (s *WsServer) FindConnByConnId(connId uint32) (IConn, bool) { return s.connMgr.Get(connId) } func (s *WsServer) Rang(cb func(conn IConn) bool) { s.connMgr.Rang(cb) }