Skip to content

Commit dfa6101

Browse files
committed
rust/dns: example of catching rust panics
Wrap DNS probing/parsing in catch_unwind as an example how to gracefully handle panics from Rust.
1 parent e717c2e commit dfa6101

File tree

1 file changed

+55
-32
lines changed

1 file changed

+55
-32
lines changed

rust/src/dns/dns.rs

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std;
1919
use std::collections::HashMap;
2020
use std::collections::VecDeque;
2121
use std::ffi::CString;
22+
use std::panic::catch_unwind;
2223

2324
use crate::applayer::*;
2425
use crate::core::{self, *};
@@ -769,8 +770,10 @@ pub unsafe extern "C" fn rs_dns_parse_request(
769770
flow: *const core::Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
770771
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
771772
) -> AppLayerResult {
772-
let state = cast_pointer!(state, DNSState);
773-
state.parse_request_udp(flow, stream_slice);
773+
let _ = catch_unwind(|| {
774+
let state = cast_pointer!(state, DNSState);
775+
state.parse_request_udp(flow, stream_slice);
776+
});
774777
AppLayerResult::ok()
775778
}
776779

@@ -779,8 +782,10 @@ pub unsafe extern "C" fn rs_dns_parse_response(
779782
flow: *const core::Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
780783
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
781784
) -> AppLayerResult {
782-
let state = cast_pointer!(state, DNSState);
783-
state.parse_response_udp(flow, stream_slice);
785+
let _ = catch_unwind(|| {
786+
let state = cast_pointer!(state, DNSState);
787+
state.parse_response_udp(flow, stream_slice);
788+
});
784789
AppLayerResult::ok()
785790
}
786791

@@ -790,27 +795,40 @@ pub unsafe extern "C" fn rs_dns_parse_request_tcp(
790795
flow: *const core::Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
791796
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
792797
) -> AppLayerResult {
793-
let state = cast_pointer!(state, DNSState);
794-
if stream_slice.is_gap() {
795-
state.request_gap(stream_slice.gap_size());
796-
} else if !stream_slice.is_empty() {
797-
return state.parse_request_tcp(flow, stream_slice);
798+
match catch_unwind(|| {
799+
let state = cast_pointer!(state, DNSState);
800+
if stream_slice.is_gap() {
801+
state.request_gap(stream_slice.gap_size());
802+
} else if !stream_slice.is_empty() {
803+
return state.parse_request_tcp(flow, stream_slice);
804+
}
805+
return AppLayerResult::ok();
806+
}) {
807+
Ok(result) => result,
808+
Result::Err(_) => AppLayerResult::err(),
798809
}
799-
AppLayerResult::ok()
800810
}
801811

802812
#[no_mangle]
803813
pub unsafe extern "C" fn rs_dns_parse_response_tcp(
804814
flow: *const core::Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
805815
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
806816
) -> AppLayerResult {
807-
let state = cast_pointer!(state, DNSState);
808-
if stream_slice.is_gap() {
809-
state.response_gap(stream_slice.gap_size());
810-
} else if !stream_slice.is_empty() {
811-
return state.parse_response_tcp(flow, stream_slice);
817+
match catch_unwind(|| {
818+
let state = cast_pointer!(state, DNSState);
819+
if stream_slice.is_gap() {
820+
state.response_gap(stream_slice.gap_size());
821+
} else if !stream_slice.is_empty() {
822+
return state.parse_response_tcp(flow, stream_slice);
823+
}
824+
return AppLayerResult::ok();
825+
}) {
826+
Ok(result) => result,
827+
Result::Err(_) => {
828+
println!("caught error!");
829+
return AppLayerResult::err();
830+
}
812831
}
813-
AppLayerResult::ok()
814832
}
815833

816834
#[no_mangle]
@@ -938,24 +956,29 @@ pub unsafe extern "C" fn rs_dns_probe(
938956
pub unsafe extern "C" fn rs_dns_probe_tcp(
939957
_flow: *const core::Flow, direction: u8, input: *const u8, len: u32, rdir: *mut u8,
940958
) -> AppProto {
941-
if len == 0 || len < std::mem::size_of::<DNSHeader>() as u32 + 2 {
942-
return core::ALPROTO_UNKNOWN;
943-
}
944-
let slice: &[u8] = std::slice::from_raw_parts(input as *mut u8, len as usize);
945-
//is_incomplete is checked by caller
946-
let (is_dns, is_request, _) = probe_tcp(slice);
947-
if is_dns {
948-
let dir = if is_request {
949-
Direction::ToServer
950-
} else {
951-
Direction::ToClient
952-
};
953-
if (direction & DIR_BOTH) != dir.into() {
954-
*rdir = dir as u8;
959+
match catch_unwind(|| {
960+
if len == 0 || len < std::mem::size_of::<DNSHeader>() as u32 + 2 {
961+
return core::ALPROTO_UNKNOWN;
955962
}
956-
return ALPROTO_DNS;
963+
let slice: &[u8] = std::slice::from_raw_parts(input as *mut u8, len as usize);
964+
//is_incomplete is checked by caller
965+
let (is_dns, is_request, _) = probe_tcp(slice);
966+
if is_dns {
967+
let dir = if is_request {
968+
Direction::ToServer
969+
} else {
970+
Direction::ToClient
971+
};
972+
if (direction & DIR_BOTH) != dir.into() {
973+
*rdir = dir as u8;
974+
}
975+
return ALPROTO_DNS;
976+
}
977+
return 0;
978+
}) {
979+
Ok(proto) => proto,
980+
Result::Err(_) => ALPROTO_UNKNOWN,
957981
}
958-
return 0;
959982
}
960983

961984
#[no_mangle]

0 commit comments

Comments
 (0)