用sync.Map替换第三方并发map

This commit is contained in:
liuxiaobo 2025-05-28 19:43:56 +08:00
parent d061a2bc50
commit 80d513739c
4 changed files with 48 additions and 20 deletions

View File

@ -107,6 +107,10 @@ func (s *NatsService) Call(topic string, timeout time.Duration, msg []byte) ([]b
return nil, nil return nil, nil
} }
func (s *NatsService) ServiceEtcd() *etcd.Registry[etcd.ServiceNode] {
return s.registry
}
// 从etcd中获取所有服务节点 // 从etcd中获取所有服务节点
func (s *NatsService) GetServiceNodes() *sync.Map { func (s *NatsService) GetServiceNodes() *sync.Map {
return s.registry.GetNodes() return s.registry.GetNodes()

View File

@ -1,18 +1,20 @@
package ws package ws
import cmap "github.com/orcaman/concurrent-map/v2" import (
"sync"
)
var userMgr = newUserManager() var userMgr = newUserManager()
type userManager struct { type userManager struct {
users cmap.ConcurrentMap[int64, uint32] users sync.Map // cmap.ConcurrentMap[int64, uint32]
} }
func newUserManager() *userManager { func newUserManager() *userManager {
return &userManager{ return &userManager{
users: cmap.NewWithCustomShardingFunction[int64, uint32](func(key int64) uint32 { //users: cmap.NewWithCustomShardingFunction[int64, uint32](func(key int64) uint32 {
return uint32(key) // return uint32(key)
}), //}),
} }
} }
@ -22,20 +24,22 @@ func (m *userManager) Add(connId uint32, userId int64) bool {
} }
if conn, ok := wsMgr.Get(connId); ok { if conn, ok := wsMgr.Get(connId); ok {
conn.setUserId(userId) conn.setUserId(userId)
m.users.Set(userId, connId) m.users.Store(userId, connId)
return true return true
} }
return false return false
} }
func (m *userManager) GetConnId(userId int64) uint32 { func (m *userManager) GetConnId(userId int64) uint32 {
connId, _ := m.users.Get(userId) if connId, ok := m.users.Load(userId); ok {
return connId return connId.(uint32)
}
return 0
} }
func (m *userManager) Remove(userId int64) { func (m *userManager) Remove(userId int64) {
if userId < 1 { if userId < 1 {
return return
} }
m.users.Remove(userId) m.users.Delete(userId)
} }

View File

@ -1,23 +1,25 @@
package ws package ws
import cmap "github.com/orcaman/concurrent-map/v2" import (
"sync"
)
var wsMgr = newManager() var wsMgr = newManager()
type wsManager struct { type wsManager struct {
wsConnAll cmap.ConcurrentMap[uint32, *wsConnect] wsConnAll sync.Map // cmap.ConcurrentMap[uint32, *wsConnect]
} }
func newManager() *wsManager { func newManager() *wsManager {
return &wsManager{ return &wsManager{
wsConnAll: cmap.NewWithCustomShardingFunction[uint32, *wsConnect](func(key uint32) uint32 { //wsConnAll: cmap.NewWithCustomShardingFunction[uint32, *wsConnect](func(key uint32) uint32 {
return key // return key
}), //}),
} }
} }
func (m *wsManager) Add(conn *wsConnect) { func (m *wsManager) Add(conn *wsConnect) {
m.wsConnAll.Set(conn.id, conn) m.wsConnAll.Store(conn.id, conn)
} }
func (m *wsManager) SetUserId(connId uint32, userId int64) { func (m *wsManager) SetUserId(connId uint32, userId int64) {
@ -28,18 +30,32 @@ func (m *wsManager) Remove(conn *wsConnect) {
if conn.UserId() > 0 { if conn.UserId() > 0 {
userMgr.Remove(conn.UserId()) userMgr.Remove(conn.UserId())
} }
m.wsConnAll.Remove(conn.id) m.wsConnAll.Delete(conn.id)
} }
func (m *wsManager) Get(connId uint32) (*wsConnect, bool) { func (m *wsManager) Get(connId uint32) (*wsConnect, bool) {
return m.wsConnAll.Get(connId) v, ok := m.wsConnAll.Load(connId)
if ok {
conn, ok := v.(*wsConnect)
return conn, ok
}
return nil, false
} }
func (m *wsManager) FindByUserId(userId int64) (*wsConnect, bool) { func (m *wsManager) FindByUserId(userId int64) (*wsConnect, bool) {
connId := userMgr.GetConnId(userId) connId := userMgr.GetConnId(userId)
return m.wsConnAll.Get(connId) return m.Get(connId)
}
func (m *wsManager) Rang(cb func(k, v any) bool) {
m.wsConnAll.Range(cb)
} }
func (m *wsManager) Count() int { func (m *wsManager) Count() int {
return m.wsConnAll.Count() count := 0
m.wsConnAll.Range(func(k, v interface{}) bool {
count++
return true
})
return count
} }

View File

@ -55,7 +55,7 @@ func (s *WsServer) wsHandle(w http.ResponseWriter, r *http.Request) {
nextConnId++ nextConnId++
wsConn := newWsConnect(conn, s.onDisconnect) wsConn := newWsConnect(conn, s.onDisconnect)
wsMgr.Add(wsConn) wsMgr.Add(wsConn)
log.DebugF("当前在线人数:%v", wsMgr.Count()) log.DebugF("当前连接数:%v", wsMgr.Count())
ksync.GoSafe(func() { wsConn.handle(s.onMessage) }, nil) ksync.GoSafe(func() { wsConn.handle(s.onMessage) }, nil)
ksync.GoSafe(wsConn.readWsLoop, nil) ksync.GoSafe(wsConn.readWsLoop, nil)
ksync.GoSafe(wsConn.writeWsLoop, nil) ksync.GoSafe(wsConn.writeWsLoop, nil)
@ -79,3 +79,7 @@ func (s *WsServer) SetUserId(connId uint32, userId int64) {
func (s *WsServer) FindConnByUserId(userId int64) (IConn, bool) { func (s *WsServer) FindConnByUserId(userId int64) (IConn, bool) {
return wsMgr.FindByUserId(userId) return wsMgr.FindByUserId(userId)
} }
func (s *WsServer) FindConnByConnId(connId uint32) (IConn, bool) {
return wsMgr.Get(connId)
}