feat: add working revision with atomic replace
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Paul 2024-12-08 22:37:37 +01:00
parent e968539df5
commit 8ac9c88ce6
7 changed files with 197 additions and 109 deletions

View File

@ -1,6 +1,12 @@
use crate::ip::{BlockIpData, IpData, IpEvent}; use crate::ip::{BlockIpData, IpData, IpEvent};
use crate::utils::{gethostname, sleep_s}; use crate::utils::{gethostname, sleep_s};
use std::{
collections::HashMap,
hash::{Hash, Hasher},
path::Path,
};
use chrono::prelude::*; use chrono::prelude::*;
use chrono::Duration; use chrono::Duration;
use clap::{Arg, ArgAction, ArgMatches, Command}; use clap::{Arg, ArgAction, ArgMatches, Command};
@ -10,11 +16,9 @@ use nix::sys::inotify::{AddWatchFlags, Inotify, WatchDescriptor};
use regex::Regex; use regex::Regex;
use reqwest::{Client, Error as ReqError, Response}; use reqwest::{Client, Error as ReqError, Response};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::path::Path;
pub const GIT_VERSION: &str = git_version!(args = ["--always", "--dirty="]); pub const GIT_VERSION: &str = git_version!(args = ["--always", "--dirty="]);
const MASTERSERVER: &str = "ipbl.paulbsd.com"; const MASTERSERVER: &str = "ipbl.paulbsd.com";
const WSSUBSCRIPTION: &str = "ipbl"; const WSSUBSCRIPTION: &str = "ipbl";
const CONFIG_RETRY_INTERVAL: u64 = 2; const CONFIG_RETRY_INTERVAL: u64 = 2;
@ -28,7 +32,7 @@ pub struct Context {
pub flags: Flags, pub flags: Flags,
pub sas: HashMap<String, SetMap>, pub sas: HashMap<String, SetMap>,
pub hashwd: HashMap<String, WatchDescriptor>, pub hashwd: HashMap<String, WatchDescriptor>,
pub reloadinterval: isize, pub reloadinterval: u64,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -134,6 +138,25 @@ impl Context {
} }
}; };
} }
let mut last_in_err = false;
loop {
let res = self.discovery().await;
match res {
Ok(o) => {
self.discovery = o;
if last_in_err {
println!("loaded discovery");
}
break;
}
Err(e) => {
println!("error loading disvoery: {e}, retrying in {CONFIG_RETRY_INTERVAL}s");
last_in_err = true;
sleep_s(CONFIG_RETRY_INTERVAL).await;
}
};
}
if last_in_err { if last_in_err {
println!("creating sas"); println!("creating sas");
} }
@ -145,7 +168,7 @@ impl Context {
} }
#[cfg(test)] #[cfg(test)]
pub async fn get_blocklist_pending(&self) -> Vec<IpData> { pub async fn get_blocklist_pending(&self) -> Vec<BlockIpData> {
let mut res: Vec<IpData> = vec![]; let mut res: Vec<IpData> = vec![];
for (_, v) in self.blocklist.iter() { for (_, v) in self.blocklist.iter() {
res.push(v.ipdata.clone()); res.push(v.ipdata.clone());
@ -153,13 +176,13 @@ impl Context {
res res
} }
pub async fn get_blocklist_toblock(&self) -> Vec<IpData> { pub async fn get_blocklist_toblock(&self, all: bool) -> Vec<BlockIpData> {
let mut res: Vec<IpData> = vec![]; let mut res: Vec<BlockIpData> = vec![];
for (_, ipblock) in self.blocklist.iter() { for (_, ipblock) in self.blocklist.iter() {
match self.cfg.sets.get(&ipblock.ipdata.src) { match self.cfg.sets.get(&ipblock.ipdata.src) {
Some(set) => { Some(set) => {
if ipblock.tryfail >= set.tryfail && !ipblock.blocked { if ipblock.tryfail >= set.tryfail && (!ipblock.blocked || all) {
res.push(ipblock.ipdata.clone()); res.push(ipblock.clone());
} }
} }
None => {} None => {}
@ -177,6 +200,7 @@ impl Context {
.with_timezone(&chrono::Local); .with_timezone(&chrono::Local);
let blocktime = set.blocktime; let blocktime = set.blocktime;
let blocked = false; let blocked = false;
let handle = u64::MIN;
if ipevent.mode == "file".to_string() && gethostname(true) == ipevent.hostname { if ipevent.mode == "file".to_string() && gethostname(true) == ipevent.hostname {
let block = let block =
self.blocklist self.blocklist
@ -187,6 +211,7 @@ impl Context {
starttime, starttime,
blocktime, blocktime,
blocked, blocked,
handle,
}); });
block.tryfail += 1; block.tryfail += 1;
block.blocktime = blocktime; block.blocktime = blocktime;
@ -202,6 +227,7 @@ impl Context {
starttime, starttime,
blocktime, blocktime,
blocked, blocked,
handle,
}); });
} }
} }
@ -212,8 +238,8 @@ impl Context {
None None
} }
pub async fn gc_blocklist(&mut self) -> Vec<IpData> { pub async fn gc_blocklist(&mut self) -> Vec<BlockIpData> {
let mut removed: Vec<IpData> = vec![]; let mut removed: Vec<BlockIpData> = vec![];
let now: DateTime<Local> = Local::now().trunc_subsecs(0); let now: DateTime<Local> = Local::now().trunc_subsecs(0);
// nightly, future use // nightly, future use
// let drained: HashMap<String,IpData> = ctx.blocklist.drain_filter(|k,v| v.parse_date() < mindate) // let drained: HashMap<String,IpData> = ctx.blocklist.drain_filter(|k,v| v.parse_date() < mindate)
@ -228,7 +254,7 @@ impl Context {
let mindate = now - Duration::minutes(blocked.blocktime); let mindate = now - Duration::minutes(blocked.blocktime);
if blocked.starttime < mindate { if blocked.starttime < mindate {
self.blocklist.remove(&ip.clone()).unwrap(); self.blocklist.remove(&ip.clone()).unwrap();
removed.push(blocked.ipdata.clone()); removed.push(blocked.clone());
} }
} }
removed removed
@ -640,7 +666,7 @@ mod test {
pub async fn test_blocklist_toblock() { pub async fn test_blocklist_toblock() {
let mut ctx = prepare_test_data().await; let mut ctx = prepare_test_data().await;
ctx.gc_blocklist().await; ctx.gc_blocklist().await;
let toblock = ctx.get_blocklist_toblock().await; let toblock = ctx.get_blocklist_toblock(false).await;
assert_eq!(toblock.len(), 3); assert_eq!(toblock.len(), 3);
} }

111
src/fw.rs
View File

@ -1,11 +1,11 @@
use crate::{ip::IpData, ipblc::PKG_NAME}; use crate::{config::Context, ip::BlockIpData, ipblc::PKG_NAME};
use std::{ use std::{
io::Error, io::Error,
net::{Ipv4Addr, Ipv6Addr}, net::{IpAddr, Ipv4Addr, Ipv6Addr},
}; };
use rustables::*; use rustables::{expr::*, *};
pub enum FwTableType { pub enum FwTableType {
IPv4, IPv4,
@ -30,7 +30,7 @@ macro_rules! initrules {
macro_rules! makerules { macro_rules! makerules {
($ipdata:ident, $chain:ident, $batch:ident, $t:ty, $ip_t:ident,$action:ty) => { ($ipdata:ident, $chain:ident, $batch:ident, $t:ty, $ip_t:ident,$action:ty) => {
let ip = $ipdata.ip.parse::<$t>().unwrap(); let ip = $ipdata.ipdata.ip.parse::<$t>().unwrap();
Rule::new(&$chain) Rule::new(&$chain)
.unwrap() .unwrap()
.saddr(ip.into()) .saddr(ip.into())
@ -67,60 +67,77 @@ pub fn fwglobalinit(t: FwTableType, reset: bool) -> (Batch, Chain) {
(batch, chain) (batch, chain)
} }
pub fn fwblock<'a>(ip_add: &IpData) -> std::result::Result<(), error::QueryError> { pub fn fwblock<'a>(ip_add: &BlockIpData) -> std::result::Result<&String, error::QueryError> {
let (mut batch4, chain4) = fwglobalinit(FwTableType::IPv4, false); let (mut batch4, chain4) = fwglobalinit(FwTableType::IPv4, false);
let (mut batch6, chain6) = fwglobalinit(FwTableType::IPv6, false); let (mut batch6, chain6) = fwglobalinit(FwTableType::IPv6, false);
match ip_add.t { match ip_add.ipdata.t {
4 => { 4 => {
makerules!(ip_add, chain4, batch4, Ipv4Addr, ipv4, FwAction::Add); makerules!(ip_add, chain4, batch4, Ipv4Addr, ipv4, FwAction::Add);
match batch4.send() {
Ok(_) => {}
Err(e) => {
println!("block not ok {e} {ip_add:?}")
}
}
} }
6 => { 6 => {
makerules!(ip_add, chain6, batch6, Ipv6Addr, ipv6, FwAction::Add); makerules!(ip_add, chain6, batch6, Ipv6Addr, ipv6, FwAction::Add);
match batch6.send() {
Ok(_) => {}
Err(e) => {
println!("block not ok {e} {ip_add:?}")
}
}
} }
_ => {} _ => {}
} }
// validate and send batch Ok(&ip_add.ipdata.ip)
for b in [batch4, batch6] {
match b.send() {
Ok(_) => {}
Err(e) => {
println!("error sending batch: {e}");
return Err(e);
}
};
}
Ok(())
} }
pub fn fwunblock<'a>(ips_del: &Vec<IpData>) -> std::result::Result<(), Error> { pub fn fwunblock<'a>(ip_del: &BlockIpData) -> std::result::Result<&String, error::QueryError> {
let (mut batch4, chain4) = fwglobalinit(FwTableType::IPv4, false); let (mut batch4, chain4) = fwglobalinit(FwTableType::IPv4, false);
let (mut batch6, chain6) = fwglobalinit(FwTableType::IPv6, false); let (mut batch6, chain6) = fwglobalinit(FwTableType::IPv6, false);
// to implement match ip_del.ipdata.t {
/*for ip_del in ips_del {
match ip_del.t {
4 => { 4 => {
makerules!(ip_del, chain4, batch4, Ipv4Addr, ipv4, FwAction::Del); let r = Rule::new(&chain4).unwrap().with_handle(ip_del.handle);
batch4.add(&r, MsgType::Del);
match batch4.send() {
Ok(_) => {}
Err(e) => {
println!("delete not ok {e} {ip_del:?}")
}
}
} }
6 => { 6 => {
makerules!(ip_del, chain6, batch6, Ipv6Addr, ipv6, FwAction::Del); let r = Rule::new(&chain6).unwrap().with_handle(ip_del.handle);
batch6.add(&r, MsgType::Del);
match batch6.send() {
Ok(_) => {}
Err(e) => {
println!("delete not ok {e} {ip_del:?}")
}
}
} }
_ => {} _ => {}
} }
}*/ Ok(&ip_del.ipdata.ip)
Ok(())
} }
pub fn get_current_rules(table_name: &str, chain_name: &str) -> Result<(), Error> { pub fn get_current_rules(
ctx: &mut Context,
ret: &mut Vec<String>,
fwlen: &mut usize,
) -> Result<(), Error> {
let mut ips_all_count = 0;
let tables = vec![format!("{PKG_NAME}4"), format!("{PKG_NAME}6")];
for table_name in tables {
let get_table = || -> Result<Option<Table>, Error> { let get_table = || -> Result<Option<Table>, Error> {
let tables = list_tables().unwrap(); let tables = list_tables().unwrap();
for table in tables { for table in tables {
if let Some(name) = table.get_name() { if let Some(name) = table.get_name() {
println!("Found table {}", name);
if *name == table_name { if *name == table_name {
return Ok(Some(table)); return Ok(Some(table));
} }
@ -134,9 +151,7 @@ pub fn get_current_rules(table_name: &str, chain_name: &str) -> Result<(), Error
let chains = list_chains_for_table(table).unwrap(); let chains = list_chains_for_table(table).unwrap();
for chain in chains { for chain in chains {
if let Some(name) = chain.get_name() { if let Some(name) = chain.get_name() {
println!("Found chain {}", name); if *name == "ipblc" {
if *name == chain_name {
return Ok(Some(chain)); return Ok(Some(chain));
} }
} }
@ -148,10 +163,36 @@ pub fn get_current_rules(table_name: &str, chain_name: &str) -> Result<(), Error
let table = get_table()?.expect("no table?"); let table = get_table()?.expect("no table?");
let chain = get_chain(&table)?.expect("no chain?"); let chain = get_chain(&table)?.expect("no chain?");
let rules = list_rules_for_chain(&chain).unwrap(); for (ip, c) in ctx.blocklist.iter_mut() {
for mut rule in rules { let ip_parsed: IpAddr = ip.parse().unwrap();
println!("{:?}", rule);
let cmprule = Rule::new(&chain).unwrap().saddr(ip_parsed).drop();
let mut gexpr = RawExpression::default();
for e in cmprule.get_expressions().unwrap().iter() {
if let Some(ExpressionVariant::Cmp(_ip)) = e.get_data() {
gexpr = e.clone();
} }
}
let rules = list_rules_for_chain(&chain).unwrap();
for rule in rules {
for expr in rule.get_expressions().unwrap().iter() {
if let Some(expr::ExpressionVariant::Cmp(_)) = expr.get_data() {
if gexpr == expr.clone() {
ips_all_count += 1;
c.handle = *rule.get_handle().unwrap();
}
}
}
}
}
}
if *fwlen != ips_all_count {
ret.push(format!("{length} ip in firewall", length = ips_all_count));
}
*fwlen = ips_all_count;
Ok(()) Ok(())
} }

View File

@ -1,15 +1,18 @@
use crate::utils::gethostname; use crate::utils::gethostname;
use std::{
cmp::Ordering,
fmt::{Display, Formatter},
io::{BufRead, BufReader, Read},
net::IpAddr,
};
use chrono::offset::LocalResult; use chrono::offset::LocalResult;
use chrono::prelude::*; use chrono::prelude::*;
use ipnet::IpNet; use ipnet::IpNet;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::io::{BufRead, BufReader, Read};
use std::net::IpAddr;
lazy_static! { lazy_static! {
static ref R_IPV4: Regex = Regex::new(include_str!("regexps/ipv4.txt")).unwrap(); static ref R_IPV4: Regex = Regex::new(include_str!("regexps/ipv4.txt")).unwrap();
@ -52,6 +55,7 @@ pub struct BlockIpData {
pub blocktime: i64, pub blocktime: i64,
pub starttime: DateTime<Local>, pub starttime: DateTime<Local>,
pub blocked: bool, pub blocked: bool,
pub handle: u64,
} }
#[derive(Clone, Debug, Serialize, Deserialize, Eq)] #[derive(Clone, Debug, Serialize, Deserialize, Eq)]

View File

@ -7,20 +7,19 @@ use crate::utils::{gethostname, read_lines, sleep_s};
use crate::webservice::send_to_ipbl_api; use crate::webservice::send_to_ipbl_api;
use crate::websocket::{send_to_ipbl_websocket, websocketpubsub, websocketreqrep}; use crate::websocket::{send_to_ipbl_websocket, websocketpubsub, websocketreqrep};
use std::{collections::HashMap, sync::Arc};
use chrono::prelude::*; use chrono::prelude::*;
use chrono::prelude::{DateTime, Local}; use chrono::prelude::{DateTime, Local};
use chrono::Duration; use chrono::Duration;
use nix::sys::inotify::{InitFlags, Inotify, InotifyEvent}; use nix::sys::inotify::{InitFlags, Inotify, InotifyEvent};
use sd_notify::*; use sd_notify::*;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock; use tokio::sync::RwLock;
pub const PKG_NAME: &str = env!("CARGO_PKG_NAME"); pub const PKG_NAME: &str = env!("CARGO_PKG_NAME");
const BL_CHAN_SIZE: usize = 32; const BL_CHAN_SIZE: usize = 32;
const WS_CHAN_SIZE: usize = 64; const WS_CHAN_SIZE: usize = 64;
const LOOP_MAX_WAIT: u64 = 5;
macro_rules! log_with_systemd { macro_rules! log_with_systemd {
($msg:expr) => { ($msg:expr) => {
@ -75,21 +74,26 @@ pub async fn run() {
loop { loop {
let mut ret: Vec<String> = Vec::new(); let mut ret: Vec<String> = Vec::new();
let ctxclone = Arc::clone(&ctxarc); let ctxclone = Arc::clone(&ctxarc);
let reloadinterval;
{
let ctx = ctxclone.read().await;
reloadinterval = ctx.reloadinterval;
}
tokio::select! { tokio::select! {
ipevent = ipeventrx.recv() => { ipevent = ipeventrx.recv() => {
let received_ip = ipevent.unwrap(); let received_ip = ipevent.unwrap();
let (toblock,server) = { let (toblock,server) = {
let ctx = ctxclone.read().await; let ctx = ctxclone.read().await;
(ctx.get_blocklist_toblock().await,ctx.flags.server.clone()) (ctx.get_blocklist_toblock(true).await,ctx.flags.server.clone())
}; };
if received_ip.msgtype == "bootstrap".to_string() { if received_ip.msgtype == "bootstrap".to_string() {
for ip_to_send in toblock { for ip_to_send in toblock {
let ipe = ipevent!("init","ws",gethostname(true),Some(ip_to_send)); let ipe = ipevent!("init","ws",gethostname(true),Some(ip_to_send.ipdata));
if !send_to_ipbl_websocket(&mut wssocketrr, &ipe).await { if !send_to_ipbl_websocket(&mut wssocketrr, &ipe).await {
wssocketrr.close(None).unwrap(); wssocketrr.close(None).unwrap();
wssocketrr = websocketreqrep(&ctxwsrr).await; wssocketrr = websocketreqrep(&ctxwsrr).await;
@ -118,7 +122,7 @@ pub async fn run() {
} }
} }
} }
_val = sleep_s(LOOP_MAX_WAIT) => { _val = sleep_s(reloadinterval) => {
let ipe = ipevent!("ping", "ws", gethostname(true)); let ipe = ipevent!("ping", "ws", gethostname(true));
if !send_to_ipbl_websocket(&mut wssocketrr, &ipe).await { if !send_to_ipbl_websocket(&mut wssocketrr, &ipe).await {
wssocketrr.close(None).unwrap(); wssocketrr.close(None).unwrap();
@ -134,15 +138,20 @@ pub async fn run() {
}; };
let toblock = { let toblock = {
let ctx = ctxclone.read().await; let ctx = ctxclone.read().await;
ctx.get_blocklist_toblock().await ctx.get_blocklist_toblock(false).await
}; };
// apply firewall blocking {
let mut ctx = ctxclone.write().await;
get_current_rules(&mut ctx, &mut ret, &mut fwlen).unwrap();
get_current_rules(&mut ctx, &mut ret, &mut fwlen).unwrap();
}
for b in toblock { for b in toblock {
match fwblock(&b) { match fwblock(&b) {
Ok(_) => { Ok(ip) => {
let mut ctx = ctxclone.write().await; let mut ctx = ctxclone.write().await;
if let Some(x) = ctx.blocklist.get_mut(&b.ip) { if let Some(x) = ctx.blocklist.get_mut(ip) {
x.blocked = true; x.blocked = true;
} }
} }
@ -151,13 +160,18 @@ pub async fn run() {
} }
}; };
} }
get_current_rules("ipblc4", "ipblc").unwrap();
match fwunblock(&tounblock) { for ub in tounblock {
if ub.blocked {
let res = fwunblock(&ub);
match res {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
println!("err: {e}, unable to push firewall rules, use super user") println!("err: {e}, unable to push firewall rules, use super user")
} }
}; };
}
}
// log lines // log lines
if ret.len() > 0 { if ret.len() > 0 {
@ -167,17 +181,18 @@ pub async fn run() {
let ctxclone = Arc::clone(&ctxarc); let ctxclone = Arc::clone(&ctxarc);
let inoclone = Arc::clone(&inoarc); let inoclone = Arc::clone(&inoarc);
handle_cfg_reload(&ctxclone, &mut last_cfg_reload, inoclone).await; handle_cfg_reload(&ctxclone, reloadinterval, &mut last_cfg_reload, inoclone).await;
} }
} }
async fn handle_cfg_reload( async fn handle_cfg_reload(
ctxclone: &Arc<RwLock<Context>>, ctxclone: &Arc<RwLock<Context>>,
reloadinterval: u64,
last_cfg_reload: &mut DateTime<Local>, last_cfg_reload: &mut DateTime<Local>,
inoarc: Arc<RwLock<Inotify>>, inoarc: Arc<RwLock<Inotify>>,
) { ) {
let now_cfg_reload = Local::now().trunc_subsecs(0); let now_cfg_reload = Local::now().trunc_subsecs(0);
if (now_cfg_reload - *last_cfg_reload) > Duration::seconds(LOOP_MAX_WAIT as i64) { if (now_cfg_reload - *last_cfg_reload) > Duration::seconds(reloadinterval as i64) {
let inotify; let inotify;
loop { loop {
inotify = match inoarc.try_read() { inotify = match inoarc.try_read() {
@ -190,14 +205,14 @@ async fn handle_cfg_reload(
}; };
break; break;
} }
let mut ctxtest = match ctxclone.try_write() { let mut ctx = match ctxclone.try_write() {
Ok(o) => o, Ok(o) => o,
Err(e) => { Err(e) => {
println!("{e}"); println!("{e}");
return; return;
} }
}; };
match ctxtest.load(&inotify).await { match ctx.load(&inotify).await {
Ok(_) => { Ok(_) => {
*last_cfg_reload = Local::now().trunc_subsecs(0); *last_cfg_reload = Local::now().trunc_subsecs(0);
} }

View File

@ -1,8 +1,8 @@
use crate::config::Context; use crate::config::Context;
use std::{io, sync::Arc};
use serde_json; use serde_json;
use std::io;
use std::sync::Arc;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::RwLock; use tokio::sync::RwLock;

View File

@ -1,7 +1,6 @@
use std::{boxed::Box, fs::File, io::*};
use nix::unistd; use nix::unistd;
use std::boxed::Box;
use std::fs::File;
use std::io::*;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
pub fn read_lines(filename: &String, offset: u64) -> Option<Box<dyn Read>> { pub fn read_lines(filename: &String, offset: u64) -> Option<Box<dyn Read>> {

View File

@ -2,10 +2,13 @@ use crate::config::{Context, WebSocketCfg};
use crate::ip::IpEvent; use crate::ip::IpEvent;
use crate::utils::{gethostname, sleep_s}; use crate::utils::{gethostname, sleep_s};
use std::{
io::{self, Write},
net::TcpStream,
sync::Arc,
};
use serde_json::json; use serde_json::json;
use std::io::{self, Write};
use std::net::TcpStream;
use std::sync::Arc;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tungstenite::stream::*; use tungstenite::stream::*;