package ws import ( "github.com/fox/fox/ksync" "github.com/fox/fox/log" "github.com/gorilla/websocket" "net/http" "time" ) const ( // 允许等待的写入时间 writeWait = 10 * time.Second // pong间隔时间 pongWait = 60 * time.Second // ping间隔时间 pingPeriod = (pongWait * 9) / 10 // 最大消息长度 maxMessageSize = 512 ) var upGrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有跨域请求(生产环境应限制) // 跨域限制 // allowedOrigins := map[string]bool{ // "https://yourdomain.com": true, // "https://api.yourdomain.com": true, // } // return allowedOrigins[r.Header.Get("Origin")] }, HandshakeTimeout: 10 * time.Second, // 握手超时 ReadBufferSize: 4096, // 读缓冲区 WriteBufferSize: 4096, // 写缓冲区 // MaxMessageSize: 1024 * 1024 * 2, // 最大消息2MB } type WsServer struct { addr string // 0.0.0.0:8888 onMessage func(IConn, []byte) onDisconnect func(IConn) } func NewWsServer(addr string, onMessage func(IConn, []byte), onDisconnect func(IConn)) *WsServer { return &WsServer{addr: addr, onMessage: onMessage, onDisconnect: onDisconnect} } func (s *WsServer) wsHandle(w http.ResponseWriter, r *http.Request) { conn, err := upGrader.Upgrade(w, r, nil) if err != nil { log.ErrorF("升级到WebSocket失败:%v", err) return } // defer func() { _ = conn.Close() }() nextConnId++ wsConn := newWsConnect(conn, s.onDisconnect) wsMgr.Add(wsConn) log.DebugF("当前在线人数:%v", wsMgr.Count()) ksync.GoSafe(func() { wsConn.handle(s.onMessage) }, nil) ksync.GoSafe(wsConn.readWsLoop, nil) ksync.GoSafe(wsConn.writeWsLoop, nil) } func (s *WsServer) Run() { http.HandleFunc("/", s.wsHandle) log.DebugF("websocket server listening on :%v", s.addr) ksync.GoSafe(func() { err := http.ListenAndServe(s.addr, nil) if err != nil { log.Error(err.Error()) } }, nil) } func (s *WsServer) SetUserId(connId uint32, userId int64) { wsMgr.SetUserId(connId, userId) } func (s *WsServer) FindConnByUserId(userId int64) (IConn, bool) { return wsMgr.FindByUserId(userId) }