Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions rust/benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ struct BenchmarkOpt {
/// Load the data into a MemTable before executing the query
#[structopt(short = "m", long = "mem-table")]
mem_table: bool,

/// Number of partitions to create when using MemTable as input
#[structopt(short = "n", long = "partitions", default_value = "8")]
partitions: usize,
}

#[derive(Debug, StructOpt)]
Expand Down Expand Up @@ -134,8 +138,12 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<Vec<arrow::record_batch::RecordB
println!("Loading table '{}' into memory", table);
let start = Instant::now();

let memtable =
MemTable::load(table_provider.as_ref(), opt.batch_size).await?;
let memtable = MemTable::load(
table_provider.as_ref(),
opt.batch_size,
Some(opt.partitions),
)
.await?;
println!(
"Loaded table '{}' into memory in {} ms",
table,
Expand Down Expand Up @@ -1589,6 +1597,7 @@ mod tests {
path: PathBuf::from(path.to_string()),
file_format: "tbl".to_string(),
mem_table: false,
partitions: 16,
};
let actual = benchmark(opt).await?;

Expand Down
7 changes: 6 additions & 1 deletion rust/datafusion/benches/sort_limit_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ fn create_context() -> Arc<Mutex<ExecutionContext>> {

let ctx_holder: Arc<Mutex<Vec<Arc<Mutex<ExecutionContext>>>>> =
Arc::new(Mutex::new(vec![]));

let partitions = 16;

rt.block_on(async {
let mem_table = MemTable::load(&csv, 16 * 1024).await.unwrap();
let mem_table = MemTable::load(&csv, 16 * 1024, Some(partitions))
.await
.unwrap();

// create local execution context
let mut ctx = ExecutionContext::new();
Expand Down
34 changes: 32 additions & 2 deletions rust/datafusion/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
//! queried by DataFusion. This allows data to be pre-loaded into memory and then
//! repeatedly queried without incurring additional file I/O overhead.

use futures::StreamExt;
use log::debug;
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;

use crate::datasource::datasource::Statistics;
use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Expr;
use crate::physical_plan::common;
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::ExecutionPlan;
use crate::{
datasource::datasource::Statistics,
physical_plan::{repartition::RepartitionExec, Partitioning},
};

use super::datasource::ColumnStatistics;

Expand Down Expand Up @@ -102,7 +106,11 @@ impl MemTable {
}

/// Create a mem table by reading from another data source
pub async fn load(t: &dyn TableProvider, batch_size: usize) -> Result<Self> {
pub async fn load(
t: &dyn TableProvider,
batch_size: usize,
output_partitions: Option<usize>,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(&None, batch_size, &[])?;
let partition_count = exec.output_partitioning().partition_count();
Expand All @@ -126,6 +134,28 @@ impl MemTable {
data.push(result);
}

let exec = MemoryExec::try_new(&data, schema.clone(), None)?;

if let Some(num_partitions) = output_partitions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let exec = RepartitionExec::try_new(
Arc::new(exec),
Partitioning::RoundRobinBatch(num_partitions),
)?;

// execute and collect results
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
let mut stream = exec.execute(i).await?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}

return MemTable::try_new(schema.clone(), output_partitions);
}
MemTable::try_new(schema.clone(), data)
}
}
Expand Down