Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions rust/src/mqtt/mqtt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::applayer::*;
use crate::applayer::{self, LoggerFlags};
use crate::conf::conf_get;
use crate::core::*;
use crate::frames::*;
use nom7::Err;
use std;
use std::collections::VecDeque;
Expand All @@ -41,6 +42,13 @@ static mut MQTT_MAX_TX: usize = 1024;

static mut ALPROTO_MQTT: AppProto = ALPROTO_UNKNOWN;

#[derive(AppLayerFrameType)]
pub enum MQTTFrameType {
Pdu,
Header,
Data,
}

#[derive(FromPrimitive, Debug, AppLayerEvent)]
pub enum MQTTEvent {
MissingConnect,
Expand Down Expand Up @@ -422,8 +430,10 @@ impl MQTTState {
}
}

fn parse_request(&mut self, input: &[u8]) -> AppLayerResult {
fn parse_request(&mut self, flow: *const Flow, stream_slice: StreamSlice) -> AppLayerResult {
let input = stream_slice.as_slice();
let mut current = input;

if input.is_empty() {
return AppLayerResult::ok();
}
Expand Down Expand Up @@ -455,6 +465,13 @@ impl MQTTState {
SCLogDebug!("request: handling {}", current.len());
match parse_message(current, self.protocol_version, self.max_msg_len) {
Ok((rem, msg)) => {
let _pdu = Frame::new(
flow,
&stream_slice,
input,
current.len() as i64,
MQTTFrameType::Pdu as u8,
);
SCLogDebug!("request msg {:?}", msg);
if let MQTTOperation::TRUNCATED(ref trunc) = msg.op {
SCLogDebug!(
Expand All @@ -474,6 +491,8 @@ impl MQTTState {
continue;
}
}

self.mqtt_hdr_and_data_frames(flow, &stream_slice, &msg);
self.handle_msg(msg, false);
consumed += current.len() - rem.len();
current = rem;
Expand All @@ -497,8 +516,10 @@ impl MQTTState {
return AppLayerResult::ok();
}

fn parse_response(&mut self, input: &[u8]) -> AppLayerResult {
fn parse_response(&mut self, flow: *const Flow, stream_slice: StreamSlice) -> AppLayerResult {
let input = stream_slice.as_slice();
let mut current = input;

if input.is_empty() {
return AppLayerResult::ok();
}
Expand Down Expand Up @@ -529,6 +550,14 @@ impl MQTTState {
SCLogDebug!("response: handling {}", current.len());
match parse_message(current, self.protocol_version, self.max_msg_len) {
Ok((rem, msg)) => {
let _pdu = Frame::new(
flow,
&stream_slice,
input,
input.len() as i64,
MQTTFrameType::Pdu as u8,
);

SCLogDebug!("response msg {:?}", msg);
if let MQTTOperation::TRUNCATED(ref trunc) = msg.op {
SCLogDebug!(
Expand All @@ -549,6 +578,8 @@ impl MQTTState {
continue;
}
}

self.mqtt_hdr_and_data_frames(flow, &stream_slice, &msg);
self.handle_msg(msg, true);
consumed += current.len() - rem.len();
current = rem;
Expand Down Expand Up @@ -589,6 +620,25 @@ impl MQTTState {
tx.tx_data.set_event(event as u8);
self.transactions.push_back(tx);
}

fn mqtt_hdr_and_data_frames(
&mut self, flow: *const Flow, stream_slice: &StreamSlice, input: &MQTTMessage,
) {
let hdr = stream_slice.as_slice();
//MQTT payload has a fixed header of 2 bytes
let _mqtt_hdr = Frame::new(flow, stream_slice, hdr, 2, MQTTFrameType::Header as u8);
SCLogDebug!("mqtt_hdr Frame {:?}", _mqtt_hdr);
let rem_length = input.header.remaining_length as usize;
let data = &hdr[2..rem_length + 2];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::Data as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
}
}

// C exports.
Expand Down Expand Up @@ -637,20 +687,20 @@ pub unsafe extern "C" fn rs_mqtt_state_tx_free(state: *mut std::os::raw::c_void,

#[no_mangle]
pub unsafe extern "C" fn rs_mqtt_parse_request(
_flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, MQTTState);
return state.parse_request(stream_slice.as_slice());
return state.parse_request(flow, stream_slice);
}

#[no_mangle]
pub unsafe extern "C" fn rs_mqtt_parse_response(
_flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, MQTTState);
return state.parse_response(stream_slice.as_slice());
return state.parse_response(flow, stream_slice);
}

#[no_mangle]
Expand Down Expand Up @@ -761,8 +811,8 @@ pub unsafe extern "C" fn rs_mqtt_register_parser(cfg_max_msg_len: u32) {
apply_tx_config: None,
flags: APP_LAYER_PARSER_OPT_UNIDIR_TXS,
truncate: None,
get_frame_id_by_name: None,
get_frame_name_by_id: None,
get_frame_id_by_name: Some(MQTTFrameType::ffi_id_from_name),
get_frame_name_by_id: Some(MQTTFrameType::ffi_name_from_id),
};

let ip_proto_str = CString::new("tcp").unwrap();
Expand Down