diff --git a/src/fw.rs b/src/fw.rs index 2ca6b2e..d687c6c 100644 --- a/src/fw.rs +++ b/src/fw.rs @@ -3,8 +3,11 @@ use crate::{config::Context, ip::BlockIpData, ipblc::PKG_NAME}; use std::{ io::Error, net::{IpAddr, Ipv4Addr, Ipv6Addr}, + sync::Arc, }; +use tokio::sync::RwLock; + use rustables::{expr::*, *}; pub enum FwTableType { @@ -126,8 +129,8 @@ pub fn fwunblock<'a>(ip_del: &BlockIpData) -> std::result::Result<&String, error Ok(&ip_del.ipdata.ip) } -pub fn get_current_rules( - ctx: &mut Context, +pub async fn get_current_rules( + ctx: &Arc>, ret: &mut Vec, fwlen: &mut usize, ) -> Result<(), Error> { @@ -163,6 +166,9 @@ pub fn get_current_rules( let table = get_table()?.expect("no table?"); let chain = get_chain(&table)?.expect("no chain?"); + let mut ctx = { ctx.write().await }; + let rules = list_rules_for_chain(&chain).unwrap().clone(); + for (ip, c) in ctx.blocklist.iter_mut() { let ip_parsed: IpAddr = ip.parse().unwrap(); @@ -174,8 +180,7 @@ pub fn get_current_rules( } } - let rules = list_rules_for_chain(&chain).unwrap(); - for rule in rules { + for rule in rules.iter() { for expr in rule.get_expressions().unwrap().iter() { if let Some(expr::ExpressionVariant::Cmp(_)) = expr.get_data() { if gexpr == expr.clone() { diff --git a/src/ipblc.rs b/src/ipblc.rs index 883cd74..f026cb6 100644 --- a/src/ipblc.rs +++ b/src/ipblc.rs @@ -132,23 +132,21 @@ pub async fn run() { }; let ctxclone = Arc::clone(&ctxarc); - let tounblock = { + let ipstounblock = { let mut ctx = ctxclone.write().await; ctx.gc_blocklist().await }; - let toblock = { + let ipstoblock = { let ctx = ctxclone.read().await; ctx.get_blocklist_toblock(false).await }; - { - let mut ctx = ctxclone.write().await; - get_current_rules(&mut ctx, &mut ret, &mut fwlen).unwrap(); - get_current_rules(&mut ctx, &mut ret, &mut fwlen).unwrap(); - } + get_current_rules(&ctxarc, &mut ret, &mut fwlen) + .await + .unwrap(); - for b in toblock { - match fwblock(&b) { + for ip in ipstoblock { + match fwblock(&ip) { Ok(ip) => { let mut ctx = ctxclone.write().await; if let Some(x) = ctx.blocklist.get_mut(ip) { @@ -161,10 +159,9 @@ pub async fn run() { }; } - for ub in tounblock { - if ub.blocked { - let res = fwunblock(&ub); - match res { + for ip in ipstounblock { + if ip.blocked { + match fwunblock(&ip) { Ok(_) => {} Err(e) => { println!("err: {e}, unable to push firewall rules, use super user")