diff --git a/src/ws/init.go b/src/ws/init.go index e0d923e..3dbbfd5 100644 --- a/src/ws/init.go +++ b/src/ws/init.go @@ -1,20 +1,20 @@ package ws import ( + "sync" "time" "git.paulbsd.com/paulbsd/ipbl/src/config" "golang.org/x/net/websocket" ) -var listeners map[string]*connectionInfo +var listeners sync.Map func Init(cfg *config.Config) { - listeners = make(map[string]*connectionInfo) } func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) { - connectinfo, ok := listeners[welcome.Hostname] + connectinfoVal, ok := listeners.Load(welcome.Hostname) if !ok { switch t { @@ -23,15 +23,16 @@ func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) { ConnectionPS: ws, InitDate: time.Now(), } - listeners[welcome.Hostname] = &connectinfo + listeners.Store(welcome.Hostname, &connectinfo) case "rr": connectinfo := connectionInfo{ ConnectionRR: ws, InitDate: time.Now(), } - listeners[welcome.Hostname] = &connectinfo + listeners.Store(welcome.Hostname, &connectinfo) } } else { + connectinfo := connectinfoVal.(*connectionInfo) switch t { case "ps": connectinfo.ConnectionPS = ws @@ -42,15 +43,16 @@ func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) { } func gcConnOnError(ws *websocket.Conn) (err error) { - for index, value := range listeners { - if value.ConnectionPS == ws { - value.ConnectionPS.Close() - delete(listeners, index) - } else if value.ConnectionRR == ws { - value.ConnectionRR.Close() - delete(listeners, index) + listeners.Range(func(index, value interface{}) bool { + if value.(*connectionInfo).ConnectionPS == ws { + value.(*connectionInfo).ConnectionPS.Close() + listeners.Delete(index) + } else if value.(*connectionInfo).ConnectionRR == ws { + value.(*connectionInfo).ConnectionRR.Close() + listeners.Delete(index) } - } + return true + }) return err } diff --git a/src/ws/reqrep.go b/src/ws/reqrep.go index 09186d3..742e72a 100644 --- a/src/ws/reqrep.go +++ b/src/ws/reqrep.go @@ -52,15 +52,16 @@ func HandleWSRR(c *echo.Context, cfg *config.Config) error { switch apievent.MsgType { case "bootstrap": log.Printf("bootstrap: %s\n", apievent.Hostname) - for index, value := range listeners { - if index != apievent.Hostname && value.ConnectionPS != nil { - err = websocket.Message.Send(value.ConnectionPS, msg) + listeners.Range(func(index, value interface{}) bool { + if index != apievent.Hostname && value.(*connectionInfo).ConnectionPS != nil { + err = websocket.Message.Send(value.(*connectionInfo).ConnectionPS, msg) if err != nil { log.Println(err) gcConnOnError(ws) } } - } + return true + }) case "add": session := cfg.Db.NewSession() event.APIParse(session, apievent) @@ -71,26 +72,28 @@ func HandleWSRR(c *echo.Context, cfg *config.Config) error { log.Println(err) } - for _, value := range listeners { - if value.ConnectionPS != nil { - err = websocket.Message.Send(value.ConnectionPS, msg) + listeners.Range(func(index, value interface{}) bool { + if value.(*connectionInfo).ConnectionPS != nil { + err = websocket.Message.Send(value.(*connectionInfo).ConnectionPS, msg) if err != nil { log.Println(err) gcConnOnError(ws) } } - } + return true + }) log.Printf("ws: Inserted event") case "init": - for _, value := range listeners { - if value.ConnectionPS != nil { - err = websocket.Message.Send(value.ConnectionPS, msg) + listeners.Range(func(index, value interface{}) bool { + if value.(*connectionInfo).ConnectionPS != nil { + err = websocket.Message.Send(value.(*connectionInfo).ConnectionPS, msg) if err != nil { log.Println(err) gcConnOnError(ws) } } - } + return true + }) default: }