use crate::ip::IpData; use nftnl::{nft_expr, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table}; use std::{ffi::CString, io::Error, net::Ipv4Addr}; pub fn init(tablename: &String) -> (Batch, Table) { let mut batch = Batch::new(); let table = Table::new( &CString::new(tablename.as_str()).unwrap(), ProtoFamily::Ipv4, ); batch.add(&table, nftnl::MsgType::Add); batch.add(&table, nftnl::MsgType::Del); batch.add(&table, nftnl::MsgType::Add); (batch, table) } pub fn block( tablename: &String, ips_add: &Vec, ret: &mut Vec, fwlen: &mut usize, ) -> std::result::Result<(), Error> { // convert chain let ips_add = convert(ips_add); let (mut batch, table) = init(tablename); // build chain let mut chain = Chain::new(&CString::new(tablename.as_str()).unwrap(), &table); chain.set_hook(nftnl::Hook::In, 1); chain.set_policy(nftnl::Policy::Accept); // add chain batch.add(&chain, nftnl::MsgType::Add); let mut rule = Rule::new(&chain); rule.add_expr(&nft_expr!(ct state)); rule.add_expr(&nft_expr!(bitwise mask 4u32, xor 0u32)); rule.add_expr(&nft_expr!(cmp != 0u32)); rule.add_expr(&nft_expr!(counter)); rule.add_expr(&nft_expr!(verdict accept)); batch.add(&rule, nftnl::MsgType::Add); // build and add rules for ip in ips_add.clone() { let mut rule = Rule::new(&chain); rule.add_expr(&nft_expr!(payload ipv4 saddr)); rule.add_expr(&nft_expr!(cmp == ip)); rule.add_expr(&nft_expr!(ct state)); rule.add_expr(&nft_expr!(bitwise mask 10u32, xor 0u32)); rule.add_expr(&nft_expr!(cmp != 0u32)); rule.add_expr(&nft_expr!(counter)); rule.add_expr(&nft_expr!(verdict drop)); batch.add(&rule, nftnl::MsgType::Add); } // validate and send batch let finalized_batch = batch.finalize(); send_and_process(&finalized_batch)?; if fwlen != &mut ips_add.len() { ret.push(format!("{length} ip in firewall", length = ips_add.len())); } *fwlen = ips_add.len(); Ok(()) } fn send_and_process(batch: &FinalizedBatch) -> std::result::Result<(), Error> { let seq: u32 = 2; let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; socket.send_all(batch)?; let mut buffer = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; while let Some(message) = socket_recv(&socket, &mut buffer[..])? { match mnl::cb_run(message, seq, socket.portid())? { mnl::CbResult::Stop => { break; } mnl::CbResult::Ok => (), } } Ok(()) } fn socket_recv<'a>( socket: &mnl::Socket, buf: &'a mut [u8], ) -> std::result::Result, Error> { let ret = socket.recv(buf)?; if ret > 0 { Ok(Some(&buf[..ret])) } else { Ok(None) } } fn convert(input: &Vec) -> Vec { let mut output: Vec = vec![]; for val in input { output.push(val.ip.parse::().unwrap()); } output }