Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/query/expression/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ pub trait BlockMetaInfo: Debug + Send + Sync + Any + 'static {
"The reason for not implementing clone_self is usually because the higher-level logic doesn't allow/need the associated block to be cloned."
)
}

fn override_block_schema(&self) -> Option<DataSchemaRef> {
None
}
}

pub trait BlockMetaInfoDowncast: Sized + BlockMetaInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use async_channel::Sender;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::BlockMetaInfoDowncast;
use databend_common_expression::BlockMetaInfoPtr;
use databend_common_expression::DataBlock;
use databend_common_pipeline::core::Event;
use databend_common_pipeline::core::InputPort;
Expand Down Expand Up @@ -70,8 +69,8 @@ pub struct SortSampleState<C: BroadcastChannel> {
}

pub trait BroadcastChannel: Clone + Send + 'static {
fn sender(&self) -> Sender<BlockMetaInfoPtr>;
fn receiver(&self) -> Receiver<BlockMetaInfoPtr>;
fn sender(&self) -> Sender<DataBlock>;
fn receiver(&self) -> Receiver<DataBlock>;
}

impl<C: BroadcastChannel> SortSampleState<C> {
Expand All @@ -91,16 +90,16 @@ impl<C: BroadcastChannel> SortSampleState<C> {
let is_empty = meta.is_none();
let meta = meta.map(|meta| meta.boxed()).unwrap_or(().boxed());
sender
.send(meta)
.send(DataBlock::empty_with_meta(meta))
.await
.map_err(|_| ErrorCode::TokioError("send sort bounds failed"))?;
sender.close();
log::debug!(is_empty; "sample has sent");

let receiver = self.channel.receiver();
let mut all = Vec::new();
while let Ok(r) = receiver.recv().await {
match SortExchangeMeta::downcast_from_err(r) {
while let Ok(mut r) = receiver.recv().await {
match SortExchangeMeta::downcast_from_err(r.take_meta().unwrap()) {
Ok(meta) => all.push(meta),
Err(r) => {
debug_assert!(().boxed().equals(&r))
Expand Down
1 change: 0 additions & 1 deletion src/query/service/src/physical_plans/physical_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ impl HashJoin {
let joined_output = OutputPort::create();

let hash_join = TransformHashJoin::create(
self.get_id(),
build_input.clone(),
probe_input.clone(),
joined_output.clone(),
Expand Down
27 changes: 10 additions & 17 deletions src/query/service/src/pipelines/processors/transforms/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use async_channel::Sender;
use databend_common_catalog::table_context::TableContext;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::BlockMetaInfoPtr;
use databend_common_expression::DataBlock;
use databend_common_pipeline::core::InputPort;
use databend_common_pipeline::core::OutputPort;
Expand All @@ -30,13 +29,13 @@ use databend_common_pipeline::sources::AsyncSource;
use databend_common_pipeline::sources::AsyncSourcer;

pub struct BroadcastSourceProcessor {
pub receiver: Receiver<BlockMetaInfoPtr>,
pub receiver: Receiver<DataBlock>,
}

impl BroadcastSourceProcessor {
pub fn create(
ctx: Arc<dyn TableContext>,
receiver: Receiver<BlockMetaInfoPtr>,
receiver: Receiver<DataBlock>,
output_port: Arc<OutputPort>,
) -> Result<ProcessorPtr> {
AsyncSourcer::create(ctx.get_scan_progress(), output_port, Self { receiver })
Expand All @@ -50,23 +49,20 @@ impl AsyncSource for BroadcastSourceProcessor {

#[async_backtrace::framed]
async fn generate(&mut self) -> Result<Option<DataBlock>> {
let received = self.receiver.recv().await;
match received {
Ok(meta) => Ok(Some(DataBlock::empty_with_meta(meta))),
Err(_) => {
// The channel is closed, we should return None to stop generating
Ok(None)
}
match self.receiver.recv().await {
Ok(block) => Ok(Some(block)),
// The channel is closed, we should return None to stop generating
Err(_) => Ok(None),
}
}
}

pub struct BroadcastSinkProcessor {
sender: Sender<BlockMetaInfoPtr>,
sender: Sender<DataBlock>,
}

impl BroadcastSinkProcessor {
pub fn create(input: Arc<InputPort>, sender: Sender<BlockMetaInfoPtr>) -> Result<ProcessorPtr> {
pub fn create(input: Arc<InputPort>, sender: Sender<DataBlock>) -> Result<ProcessorPtr> {
Ok(ProcessorPtr::create(AsyncSinker::create(input, Self {
sender,
})))
Expand All @@ -82,12 +78,9 @@ impl AsyncSink for BroadcastSinkProcessor {
Ok(())
}

async fn consume(&mut self, mut data_block: DataBlock) -> Result<bool> {
let meta = data_block
.take_meta()
.ok_or_else(|| ErrorCode::Internal("Cannot downcast meta to BroadcastMeta"))?;
async fn consume(&mut self, data_block: DataBlock) -> Result<bool> {
self.sender
.send(meta)
.send(data_block)
.await
.map_err(|_| ErrorCode::Internal("BroadcastSinkProcessor send error"))?;
Ok(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::BlockMetaInfoDowncast;

use super::merge::merge_join_runtime_filter_packets;
use super::packet::JoinRuntimeFilterPacket;
Expand All @@ -30,13 +29,13 @@ pub async fn get_global_runtime_filter_packet(
let mut received = vec![];

sender
.send(Box::new(local_packet))
.send(local_packet.try_into()?)
.await
.map_err(|_| ErrorCode::TokioError("send runtime filter shards failed"))?;
sender.close();

while let Ok(r) = receiver.recv().await {
received.push(JoinRuntimeFilterPacket::downcast_from(r).unwrap());
while let Ok(data_block) = receiver.recv().await {
received.push(JoinRuntimeFilterPacket::try_from(data_block)?);
}
merge_join_runtime_filter_packets(received)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@ use std::collections::HashMap;
use std::fmt;
use std::fmt::Debug;

use databend_common_column::buffer::Buffer;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::BlockMetaInfo;
use databend_common_expression::BlockMetaInfoDowncast;
use databend_common_expression::Column;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::DataBlock;
use databend_common_expression::DataSchemaRef;
use databend_common_expression::Scalar;
use databend_common_expression::types::ArrayColumn;
use databend_common_expression::types::NumberColumn;
use databend_common_expression::types::NumberColumnBuilder;
use databend_common_expression::types::array::ArrayColumnBuilder;

use crate::pipelines::processors::transforms::RuntimeFilterDesc;

Expand Down Expand Up @@ -84,15 +94,153 @@ impl JoinRuntimeFilterPacket {
}
}

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Default, PartialEq)]
struct FlightRuntimeFilterPacket {
pub id: usize,
pub bloom: Option<usize>,
pub inlist: Option<usize>,
pub min_max: Option<SerializableDomain>,
}

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Default, PartialEq)]
struct FlightJoinRuntimeFilterPacket {
#[serde(default)]
pub build_rows: usize,
#[serde(default)]
pub packets: Option<HashMap<usize, FlightRuntimeFilterPacket>>,

pub schema: DataSchemaRef,
}

impl TryInto<DataBlock> for JoinRuntimeFilterPacket {
type Error = ErrorCode;

fn try_into(mut self) -> Result<DataBlock> {
let mut entities = vec![];
let mut join_flight_packets = None;

if let Some(packets) = self.packets.take() {
let mut flight_packets = HashMap::with_capacity(packets.len());

for (id, packet) in packets {
let mut inlist_pos = None;
if let Some(in_list) = packet.inlist {
let len = in_list.len() as u64;
inlist_pos = Some(entities.len());
entities.push(Column::Array(Box::new(ArrayColumn::new(
in_list,
Buffer::from(vec![0, len]),
))));
}

let mut bloom_pos = None;
if let Some(bloom_filter) = packet.bloom {
let len = bloom_filter.len() as u64;
bloom_pos = Some(entities.len());

let builder = ArrayColumnBuilder {
builder: ColumnBuilder::Number(NumberColumnBuilder::UInt64(bloom_filter)),
offsets: vec![0, len],
};
entities.push(Column::Array(Box::new(builder.build())));
}

flight_packets.insert(id, FlightRuntimeFilterPacket {
id,
bloom: bloom_pos,
inlist: inlist_pos,
min_max: packet.min_max,
});
}

join_flight_packets = Some(flight_packets);
}

let data_block = match entities.is_empty() {
true => DataBlock::empty(),
false => DataBlock::new_from_columns(entities),
};

let schema = DataSchemaRef::new(data_block.infer_schema());

data_block.add_meta(Some(Box::new(FlightJoinRuntimeFilterPacket {
build_rows: self.build_rows,
packets: join_flight_packets,
schema,
})))
}
}

impl TryFrom<DataBlock> for JoinRuntimeFilterPacket {
type Error = ErrorCode;

fn try_from(mut block: DataBlock) -> Result<Self> {
if let Some(meta) = block.take_meta() {
let flight_join_rf = FlightJoinRuntimeFilterPacket::downcast_from(meta)
.ok_or_else(|| ErrorCode::Internal("It's a bug"))?;

let Some(packet) = flight_join_rf.packets else {
return Ok(JoinRuntimeFilterPacket {
packets: None,
build_rows: flight_join_rf.build_rows,
});
};

let mut flight_packets = HashMap::with_capacity(packet.len());
for (id, flight_packet) in packet {
let mut inlist = None;
if let Some(column_idx) = flight_packet.inlist {
let column = block.get_by_offset(column_idx).clone();
let column = column.into_column().unwrap();
let array_column = column.into_array().expect("it's a bug");
inlist = Some(array_column.index(0).expect("It's a bug"));
}

let mut bloom = None;
if let Some(column_idx) = flight_packet.bloom {
let column = block.get_by_offset(column_idx).clone();
let column = column.into_column().unwrap();
let array_column = column.into_array().expect("it's a bug");
let bloom_value_column = array_column.index(0).expect("It's a bug");
bloom = Some(match bloom_value_column {
Column::Number(NumberColumn::UInt64(v)) => v.to_vec(),
_ => unreachable!("Unexpected runtime bloom filter column type"),
})
}

flight_packets.insert(id, RuntimeFilterPacket {
bloom,
inlist,
id: flight_packet.id,
min_max: flight_packet.min_max,
});
}

return Ok(JoinRuntimeFilterPacket {
packets: Some(flight_packets),
build_rows: flight_join_rf.build_rows,
});
}

Err(ErrorCode::Internal(
"Unexpected runtime filter packet meta type. It's a bug",
))
}
}

#[typetag::serde(name = "join_runtime_filter_packet")]
impl BlockMetaInfo for JoinRuntimeFilterPacket {
impl BlockMetaInfo for FlightJoinRuntimeFilterPacket {
fn equals(&self, info: &Box<dyn BlockMetaInfo>) -> bool {
JoinRuntimeFilterPacket::downcast_ref_from(info).is_some_and(|other| self == other)
FlightJoinRuntimeFilterPacket::downcast_ref_from(info).is_some_and(|other| self == other)
}

fn clone_self(&self) -> Box<dyn BlockMetaInfo> {
Box::new(self.clone())
}

fn override_block_schema(&self) -> Option<DataSchemaRef> {
Some(self.schema.clone())
}
}

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)]
Expand Down
Loading
Loading