replaced classic map by sync.Map
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Paul 2023-12-24 10:22:26 +01:00
parent 2a52ef3157
commit f1f6b2dbc9
2 changed files with 30 additions and 25 deletions

View File

@ -1,20 +1,20 @@
package ws package ws
import ( import (
"sync"
"time" "time"
"git.paulbsd.com/paulbsd/ipbl/src/config" "git.paulbsd.com/paulbsd/ipbl/src/config"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
var listeners map[string]*connectionInfo var listeners sync.Map
func Init(cfg *config.Config) { func Init(cfg *config.Config) {
listeners = make(map[string]*connectionInfo)
} }
func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) { func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) {
connectinfo, ok := listeners[welcome.Hostname] connectinfoVal, ok := listeners.Load(welcome.Hostname)
if !ok { if !ok {
switch t { switch t {
@ -23,15 +23,16 @@ func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) {
ConnectionPS: ws, ConnectionPS: ws,
InitDate: time.Now(), InitDate: time.Now(),
} }
listeners[welcome.Hostname] = &connectinfo listeners.Store(welcome.Hostname, &connectinfo)
case "rr": case "rr":
connectinfo := connectionInfo{ connectinfo := connectionInfo{
ConnectionRR: ws, ConnectionRR: ws,
InitDate: time.Now(), InitDate: time.Now(),
} }
listeners[welcome.Hostname] = &connectinfo listeners.Store(welcome.Hostname, &connectinfo)
} }
} else { } else {
connectinfo := connectinfoVal.(*connectionInfo)
switch t { switch t {
case "ps": case "ps":
connectinfo.ConnectionPS = ws connectinfo.ConnectionPS = ws
@ -42,15 +43,16 @@ func welcomeAgents(ws *websocket.Conn, welcome wsWelcome, t string) {
} }
func gcConnOnError(ws *websocket.Conn) (err error) { func gcConnOnError(ws *websocket.Conn) (err error) {
for index, value := range listeners { listeners.Range(func(index, value interface{}) bool {
if value.ConnectionPS == ws { if value.(*connectionInfo).ConnectionPS == ws {
value.ConnectionPS.Close() value.(*connectionInfo).ConnectionPS.Close()
delete(listeners, index) listeners.Delete(index)
} else if value.ConnectionRR == ws { } else if value.(*connectionInfo).ConnectionRR == ws {
value.ConnectionRR.Close() value.(*connectionInfo).ConnectionRR.Close()
delete(listeners, index) listeners.Delete(index)
}
} }
return true
})
return err return err
} }

View File

@ -52,15 +52,16 @@ func HandleWSRR(c *echo.Context, cfg *config.Config) error {
switch apievent.MsgType { switch apievent.MsgType {
case "bootstrap": case "bootstrap":
log.Printf("bootstrap: %s\n", apievent.Hostname) log.Printf("bootstrap: %s\n", apievent.Hostname)
for index, value := range listeners { listeners.Range(func(index, value interface{}) bool {
if index != apievent.Hostname && value.ConnectionPS != nil { if index != apievent.Hostname && value.(*connectionInfo).ConnectionPS != nil {
err = websocket.Message.Send(value.ConnectionPS, msg) err = websocket.Message.Send(value.(*connectionInfo).ConnectionPS, msg)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
gcConnOnError(ws) gcConnOnError(ws)
} }
} }
} return true
})
case "add": case "add":
session := cfg.Db.NewSession() session := cfg.Db.NewSession()
event.APIParse(session, apievent) event.APIParse(session, apievent)
@ -71,26 +72,28 @@ func HandleWSRR(c *echo.Context, cfg *config.Config) error {
log.Println(err) log.Println(err)
} }
for _, value := range listeners { listeners.Range(func(index, value interface{}) bool {
if value.ConnectionPS != nil { if value.(*connectionInfo).ConnectionPS != nil {
err = websocket.Message.Send(value.ConnectionPS, msg) err = websocket.Message.Send(value.(*connectionInfo).ConnectionPS, msg)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
gcConnOnError(ws) gcConnOnError(ws)
} }
} }
} return true
})
log.Printf("ws: Inserted event") log.Printf("ws: Inserted event")
case "init": case "init":
for _, value := range listeners { listeners.Range(func(index, value interface{}) bool {
if value.ConnectionPS != nil { if value.(*connectionInfo).ConnectionPS != nil {
err = websocket.Message.Send(value.ConnectionPS, msg) err = websocket.Message.Send(value.(*connectionInfo).ConnectionPS, msg)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
gcConnOnError(ws) gcConnOnError(ws)
} }
} }
} return true
})
default: default:
} }