From 9d70341f4164260accf50059bd0d8e4f00cc8f27 Mon Sep 17 00:00:00 2001 From: Paul Lecuq Date: Tue, 10 May 2022 13:00:22 +0200 Subject: [PATCH] added transactions, cleaned code --- src/models/cfg.go | 8 ++++++++ src/models/ip.go | 23 ++++++++++++++++------- src/models/utils.go | 2 +- src/zmqrouter/main.go | 10 ++++++++++ 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/models/cfg.go b/src/models/cfg.go index 9764b62..beef041 100644 --- a/src/models/cfg.go +++ b/src/models/cfg.go @@ -16,6 +16,7 @@ var ipv4_cidr_regex = `^(((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|)){4}\/([1-3 func GetTrustlists(cfg config.Config) (res []string, err error) { var w []CfgTrustlist err = cfg.Db.Find(&w) + if len(w) > 0 { for _, a := range w { res = append(res, a.IP) @@ -27,6 +28,7 @@ func GetTrustlists(cfg config.Config) (res []string, err error) { func (wl CfgTrustlist) InsertOrUpdate(cfg config.Config) (err error) { var w = Cfg{Key: "trustlist"} exists, _ := cfg.Db.Get(&w) + if exists { existing, _ := GetTrustlists(cfg) for _, j := range existing { @@ -44,6 +46,7 @@ func (wl CfgTrustlist) InsertOrUpdate(cfg config.Config) (err error) { func (wl CfgTrustlist) Delete(cfg config.Config, ip string) (affected int64, err error) { var w = CfgTrustlist{IP: ip} exists, _ := cfg.Db.Get(&w) + if exists { affected, err = cfg.Db.ID(w.ID).Update(&w) return @@ -58,6 +61,7 @@ func (wl CfgTrustlist) Verify() bool { func GetSets(cfg config.Config) (res []CfgSet, err error) { var w = []CfgSet{} + if err := cfg.Db.Find(&w); err == nil { return w, err } @@ -66,6 +70,7 @@ func GetSets(cfg config.Config) (res []CfgSet, err error) { func GetGlobalConfig(cfg config.Config) (res []Cfg, err error) { var w = []Cfg{} + if err := cfg.Db.Find(&w); err == nil { return w, err } @@ -74,6 +79,7 @@ func GetGlobalConfig(cfg config.Config) (res []Cfg, err error) { func InsertOrUpdateSets(cfg config.Config, folders []CfgSet) (res string, err error) { var w = Cfg{Key: "folders"} + if exists, _ := cfg.Db.Get(&w); exists { resbytes, err := json.Marshal(folders) if err != nil { @@ -90,6 +96,7 @@ func InsertOrUpdateSets(cfg config.Config, folders []CfgSet) (res string, err er func GetZMQ(cfg config.Config) (res []CfgZMQ, err error) { var w = []CfgZMQ{} + if err = cfg.Db.Find(&w); err == nil { return w, err } @@ -115,6 +122,7 @@ type CfgTrustlist struct { ID int `xorm:"pk autoincr" json:"-"` IP string `xorm:"text notnull" json:"ip"` } + type CfgZMQ struct { ID int `xorm:"pk autoincr" json:"-"` Type string `xorm:"text notnull" json:"type"` diff --git a/src/models/ip.go b/src/models/ip.go index f581176..ad15c35 100644 --- a/src/models/ip.go +++ b/src/models/ip.go @@ -31,9 +31,11 @@ func GetIPsLast(ctx *context.Context, config *config.Config, interval string) (a return } -func GetIP(ctx *context.Context, config *config.Config, ipquery interface{}) (apiip *APIIP, err error) { +func GetIP(ctx *context.Context, cfg *config.Config, ipquery interface{}) (apiip *APIIP, err error) { + session := cfg.Db.NewSession() + defer session.Close() var ip IP - has, err := config.Db.Where("ip = ?", ipquery).Get(&ip) + has, err := session.Where("ip = ?", ipquery).Get(&ip) if !has { err = fmt.Errorf("not found") return nil, err @@ -42,6 +44,7 @@ func GetIP(ctx *context.Context, config *config.Config, ipquery interface{}) (ap return } apiip = ip.APIFormat() + session.Commit() return } @@ -56,19 +59,22 @@ func (ip *IP) UpdateRDNS() (result string, err error) { } func (ip *IP) InsertOrUpdate(cfg *config.Config) (numinsert int64, numupdate int64, err error) { + session := cfg.Db.NewSession() + defer session.Close() var ips = []IP{} - err = cfg.Db.Where("ip = ?", ip.IP).And("src = ?", ip.Src).And("hostname = ?", ip.Hostname).Find(&ips) + err = session.Where("ip = ?", ip.IP).And("src = ?", ip.Src).And("hostname = ?", ip.Hostname).Find(&ips) if len(ips) > 0 { - numupdate, err = cfg.Db.Where("id = ?", ips[0].ID).Cols("updated").Update(&IP{}) + numupdate, err = session.Where("id = ?", ips[0].ID).Cols("updated").Update(&IP{}) if err != nil { log.Println(err) } } else { - numinsert, err = cfg.Db.Insert(ip) + numinsert, err = session.Insert(ip) if err != nil { log.Println(err) } } + session.Commit() return } @@ -122,13 +128,16 @@ func ScanIP(cfg *config.Config) (err error) { } func Cleanup(cfg *config.Config) (err error) { - results, _ := cfg.Db.Query("select * from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null order by updated desc;") + session := cfg.Db.NewSession() + defer session.Close() + results, _ := session.Query("select * from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null order by updated desc;") if len(results) > 0 { - _, err := cfg.Db.Query("delete from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null;") + _, err := session.Query("delete from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null;") if err != nil { log.Println("error deleting orphans") } } + session.Commit() return } diff --git a/src/models/utils.go b/src/models/utils.go index 0ae7842..5d654e9 100644 --- a/src/models/utils.go +++ b/src/models/utils.go @@ -13,7 +13,7 @@ func differ(sl1 []IP, sl2 []IP) (toinsert []IP, err error) { m[v2.IP] = IPDiffer{IP: v2, Num: 1} } else { if this, ok := m[v2.IP]; ok { - this.Num += 1 + this.Num++ m[v2.IP] = this } } diff --git a/src/zmqrouter/main.go b/src/zmqrouter/main.go index 77847d3..e9ffe7c 100644 --- a/src/zmqrouter/main.go +++ b/src/zmqrouter/main.go @@ -12,10 +12,12 @@ import ( func Init(cfg *config.Config) (err error) { log.Println("Initiating ZMQ sockets") + reqsock, err := InitRep() if err != nil { return } + pubsock, err := InitPub() if err != nil { return @@ -27,23 +29,28 @@ func Init(cfg *config.Config) (err error) { func Handle(cfg *config.Config, reqsock *goczmq.Sock, pubsock *goczmq.Sock, channel string) (err error) { log.Println("Start handling zmq sockets") + for { var msg = "err" var req, err = reqsock.RecvMessage() + if err != nil { log.Println("unable to receive message from req socket") continue } + var topub [][]byte for _, val := range req { var apiip = models.APIIP{} var ip = models.IP{} + err = json.Unmarshal(val, &apiip) if err != nil { log.Println("unable to parse ip address", err) continue } ip = *apiip.APIConvert() + numinsert, numupdate, err := ip.InsertOrUpdate(cfg) if err != nil { log.Println(err) @@ -53,13 +60,16 @@ func Handle(cfg *config.Config, reqsock *goczmq.Sock, pubsock *goczmq.Sock, chan val = []byte(tmpval) topub = append(topub, val) } + err = pubsock.SendMessage(topub) if err != nil { log.Println("error sending message to pub socket") continue } + msg = "ok" var resp [][]byte = [][]byte{[]byte(msg)} + err = reqsock.SendMessage(resp) if err != nil { log.Println("error replying message to req socket")