Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions src/common/base/src/base/barrier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Mutex;
use std::sync::PoisonError;

use tokio::sync::watch;

#[derive(Debug)]
struct BarrierState {
waker: watch::Sender<usize>,
arrived: usize,
generation: usize,

n: usize,
}

pub struct Barrier {
state: Mutex<BarrierState>,
wait: watch::Receiver<usize>,
}

impl Barrier {
pub fn new(mut n: usize) -> Barrier {
let (waker, wait) = watch::channel(0);

if n == 0 {
n = 1;
}

Barrier {
state: Mutex::new(BarrierState {
n,
waker,
arrived: 0,
generation: 1,
}),
wait,
}
}

pub async fn wait(&self) -> BarrierWaitResult {
let (generation, is_leader) = {
let locked = self.state.lock();
let mut state = locked.unwrap_or_else(PoisonError::into_inner);

let is_leader = state.arrived == 0;
let generation = state.generation;
state.arrived += 1;

if state.arrived == state.n {
state
.waker
.send(state.generation)
.expect("there is at least one receiver");
state.arrived = 0;
state.generation += 1;
return BarrierWaitResult(is_leader);
}

(generation, is_leader)
};

let mut wait = self.wait.clone();

loop {
let _ = wait.changed().await;

if *wait.borrow() >= generation {
break;
}
}

BarrierWaitResult(is_leader)
}

pub fn reduce_quorum(&self, n: usize) {
let locked = self.state.lock();
let mut state = locked.unwrap_or_else(PoisonError::into_inner);
state.n -= n;

if state.arrived >= state.n {
state
.waker
.send(state.generation)
.expect("there is at least one receiver");
state.arrived = 0;
state.generation += 1;
}
}
}

#[derive(Debug, Clone)]
pub struct BarrierWaitResult(bool);

impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.0
}
}
2 changes: 2 additions & 0 deletions src/common/base/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod barrier;
mod build_info;
mod dma;
mod drop_callback;
Expand All @@ -30,6 +31,7 @@ mod take_mut;
mod uniq_id;
mod watch_notify;

pub use barrier::Barrier;
pub use build_info::*;
pub use dma::*;
pub use drop_callback::DropCallback;
Expand Down
14 changes: 12 additions & 2 deletions src/query/service/src/physical_plans/physical_cache_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,19 @@ impl IPhysicalPlan for CacheScan {
max_block_size,
))
}
Some(HashJoinStateRef::NewHashJoinState(hash_join_state)) => {
Some(HashJoinStateRef::NewHashJoinState(hash_join_state, column_map)) => {
let mut column_offsets = Vec::with_capacity(column_indexes.len());
for index in column_indexes {
let Some(offset) = column_map.get(index) else {
return Err(ErrorCode::Internal(format!(
"Hash join cache column {} not found in build projection",
index
)));
};
column_offsets.push(*offset);
}
CacheSourceState::NewHashJoinCacheState(NewHashJoinCacheState::new(
column_indexes.clone(),
column_offsets,
hash_join_state.clone(),
))
}
Expand Down
29 changes: 22 additions & 7 deletions src/query/service/src/physical_plans/physical_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use crate::pipelines::processors::transforms::RuntimeFiltersDesc;
use crate::pipelines::processors::transforms::TransformHashJoin;
use crate::pipelines::processors::transforms::TransformHashJoinBuild;
use crate::pipelines::processors::transforms::TransformHashJoinProbe;
use crate::sessions::QueryContext;

// Type aliases to simplify complex return types
type JoinConditionsResult = (
Expand Down Expand Up @@ -270,19 +271,29 @@ impl IPhysicalPlan for HashJoin {
let (enable_optimization, _) = builder.merge_into_get_optimization_flag(self);

if desc.single_to_inner.is_none()
&& (self.join_type == JoinType::Inner || self.join_type == JoinType::Left)
&& matches!(
self.join_type,
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti
)
&& experimental_new_join
&& !enable_optimization
&& !self.need_hold_hash_table
{
return self.build_new_join_pipeline(builder, desc);
}

// Create the join state with optimization flags
let state = self.build_state(builder)?;

if let Some((build_cache_index, _)) = self.build_side_cache_info {
if let Some((build_cache_index, _)) = &self.build_side_cache_info {
builder.hash_join_states.insert(
build_cache_index,
*build_cache_index,
HashJoinStateRef::OldHashJoinState(state.clone()),
);
}
Expand Down Expand Up @@ -413,15 +424,18 @@ impl HashJoin {
{
let state = factory.create_basic_state(0)?;

if let Some((build_cache_index, _)) = self.build_side_cache_info {
if let Some((build_cache_index, column_map)) = &self.build_side_cache_info {
builder.hash_join_states.insert(
build_cache_index,
HashJoinStateRef::NewHashJoinState(state.clone()),
*build_cache_index,
HashJoinStateRef::NewHashJoinState(state.clone(), column_map.clone()),
);
}
}

let mut sub_query_ctx = QueryContext::create_from(&builder.ctx);
std::mem::swap(&mut builder.ctx, &mut sub_query_ctx);
self.build.build_pipeline(builder)?;
std::mem::swap(&mut builder.ctx, &mut sub_query_ctx);
let mut build_sinks = builder.main_pipeline.take_sinks();

self.probe.build_pipeline(builder)?;
Expand All @@ -440,7 +454,8 @@ impl HashJoin {

debug_assert_eq!(build_sinks.len(), probe_sinks.len());

let stage_sync_barrier = Arc::new(Barrier::new(output_len));
let barrier = databend_common_base::base::Barrier::new(output_len);
let stage_sync_barrier = Arc::new(barrier);
let mut join_sinks = Vec::with_capacity(output_len * 2);
let mut join_pipe_items = Vec::with_capacity(output_len);
for (build_sink, probe_sink) in build_sinks.into_iter().zip(probe_sinks.into_iter()) {
Expand Down
3 changes: 2 additions & 1 deletion src/query/service/src/pipelines/pipeline_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use databend_common_pipeline::core::ExecutionInfo;
use databend_common_pipeline::core::Pipeline;
use databend_common_pipeline::core::always_callback;
use databend_common_settings::Settings;
use databend_common_sql::IndexType;

use super::PipelineBuilderData;
use crate::interpreters::CreateTableInterpreter;
Expand All @@ -38,7 +39,7 @@ use crate::sessions::QueryContext;
#[derive(Clone)]
pub enum HashJoinStateRef {
OldHashJoinState(Arc<HashJoinState>),
NewHashJoinState(Arc<BasicHashJoinState>),
NewHashJoinState(Arc<BasicHashJoinState>, HashMap<IndexType, usize>),
}

pub struct PipelineBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,14 @@ pub fn wrap_true_validity(
NullableColumn::new_column(col, validity).into()
}
}

pub fn wrap_nullable_block(input: &DataBlock) -> DataBlock {
let input_num_rows = input.num_rows();
let true_validity = Bitmap::new_constant(true, input_num_rows);
let nullable_columns = input
.columns()
.iter()
.map(|c| wrap_true_validity(c, input_num_rows, &true_validity))
.collect::<Vec<_>>();
DataBlock::new(nullable_columns, input_num_rows)
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub struct HashJoinDesc {
pub(crate) probe_projections: ColumnSet,
pub(crate) probe_to_build: Vec<(usize, (bool, bool))>,
pub(crate) build_schema: DataSchemaRef,
pub(crate) probe_schema: DataSchemaRef,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -138,6 +139,7 @@ impl HashJoinDesc {
build_projection: join.build_projections.clone(),
probe_projections: join.probe_projections.clone(),
build_schema: join.build.output_schema()?,
probe_schema: join.probe.output_schema()?,
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mod transform_hash_join_build;
mod transform_hash_join_probe;
mod util;

pub use common::wrap_nullable_block;
pub use common::wrap_true_validity;
pub use desc::HashJoinDesc;
pub use desc::RuntimeFilterDesc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ use crate::pipelines::processors::transforms::BasicHashJoinState;
use crate::pipelines::processors::transforms::HashJoinHashTable;
use crate::pipelines::processors::transforms::InnerHashJoin;
use crate::pipelines::processors::transforms::Join;
use crate::pipelines::processors::transforms::memory::outer_left_join::OuterLeftHashJoin;
use crate::pipelines::processors::transforms::memory::AntiLeftHashJoin;
use crate::pipelines::processors::transforms::memory::AntiRightHashJoin;
use crate::pipelines::processors::transforms::memory::OuterRightHashJoin;
use crate::pipelines::processors::transforms::memory::SemiLeftHashJoin;
use crate::pipelines::processors::transforms::memory::SemiRightHashJoin;
use crate::pipelines::processors::transforms::memory::left_join::OuterLeftHashJoin;

pub trait GraceMemoryJoin: Join {
fn reset_memory(&mut self);
Expand Down Expand Up @@ -52,6 +57,14 @@ fn reset_basic_state(state: &BasicHashJoinState) {
state.build_queue.as_mut().clear();
}

if !state.scan_map.is_empty() {
state.scan_map.as_mut().clear();
}

if !state.scan_queue.is_empty() {
state.scan_queue.as_mut().clear();
}

*state.hash_table.as_mut() = HashJoinHashTable::Null;
}

Expand All @@ -68,3 +81,41 @@ impl GraceMemoryJoin for OuterLeftHashJoin {
reset_basic_state(&self.basic_state);
}
}

impl GraceMemoryJoin for SemiLeftHashJoin {
fn reset_memory(&mut self) {
self.performance_context.clear();
reset_basic_state(&self.basic_state);
}
}

impl GraceMemoryJoin for AntiLeftHashJoin {
fn reset_memory(&mut self) {
self.performance_context.clear();
reset_basic_state(&self.basic_state);
}
}

impl GraceMemoryJoin for OuterRightHashJoin {
fn reset_memory(&mut self) {
self.finished = false;
self.performance_context.clear();
reset_basic_state(&self.basic_state);
}
}

impl GraceMemoryJoin for SemiRightHashJoin {
fn reset_memory(&mut self) {
self.finished = false;
self.performance_context.clear();
reset_basic_state(&self.basic_state);
}
}

impl GraceMemoryJoin for AntiRightHashJoin {
fn reset_memory(&mut self) {
self.finished = false;
self.performance_context.clear();
reset_basic_state(&self.basic_state);
}
}
Loading