diff --git a/etcd/etcd.go b/etcd/etcd.go index 24dda6f..b5bf9cf 100644 --- a/etcd/etcd.go +++ b/etcd/etcd.go @@ -9,18 +9,18 @@ import ( type resultT[T any] struct { Value T - Err error } type Registry[T INode] struct { *etcdRegistryImpl - nodes sync.Map + nodes *sync.Map } func NewRegistry[T INode](endpoints []string, username, password string) (*Registry[T], error) { var err error e := &Registry[T]{} - e.etcdRegistryImpl, err = newServiceRegistryImpl(endpoints, resultT[T]{}.Value.EtcdRootKey(), username, password, e.saveNode) + e.etcdRegistryImpl, err = newServiceRegistryImpl(endpoints, resultT[T]{}.Value.EtcdRootKey(), username, password, e.saveNode, e.replace) + e.nodes = &sync.Map{} return e, err } @@ -32,23 +32,28 @@ func (r *Registry[T]) Register(node INode) error { return r.etcdRegistryImpl.Register(node.EtcdKey(), string(bs)) } -// 获取当前服务 -func (r *Registry[T]) saveNode(jsonBytes []byte) { - var tmp = resultT[T]{Err: nil} +// 保存当前服务 +func (r *Registry[T]) saveNode(newNodes *sync.Map, jsonBytes []byte) { + var tmp = resultT[T]{} if err := json.Unmarshal(jsonBytes, &tmp.Value); err != nil { log.ErrorF(err.Error()) } - r.nodes.Store(tmp.Value.MapKey(), tmp.Value) + newNodes.Store(tmp.Value.MapKey(), tmp.Value) +} + +// 保存当前服务 +func (r *Registry[T]) replace(newNodes *sync.Map) { + r.nodes = newNodes } // 获取当前根节点下所有节点信息 func (r *Registry[T]) GetNodes() *sync.Map { - return &r.nodes + return r.nodes } // 根据inode的mapKey()查找对应的节点 func (r *Registry[T]) FindNode(key string) (T, error) { - var tmp = resultT[T]{Err: nil} + var tmp = resultT[T]{} v, ok := r.nodes.Load(key) if !ok { return tmp.Value, fmt.Errorf("%v not exist", key) diff --git a/etcd/etcdImpl.go b/etcd/etcdImpl.go index b919ed7..3e65022 100644 --- a/etcd/etcdImpl.go +++ b/etcd/etcdImpl.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/fox/fox/ksync" "github.com/fox/fox/log" + "sync" "time" clientv3 "go.etcd.io/etcd/client/v3" @@ -23,11 +24,12 @@ type etcdRegistryImpl struct { nodeKey string cancelFunc context.CancelFunc rootKey string - saveNodeFunc func(jsonBytes []byte) + saveNodeFunc func(*sync.Map, []byte) + replaceFunc func(*sync.Map) } // 创建服务注册中心 -func newServiceRegistryImpl(endpoints []string, rootKey, username, password string, saveNode func([]byte)) (*etcdRegistryImpl, error) { +func newServiceRegistryImpl(endpoints []string, rootKey, username, password string, saveNode func(*sync.Map, []byte), replace func(*sync.Map)) (*etcdRegistryImpl, error) { cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: DefaultDialTimeout, @@ -42,6 +44,7 @@ func newServiceRegistryImpl(endpoints []string, rootKey, username, password stri cli: cli, rootKey: rootKey, saveNodeFunc: saveNode, + replaceFunc: replace, }, nil } @@ -143,9 +146,11 @@ func (sr *etcdRegistryImpl) discoverServices() error { } // log.Debug(fmt.Sprintf("discoverServices srv:%s", srv)) + newNodes := &sync.Map{} for _, kv := range resp.Kvs { - sr.saveNodeFunc(kv.Value) + sr.saveNodeFunc(newNodes, kv.Value) } + sr.replaceFunc(newNodes) return nil }