ipblc/src/fw.rs
Paul Lecuq 1e2f047824
All checks were successful
continuous-integration/drone/push Build is passing
add ips in chunks to nftables
2024-01-03 21:44:00 +01:00

172 lines
4.9 KiB
Rust

use crate::ip::IpData;
use crate::ipblc::PKG_NAME;
use nftnl::{nft_expr, Batch, Chain, FinalizedBatch, ProtoFamily, Rule, Table};
use std::{
ffi::CString,
io::Error,
net::{Ipv4Addr, Ipv6Addr},
};
pub enum FwTableType {
IPv4,
IPv6,
}
pub fn fwglobalinit<'a>() -> ((Batch, Table), (Batch, Table)) {
let (batch4, table4) = fwinit(FwTableType::IPv4);
let (batch6, table6) = fwinit(FwTableType::IPv6);
((batch4, table4), (batch6, table6))
}
macro_rules! initrules {
($batch:expr, $table:expr, $chain:ident) => {
$chain.set_hook(nftnl::Hook::In, 1);
$chain.set_policy(nftnl::Policy::Accept);
$batch.add(&$chain, nftnl::MsgType::Add);
$batch.add(&Rule::new(&$chain), nftnl::MsgType::Del);
let mut rule = Rule::new(&$chain);
rule.add_expr(&nft_expr!(ct state));
rule.add_expr(&nft_expr!(bitwise mask 4u32, xor 0u32));
rule.add_expr(&nft_expr!(cmp != 0u32));
rule.add_expr(&nft_expr!(counter));
rule.add_expr(&nft_expr!(verdict accept));
$batch.add(&rule, nftnl::MsgType::Add);
};
}
macro_rules! createrules {
($ipdata:ident, $chain:ident, $batch:ident, $t:ty, $ip_t:ident) => {
let mut rule = Rule::new(&$chain);
let ip = $ipdata.ip.parse::<$t>().unwrap();
rule.add_expr(&nft_expr!(payload $ip_t saddr));
rule.add_expr(&nft_expr!(cmp == ip));
rule.add_expr(&nft_expr!(ct state));
rule.add_expr(&nft_expr!(bitwise mask 10u32, xor 0u32));
rule.add_expr(&nft_expr!(cmp != 0u32));
rule.add_expr(&nft_expr!(counter));
rule.add_expr(&nft_expr!(verdict drop));
$batch.add(&rule, nftnl::MsgType::Add);
}
}
fn fwinit(t: FwTableType) -> (Batch, Table) {
let table_name: String;
let table: Table;
match t {
FwTableType::IPv4 => {
table_name = format!("{PKG_NAME}4");
table = Table::new(
&CString::new(format!("{table_name}")).unwrap(),
ProtoFamily::Ipv4,
);
}
FwTableType::IPv6 => {
table_name = format!("{PKG_NAME}6");
table = Table::new(
&CString::new(format!("{table_name}")).unwrap(),
ProtoFamily::Ipv6,
);
}
}
let mut batch = Batch::new();
batch.add(&table, nftnl::MsgType::Add);
batch.add(&table, nftnl::MsgType::Del);
batch.add(&table, nftnl::MsgType::Add);
(batch, table)
}
pub fn fwblock(
ips_add_all: &Vec<IpData>,
ret: &mut Vec<String>,
fwlen: &mut usize,
) -> std::result::Result<(), Error> {
let ((mut batch4, table4), (mut batch6, table6)) = fwglobalinit();
let mut chain4 = Chain::new(&CString::new(PKG_NAME).unwrap(), &table4);
let mut chain6 = Chain::new(&CString::new(PKG_NAME).unwrap(), &table6);
initrules!(batch4, table4, chain4);
initrules!(batch6, table6, chain6);
let mut factor = 1;
if ips_add_all.len() > 100 {
factor = (ips_add_all.len() / 10) as usize
}
let ips_add_tmp: Vec<IpData> = ips_add_all.clone().iter().map(|x| x.clone()).collect();
let mut ips_add_iter = ips_add_tmp.chunks(factor);
let mut ips_add: Vec<&[IpData]> = vec![];
while let Some(x) = ips_add_iter.next() {
ips_add.push(x);
}
// build and add rules
for ipdata_group in ips_add.clone() {
for ipdata in ipdata_group {
match ipdata.t {
4 => {
createrules!(ipdata, chain4, batch4, Ipv4Addr, ipv4);
}
6 => {
createrules!(ipdata, chain6, batch6, Ipv6Addr, ipv6);
}
_ => {}
}
}
}
// validate and send batch
for b in [batch4, batch6] {
let bf = b.finalize();
match send_and_process(&bf) {
Ok(_) => {}
Err(e) => {
println!("error sending batch: {e}");
}
};
}
if fwlen != &mut ips_add_all.len() {
ret.push(format!(
"{length} ip in firewall",
length = ips_add_all.len()
));
}
*fwlen = ips_add_all.len();
Ok(())
}
fn send_and_process(batch: &FinalizedBatch) -> std::result::Result<(), Error> {
let seq: u32 = 2;
let socket = mnl::Socket::new(mnl::Bus::Netfilter)?;
socket.send_all(batch)?;
let mut buffer = vec![0; nftnl::nft_nlmsg_maxsize() as usize];
while let Some(message) = socket_recv(&socket, &mut buffer[..])? {
match mnl::cb_run(message, seq, socket.portid())? {
mnl::CbResult::Stop => {
break;
}
mnl::CbResult::Ok => (),
}
}
Ok(())
}
fn socket_recv<'a>(
socket: &mnl::Socket,
buf: &'a mut [u8],
) -> std::result::Result<Option<&'a [u8]>, Error> {
let ret = socket.recv(buf)?;
if ret > 0 {
Ok(Some(&buf[..ret]))
} else {
Ok(None)
}
}