added transactions, cleaned code
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Paul 2022-05-10 13:00:22 +02:00
parent 632c4ba101
commit 9d70341f41
4 changed files with 35 additions and 8 deletions

View File

@ -16,6 +16,7 @@ var ipv4_cidr_regex = `^(((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|)){4}\/([1-3
func GetTrustlists(cfg config.Config) (res []string, err error) { func GetTrustlists(cfg config.Config) (res []string, err error) {
var w []CfgTrustlist var w []CfgTrustlist
err = cfg.Db.Find(&w) err = cfg.Db.Find(&w)
if len(w) > 0 { if len(w) > 0 {
for _, a := range w { for _, a := range w {
res = append(res, a.IP) res = append(res, a.IP)
@ -27,6 +28,7 @@ func GetTrustlists(cfg config.Config) (res []string, err error) {
func (wl CfgTrustlist) InsertOrUpdate(cfg config.Config) (err error) { func (wl CfgTrustlist) InsertOrUpdate(cfg config.Config) (err error) {
var w = Cfg{Key: "trustlist"} var w = Cfg{Key: "trustlist"}
exists, _ := cfg.Db.Get(&w) exists, _ := cfg.Db.Get(&w)
if exists { if exists {
existing, _ := GetTrustlists(cfg) existing, _ := GetTrustlists(cfg)
for _, j := range existing { for _, j := range existing {
@ -44,6 +46,7 @@ func (wl CfgTrustlist) InsertOrUpdate(cfg config.Config) (err error) {
func (wl CfgTrustlist) Delete(cfg config.Config, ip string) (affected int64, err error) { func (wl CfgTrustlist) Delete(cfg config.Config, ip string) (affected int64, err error) {
var w = CfgTrustlist{IP: ip} var w = CfgTrustlist{IP: ip}
exists, _ := cfg.Db.Get(&w) exists, _ := cfg.Db.Get(&w)
if exists { if exists {
affected, err = cfg.Db.ID(w.ID).Update(&w) affected, err = cfg.Db.ID(w.ID).Update(&w)
return return
@ -58,6 +61,7 @@ func (wl CfgTrustlist) Verify() bool {
func GetSets(cfg config.Config) (res []CfgSet, err error) { func GetSets(cfg config.Config) (res []CfgSet, err error) {
var w = []CfgSet{} var w = []CfgSet{}
if err := cfg.Db.Find(&w); err == nil { if err := cfg.Db.Find(&w); err == nil {
return w, err return w, err
} }
@ -66,6 +70,7 @@ func GetSets(cfg config.Config) (res []CfgSet, err error) {
func GetGlobalConfig(cfg config.Config) (res []Cfg, err error) { func GetGlobalConfig(cfg config.Config) (res []Cfg, err error) {
var w = []Cfg{} var w = []Cfg{}
if err := cfg.Db.Find(&w); err == nil { if err := cfg.Db.Find(&w); err == nil {
return w, err return w, err
} }
@ -74,6 +79,7 @@ func GetGlobalConfig(cfg config.Config) (res []Cfg, err error) {
func InsertOrUpdateSets(cfg config.Config, folders []CfgSet) (res string, err error) { func InsertOrUpdateSets(cfg config.Config, folders []CfgSet) (res string, err error) {
var w = Cfg{Key: "folders"} var w = Cfg{Key: "folders"}
if exists, _ := cfg.Db.Get(&w); exists { if exists, _ := cfg.Db.Get(&w); exists {
resbytes, err := json.Marshal(folders) resbytes, err := json.Marshal(folders)
if err != nil { if err != nil {
@ -90,6 +96,7 @@ func InsertOrUpdateSets(cfg config.Config, folders []CfgSet) (res string, err er
func GetZMQ(cfg config.Config) (res []CfgZMQ, err error) { func GetZMQ(cfg config.Config) (res []CfgZMQ, err error) {
var w = []CfgZMQ{} var w = []CfgZMQ{}
if err = cfg.Db.Find(&w); err == nil { if err = cfg.Db.Find(&w); err == nil {
return w, err return w, err
} }
@ -115,6 +122,7 @@ type CfgTrustlist struct {
ID int `xorm:"pk autoincr" json:"-"` ID int `xorm:"pk autoincr" json:"-"`
IP string `xorm:"text notnull" json:"ip"` IP string `xorm:"text notnull" json:"ip"`
} }
type CfgZMQ struct { type CfgZMQ struct {
ID int `xorm:"pk autoincr" json:"-"` ID int `xorm:"pk autoincr" json:"-"`
Type string `xorm:"text notnull" json:"type"` Type string `xorm:"text notnull" json:"type"`

View File

@ -31,9 +31,11 @@ func GetIPsLast(ctx *context.Context, config *config.Config, interval string) (a
return return
} }
func GetIP(ctx *context.Context, config *config.Config, ipquery interface{}) (apiip *APIIP, err error) { func GetIP(ctx *context.Context, cfg *config.Config, ipquery interface{}) (apiip *APIIP, err error) {
session := cfg.Db.NewSession()
defer session.Close()
var ip IP var ip IP
has, err := config.Db.Where("ip = ?", ipquery).Get(&ip) has, err := session.Where("ip = ?", ipquery).Get(&ip)
if !has { if !has {
err = fmt.Errorf("not found") err = fmt.Errorf("not found")
return nil, err return nil, err
@ -42,6 +44,7 @@ func GetIP(ctx *context.Context, config *config.Config, ipquery interface{}) (ap
return return
} }
apiip = ip.APIFormat() apiip = ip.APIFormat()
session.Commit()
return return
} }
@ -56,19 +59,22 @@ func (ip *IP) UpdateRDNS() (result string, err error) {
} }
func (ip *IP) InsertOrUpdate(cfg *config.Config) (numinsert int64, numupdate int64, err error) { func (ip *IP) InsertOrUpdate(cfg *config.Config) (numinsert int64, numupdate int64, err error) {
session := cfg.Db.NewSession()
defer session.Close()
var ips = []IP{} var ips = []IP{}
err = cfg.Db.Where("ip = ?", ip.IP).And("src = ?", ip.Src).And("hostname = ?", ip.Hostname).Find(&ips) err = session.Where("ip = ?", ip.IP).And("src = ?", ip.Src).And("hostname = ?", ip.Hostname).Find(&ips)
if len(ips) > 0 { if len(ips) > 0 {
numupdate, err = cfg.Db.Where("id = ?", ips[0].ID).Cols("updated").Update(&IP{}) numupdate, err = session.Where("id = ?", ips[0].ID).Cols("updated").Update(&IP{})
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
} else { } else {
numinsert, err = cfg.Db.Insert(ip) numinsert, err = session.Insert(ip)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
} }
session.Commit()
return return
} }
@ -122,13 +128,16 @@ func ScanIP(cfg *config.Config) (err error) {
} }
func Cleanup(cfg *config.Config) (err error) { func Cleanup(cfg *config.Config) (err error) {
results, _ := cfg.Db.Query("select * from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null order by updated desc;") session := cfg.Db.NewSession()
defer session.Close()
results, _ := session.Query("select * from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null order by updated desc;")
if len(results) > 0 { if len(results) > 0 {
_, err := cfg.Db.Query("delete from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null;") _, err := session.Query("delete from ip where ip in (select ip from ip group by ip having count(ip) > 1) and hostname is null;")
if err != nil { if err != nil {
log.Println("error deleting orphans") log.Println("error deleting orphans")
} }
} }
session.Commit()
return return
} }

View File

@ -13,7 +13,7 @@ func differ(sl1 []IP, sl2 []IP) (toinsert []IP, err error) {
m[v2.IP] = IPDiffer{IP: v2, Num: 1} m[v2.IP] = IPDiffer{IP: v2, Num: 1}
} else { } else {
if this, ok := m[v2.IP]; ok { if this, ok := m[v2.IP]; ok {
this.Num += 1 this.Num++
m[v2.IP] = this m[v2.IP] = this
} }
} }

View File

@ -12,10 +12,12 @@ import (
func Init(cfg *config.Config) (err error) { func Init(cfg *config.Config) (err error) {
log.Println("Initiating ZMQ sockets") log.Println("Initiating ZMQ sockets")
reqsock, err := InitRep() reqsock, err := InitRep()
if err != nil { if err != nil {
return return
} }
pubsock, err := InitPub() pubsock, err := InitPub()
if err != nil { if err != nil {
return return
@ -27,23 +29,28 @@ func Init(cfg *config.Config) (err error) {
func Handle(cfg *config.Config, reqsock *goczmq.Sock, pubsock *goczmq.Sock, channel string) (err error) { func Handle(cfg *config.Config, reqsock *goczmq.Sock, pubsock *goczmq.Sock, channel string) (err error) {
log.Println("Start handling zmq sockets") log.Println("Start handling zmq sockets")
for { for {
var msg = "err" var msg = "err"
var req, err = reqsock.RecvMessage() var req, err = reqsock.RecvMessage()
if err != nil { if err != nil {
log.Println("unable to receive message from req socket") log.Println("unable to receive message from req socket")
continue continue
} }
var topub [][]byte var topub [][]byte
for _, val := range req { for _, val := range req {
var apiip = models.APIIP{} var apiip = models.APIIP{}
var ip = models.IP{} var ip = models.IP{}
err = json.Unmarshal(val, &apiip) err = json.Unmarshal(val, &apiip)
if err != nil { if err != nil {
log.Println("unable to parse ip address", err) log.Println("unable to parse ip address", err)
continue continue
} }
ip = *apiip.APIConvert() ip = *apiip.APIConvert()
numinsert, numupdate, err := ip.InsertOrUpdate(cfg) numinsert, numupdate, err := ip.InsertOrUpdate(cfg)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -53,13 +60,16 @@ func Handle(cfg *config.Config, reqsock *goczmq.Sock, pubsock *goczmq.Sock, chan
val = []byte(tmpval) val = []byte(tmpval)
topub = append(topub, val) topub = append(topub, val)
} }
err = pubsock.SendMessage(topub) err = pubsock.SendMessage(topub)
if err != nil { if err != nil {
log.Println("error sending message to pub socket") log.Println("error sending message to pub socket")
continue continue
} }
msg = "ok" msg = "ok"
var resp [][]byte = [][]byte{[]byte(msg)} var resp [][]byte = [][]byte{[]byte(msg)}
err = reqsock.SendMessage(resp) err = reqsock.SendMessage(resp)
if err != nil { if err != nil {
log.Println("error replying message to req socket") log.Println("error replying message to req socket")