From 8ac9c88ce6ff267b16078b2d2d0b670dc53a48a2 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 8 Dec 2024 22:37:37 +0100 Subject: [PATCH] feat: add working revision with atomic replace --- src/config.rs | 52 +++++++++++---- src/fw.rs | 167 +++++++++++++++++++++++++++++----------------- src/ip.rs | 12 ++-- src/ipblc.rs | 57 ++++++++++------ src/monitoring.rs | 4 +- src/utils.rs | 5 +- src/websocket.rs | 9 ++- 7 files changed, 197 insertions(+), 109 deletions(-) diff --git a/src/config.rs b/src/config.rs index 1313e68..9e11ffb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,12 @@ use crate::ip::{BlockIpData, IpData, IpEvent}; use crate::utils::{gethostname, sleep_s}; +use std::{ + collections::HashMap, + hash::{Hash, Hasher}, + path::Path, +}; + use chrono::prelude::*; use chrono::Duration; use clap::{Arg, ArgAction, ArgMatches, Command}; @@ -10,11 +16,9 @@ use nix::sys::inotify::{AddWatchFlags, Inotify, WatchDescriptor}; use regex::Regex; use reqwest::{Client, Error as ReqError, Response}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::path::Path; pub const GIT_VERSION: &str = git_version!(args = ["--always", "--dirty="]); + const MASTERSERVER: &str = "ipbl.paulbsd.com"; const WSSUBSCRIPTION: &str = "ipbl"; const CONFIG_RETRY_INTERVAL: u64 = 2; @@ -28,7 +32,7 @@ pub struct Context { pub flags: Flags, pub sas: HashMap, pub hashwd: HashMap, - pub reloadinterval: isize, + pub reloadinterval: u64, } #[derive(Debug, Clone)] @@ -134,6 +138,25 @@ impl Context { } }; } + + let mut last_in_err = false; + loop { + let res = self.discovery().await; + match res { + Ok(o) => { + self.discovery = o; + if last_in_err { + println!("loaded discovery"); + } + break; + } + Err(e) => { + println!("error loading disvoery: {e}, retrying in {CONFIG_RETRY_INTERVAL}s"); + last_in_err = true; + sleep_s(CONFIG_RETRY_INTERVAL).await; + } + }; + } if last_in_err { println!("creating sas"); } @@ -145,7 +168,7 @@ impl Context { } #[cfg(test)] - pub async fn get_blocklist_pending(&self) -> Vec { + pub async fn get_blocklist_pending(&self) -> Vec { let mut res: Vec = vec![]; for (_, v) in self.blocklist.iter() { res.push(v.ipdata.clone()); @@ -153,13 +176,13 @@ impl Context { res } - pub async fn get_blocklist_toblock(&self) -> Vec { - let mut res: Vec = vec![]; + pub async fn get_blocklist_toblock(&self, all: bool) -> Vec { + let mut res: Vec = vec![]; for (_, ipblock) in self.blocklist.iter() { match self.cfg.sets.get(&ipblock.ipdata.src) { Some(set) => { - if ipblock.tryfail >= set.tryfail && !ipblock.blocked { - res.push(ipblock.ipdata.clone()); + if ipblock.tryfail >= set.tryfail && (!ipblock.blocked || all) { + res.push(ipblock.clone()); } } None => {} @@ -177,6 +200,7 @@ impl Context { .with_timezone(&chrono::Local); let blocktime = set.blocktime; let blocked = false; + let handle = u64::MIN; if ipevent.mode == "file".to_string() && gethostname(true) == ipevent.hostname { let block = self.blocklist @@ -187,6 +211,7 @@ impl Context { starttime, blocktime, blocked, + handle, }); block.tryfail += 1; block.blocktime = blocktime; @@ -202,6 +227,7 @@ impl Context { starttime, blocktime, blocked, + handle, }); } } @@ -212,8 +238,8 @@ impl Context { None } - pub async fn gc_blocklist(&mut self) -> Vec { - let mut removed: Vec = vec![]; + pub async fn gc_blocklist(&mut self) -> Vec { + let mut removed: Vec = vec![]; let now: DateTime = Local::now().trunc_subsecs(0); // nightly, future use // let drained: HashMap = ctx.blocklist.drain_filter(|k,v| v.parse_date() < mindate) @@ -228,7 +254,7 @@ impl Context { let mindate = now - Duration::minutes(blocked.blocktime); if blocked.starttime < mindate { self.blocklist.remove(&ip.clone()).unwrap(); - removed.push(blocked.ipdata.clone()); + removed.push(blocked.clone()); } } removed @@ -640,7 +666,7 @@ mod test { pub async fn test_blocklist_toblock() { let mut ctx = prepare_test_data().await; ctx.gc_blocklist().await; - let toblock = ctx.get_blocklist_toblock().await; + let toblock = ctx.get_blocklist_toblock(false).await; assert_eq!(toblock.len(), 3); } diff --git a/src/fw.rs b/src/fw.rs index 92c8bdf..2ca6b2e 100644 --- a/src/fw.rs +++ b/src/fw.rs @@ -1,11 +1,11 @@ -use crate::{ip::IpData, ipblc::PKG_NAME}; +use crate::{config::Context, ip::BlockIpData, ipblc::PKG_NAME}; use std::{ io::Error, - net::{Ipv4Addr, Ipv6Addr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; -use rustables::*; +use rustables::{expr::*, *}; pub enum FwTableType { IPv4, @@ -30,7 +30,7 @@ macro_rules! initrules { macro_rules! makerules { ($ipdata:ident, $chain:ident, $batch:ident, $t:ty, $ip_t:ident,$action:ty) => { - let ip = $ipdata.ip.parse::<$t>().unwrap(); + let ip = $ipdata.ipdata.ip.parse::<$t>().unwrap(); Rule::new(&$chain) .unwrap() .saddr(ip.into()) @@ -67,91 +67,132 @@ pub fn fwglobalinit(t: FwTableType, reset: bool) -> (Batch, Chain) { (batch, chain) } -pub fn fwblock<'a>(ip_add: &IpData) -> std::result::Result<(), error::QueryError> { +pub fn fwblock<'a>(ip_add: &BlockIpData) -> std::result::Result<&String, error::QueryError> { let (mut batch4, chain4) = fwglobalinit(FwTableType::IPv4, false); let (mut batch6, chain6) = fwglobalinit(FwTableType::IPv6, false); - match ip_add.t { + match ip_add.ipdata.t { 4 => { makerules!(ip_add, chain4, batch4, Ipv4Addr, ipv4, FwAction::Add); + match batch4.send() { + Ok(_) => {} + Err(e) => { + println!("block not ok {e} {ip_add:?}") + } + } } 6 => { makerules!(ip_add, chain6, batch6, Ipv6Addr, ipv6, FwAction::Add); + match batch6.send() { + Ok(_) => {} + Err(e) => { + println!("block not ok {e} {ip_add:?}") + } + } } _ => {} } - // validate and send batch - for b in [batch4, batch6] { - match b.send() { - Ok(_) => {} - Err(e) => { - println!("error sending batch: {e}"); - return Err(e); - } - }; - } - - Ok(()) + Ok(&ip_add.ipdata.ip) } -pub fn fwunblock<'a>(ips_del: &Vec) -> std::result::Result<(), Error> { +pub fn fwunblock<'a>(ip_del: &BlockIpData) -> std::result::Result<&String, error::QueryError> { let (mut batch4, chain4) = fwglobalinit(FwTableType::IPv4, false); let (mut batch6, chain6) = fwglobalinit(FwTableType::IPv6, false); - // to implement - /*for ip_del in ips_del { - match ip_del.t { - 4 => { - makerules!(ip_del, chain4, batch4, Ipv4Addr, ipv4, FwAction::Del); + match ip_del.ipdata.t { + 4 => { + let r = Rule::new(&chain4).unwrap().with_handle(ip_del.handle); + batch4.add(&r, MsgType::Del); + match batch4.send() { + Ok(_) => {} + Err(e) => { + println!("delete not ok {e} {ip_del:?}") + } } - 6 => { - makerules!(ip_del, chain6, batch6, Ipv6Addr, ipv6, FwAction::Del); - } - _ => {} } - }*/ - Ok(()) + 6 => { + let r = Rule::new(&chain6).unwrap().with_handle(ip_del.handle); + batch6.add(&r, MsgType::Del); + match batch6.send() { + Ok(_) => {} + Err(e) => { + println!("delete not ok {e} {ip_del:?}") + } + } + } + _ => {} + } + Ok(&ip_del.ipdata.ip) } -pub fn get_current_rules(table_name: &str, chain_name: &str) -> Result<(), Error> { - let get_table = || -> Result, Error> { - let tables = list_tables().unwrap(); - for table in tables { - if let Some(name) = table.get_name() { - println!("Found table {}", name); +pub fn get_current_rules( + ctx: &mut Context, + ret: &mut Vec, + fwlen: &mut usize, +) -> Result<(), Error> { + let mut ips_all_count = 0; + let tables = vec![format!("{PKG_NAME}4"), format!("{PKG_NAME}6")]; + for table_name in tables { + let get_table = || -> Result, Error> { + let tables = list_tables().unwrap(); + for table in tables { + if let Some(name) = table.get_name() { + if *name == table_name { + return Ok(Some(table)); + } + } + } - if *name == table_name { - return Ok(Some(table)); + Ok(None) + }; + + let get_chain = |table: &Table| -> Result, Error> { + let chains = list_chains_for_table(table).unwrap(); + for chain in chains { + if let Some(name) = chain.get_name() { + if *name == "ipblc" { + return Ok(Some(chain)); + } + } + } + + Ok(None) + }; + + let table = get_table()?.expect("no table?"); + let chain = get_chain(&table)?.expect("no chain?"); + + for (ip, c) in ctx.blocklist.iter_mut() { + let ip_parsed: IpAddr = ip.parse().unwrap(); + + let cmprule = Rule::new(&chain).unwrap().saddr(ip_parsed).drop(); + let mut gexpr = RawExpression::default(); + for e in cmprule.get_expressions().unwrap().iter() { + if let Some(ExpressionVariant::Cmp(_ip)) = e.get_data() { + gexpr = e.clone(); + } + } + + let rules = list_rules_for_chain(&chain).unwrap(); + for rule in rules { + for expr in rule.get_expressions().unwrap().iter() { + if let Some(expr::ExpressionVariant::Cmp(_)) = expr.get_data() { + if gexpr == expr.clone() { + ips_all_count += 1; + c.handle = *rule.get_handle().unwrap(); + } + } } } } - - Ok(None) - }; - - let get_chain = |table: &Table| -> Result, Error> { - let chains = list_chains_for_table(table).unwrap(); - for chain in chains { - if let Some(name) = chain.get_name() { - println!("Found chain {}", name); - - if *name == chain_name { - return Ok(Some(chain)); - } - } - } - - Ok(None) - }; - - let table = get_table()?.expect("no table?"); - let chain = get_chain(&table)?.expect("no chain?"); - - let rules = list_rules_for_chain(&chain).unwrap(); - for mut rule in rules { - println!("{:?}", rule); } + + if *fwlen != ips_all_count { + ret.push(format!("{length} ip in firewall", length = ips_all_count)); + } + *fwlen = ips_all_count; + Ok(()) } diff --git a/src/ip.rs b/src/ip.rs index 93a705b..63afa07 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -1,15 +1,18 @@ use crate::utils::gethostname; +use std::{ + cmp::Ordering, + fmt::{Display, Formatter}, + io::{BufRead, BufReader, Read}, + net::IpAddr, +}; + use chrono::offset::LocalResult; use chrono::prelude::*; use ipnet::IpNet; use lazy_static::lazy_static; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -use std::fmt::{Display, Formatter}; -use std::io::{BufRead, BufReader, Read}; -use std::net::IpAddr; lazy_static! { static ref R_IPV4: Regex = Regex::new(include_str!("regexps/ipv4.txt")).unwrap(); @@ -52,6 +55,7 @@ pub struct BlockIpData { pub blocktime: i64, pub starttime: DateTime, pub blocked: bool, + pub handle: u64, } #[derive(Clone, Debug, Serialize, Deserialize, Eq)] diff --git a/src/ipblc.rs b/src/ipblc.rs index 9f5d004..883cd74 100644 --- a/src/ipblc.rs +++ b/src/ipblc.rs @@ -7,20 +7,19 @@ use crate::utils::{gethostname, read_lines, sleep_s}; use crate::webservice::send_to_ipbl_api; use crate::websocket::{send_to_ipbl_websocket, websocketpubsub, websocketreqrep}; +use std::{collections::HashMap, sync::Arc}; + use chrono::prelude::*; use chrono::prelude::{DateTime, Local}; use chrono::Duration; use nix::sys::inotify::{InitFlags, Inotify, InotifyEvent}; use sd_notify::*; -use std::collections::HashMap; -use std::sync::Arc; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; pub const PKG_NAME: &str = env!("CARGO_PKG_NAME"); const BL_CHAN_SIZE: usize = 32; const WS_CHAN_SIZE: usize = 64; -const LOOP_MAX_WAIT: u64 = 5; macro_rules! log_with_systemd { ($msg:expr) => { @@ -75,21 +74,26 @@ pub async fn run() { loop { let mut ret: Vec = Vec::new(); - let ctxclone = Arc::clone(&ctxarc); + let reloadinterval; + { + let ctx = ctxclone.read().await; + reloadinterval = ctx.reloadinterval; + } + tokio::select! { ipevent = ipeventrx.recv() => { let received_ip = ipevent.unwrap(); let (toblock,server) = { let ctx = ctxclone.read().await; - (ctx.get_blocklist_toblock().await,ctx.flags.server.clone()) + (ctx.get_blocklist_toblock(true).await,ctx.flags.server.clone()) }; if received_ip.msgtype == "bootstrap".to_string() { for ip_to_send in toblock { - let ipe = ipevent!("init","ws",gethostname(true),Some(ip_to_send)); + let ipe = ipevent!("init","ws",gethostname(true),Some(ip_to_send.ipdata)); if !send_to_ipbl_websocket(&mut wssocketrr, &ipe).await { wssocketrr.close(None).unwrap(); wssocketrr = websocketreqrep(&ctxwsrr).await; @@ -118,7 +122,7 @@ pub async fn run() { } } } - _val = sleep_s(LOOP_MAX_WAIT) => { + _val = sleep_s(reloadinterval) => { let ipe = ipevent!("ping", "ws", gethostname(true)); if !send_to_ipbl_websocket(&mut wssocketrr, &ipe).await { wssocketrr.close(None).unwrap(); @@ -134,15 +138,20 @@ pub async fn run() { }; let toblock = { let ctx = ctxclone.read().await; - ctx.get_blocklist_toblock().await + ctx.get_blocklist_toblock(false).await }; - // apply firewall blocking + { + let mut ctx = ctxclone.write().await; + get_current_rules(&mut ctx, &mut ret, &mut fwlen).unwrap(); + get_current_rules(&mut ctx, &mut ret, &mut fwlen).unwrap(); + } + for b in toblock { match fwblock(&b) { - Ok(_) => { + Ok(ip) => { let mut ctx = ctxclone.write().await; - if let Some(x) = ctx.blocklist.get_mut(&b.ip) { + if let Some(x) = ctx.blocklist.get_mut(ip) { x.blocked = true; } } @@ -151,13 +160,18 @@ pub async fn run() { } }; } - get_current_rules("ipblc4", "ipblc").unwrap(); - match fwunblock(&tounblock) { - Ok(_) => {} - Err(e) => { - println!("err: {e}, unable to push firewall rules, use super user") + + for ub in tounblock { + if ub.blocked { + let res = fwunblock(&ub); + match res { + Ok(_) => {} + Err(e) => { + println!("err: {e}, unable to push firewall rules, use super user") + } + }; } - }; + } // log lines if ret.len() > 0 { @@ -167,17 +181,18 @@ pub async fn run() { let ctxclone = Arc::clone(&ctxarc); let inoclone = Arc::clone(&inoarc); - handle_cfg_reload(&ctxclone, &mut last_cfg_reload, inoclone).await; + handle_cfg_reload(&ctxclone, reloadinterval, &mut last_cfg_reload, inoclone).await; } } async fn handle_cfg_reload( ctxclone: &Arc>, + reloadinterval: u64, last_cfg_reload: &mut DateTime, inoarc: Arc>, ) { let now_cfg_reload = Local::now().trunc_subsecs(0); - if (now_cfg_reload - *last_cfg_reload) > Duration::seconds(LOOP_MAX_WAIT as i64) { + if (now_cfg_reload - *last_cfg_reload) > Duration::seconds(reloadinterval as i64) { let inotify; loop { inotify = match inoarc.try_read() { @@ -190,14 +205,14 @@ async fn handle_cfg_reload( }; break; } - let mut ctxtest = match ctxclone.try_write() { + let mut ctx = match ctxclone.try_write() { Ok(o) => o, Err(e) => { println!("{e}"); return; } }; - match ctxtest.load(&inotify).await { + match ctx.load(&inotify).await { Ok(_) => { *last_cfg_reload = Local::now().trunc_subsecs(0); } diff --git a/src/monitoring.rs b/src/monitoring.rs index 4ca1634..ead461f 100644 --- a/src/monitoring.rs +++ b/src/monitoring.rs @@ -1,8 +1,8 @@ use crate::config::Context; +use std::{io, sync::Arc}; + use serde_json; -use std::io; -use std::sync::Arc; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio::sync::RwLock; diff --git a/src/utils.rs b/src/utils.rs index 8066012..9fe6bb7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,7 +1,6 @@ +use std::{boxed::Box, fs::File, io::*}; + use nix::unistd; -use std::boxed::Box; -use std::fs::File; -use std::io::*; use tokio::time::{sleep, Duration}; pub fn read_lines(filename: &String, offset: u64) -> Option> { diff --git a/src/websocket.rs b/src/websocket.rs index 436239d..aac0f79 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -2,10 +2,13 @@ use crate::config::{Context, WebSocketCfg}; use crate::ip::IpEvent; use crate::utils::{gethostname, sleep_s}; +use std::{ + io::{self, Write}, + net::TcpStream, + sync::Arc, +}; + use serde_json::json; -use std::io::{self, Write}; -use std::net::TcpStream; -use std::sync::Arc; use tokio::sync::mpsc::Sender; use tokio::sync::RwLock; use tungstenite::stream::*;