Skip to content

Commit

Permalink
WIP: Bulk loads
Browse files Browse the repository at this point in the history
  • Loading branch information
Julius de Bruijn committed Jul 27, 2021
1 parent 59248b4 commit d58ce24
Show file tree
Hide file tree
Showing 20 changed files with 929 additions and 123 deletions.
40 changes: 40 additions & 0 deletions examples/bulk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use once_cell::sync::Lazy;
use std::env;
use tiberius::{BulkLoadRequest, Client, Config, IntoSql, TokenRow, TypeInfo};
use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncWriteCompatExt;

static CONN_STR: Lazy<String> = Lazy::new(|| {
env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or_else(|_| {
"server=tcp:localhost,1433;IntegratedSecurity=true;TrustServerCertificate=true".to_owned()
})
});

#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();

let config = Config::from_ado_string(&CONN_STR)?;

let tcp = TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;

let mut client = Client::connect(config, tcp.compat_write()).await?;

let mut req = BulkLoadRequest::new();
req.add_column("val", TypeInfo::int());

for i in [0, 1, 2] {
req.add_row({
let mut row = TokenRow::new();
row.push(i.into_sql());
row
});
}

let res = client.bulk_insert("bulk_test1", req).await?;

dbg!(res);

Ok(())
}
35 changes: 34 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ pub(crate) use connection::*;

use crate::{
result::{ExecuteResult, QueryResult},
tds::{codec, stream::TokenStream},
tds::{
codec::{self, BulkLoadRequest, IteratorJoin},
stream::TokenStream,
},
SqlReadBytes, ToSql,
};
use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
Expand Down Expand Up @@ -228,6 +231,36 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
Ok(result)
}

/// TODO
pub async fn bulk_insert<'a>(
&'a mut self,
table: &str,
request: BulkLoadRequest<'a>,
) -> crate::Result<ExecuteResult> {
// Start the bulk request
self.connection.flush_stream().await?;

let col_data = request.column_descriptions().join(", ");
let query = format!("INSERT BULK {} ({})", table, col_data);

let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
let id = self.connection.context_mut().next_packet_id();

self.connection.send(PacketHeader::batch(id), req).await?;

let ts = TokenStream::new(&mut self.connection);
ts.flush_done().await?;

// Send the metadata
let id = self.connection.context_mut().next_packet_id();

self.connection
.send(PacketHeader::bulk_load(id), request)
.await?;

Ok(ExecuteResult::new(&mut self.connection).await?)
}

fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
vec![
RpcParam {
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ pub use from_sql::{FromSql, FromSqlOwned};
pub use result::*;
pub use row::{Column, ColumnType, Row};
pub use sql_browser::SqlBrowser;
pub use tds::{codec::ColumnData, numeric, time, xml, EncryptionLevel};
pub use tds::{
codec::{BulkLoadRequest, ColumnData, TokenRow, TypeInfo, TypeLength},
numeric, time, xml, EncryptionLevel,
};
pub use to_sql::{IntoSql, ToSql};
pub use uuid::Uuid;

Expand Down
1 change: 1 addition & 0 deletions src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ impl<'a> ExecuteResult {
ReceivedToken::DoneProc(done) if done.is_final() => (),
ReceivedToken::DoneProc(done) => acc.push(done.rows()),
ReceivedToken::DoneInProc(done) => acc.push(done.rows()),
ReceivedToken::Done(done) => acc.push(done.rows()),
_ => (),
}
Ok(acc)
Expand Down
14 changes: 7 additions & 7 deletions src/row.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::Error,
tds::codec::{ColumnData, FixedLenType, TokenRow, TypeInfo, VarLenType},
tds::codec::{ColumnData, FixedLenType, TokenRow, TypeInfo, TypeInfoInner, VarLenType},
FromSql,
};
use std::{fmt::Display, sync::Arc};
Expand Down Expand Up @@ -101,8 +101,8 @@ pub enum ColumnType {

impl From<&TypeInfo> for ColumnType {
fn from(ti: &TypeInfo) -> Self {
match ti {
TypeInfo::FixedLen(flt) => match flt {
match &ti.inner {
TypeInfoInner::FixedLen(flt) => match flt {
FixedLenType::Int1 => Self::Int1,
FixedLenType::Bit => Self::Bit,
FixedLenType::Int2 => Self::Int2,
Expand All @@ -116,7 +116,7 @@ impl From<&TypeInfo> for ColumnType {
FixedLenType::Int8 => Self::Int8,
FixedLenType::Null => Self::Null,
},
TypeInfo::VarLenSized(cx) => match cx.r#type() {
TypeInfoInner::VarLenSized(cx) => match cx.r#type() {
VarLenType::Guid => Self::Guid,
VarLenType::Intn => Self::Intn,
VarLenType::Bitn => Self::Bitn,
Expand Down Expand Up @@ -146,7 +146,7 @@ impl From<&TypeInfo> for ColumnType {
VarLenType::NText => Self::NText,
VarLenType::SSVariant => Self::SSVariant,
},
TypeInfo::VarLenSizedPrecision { ty, .. } => match ty {
TypeInfoInner::VarLenSizedPrecision { ty, .. } => match ty {
VarLenType::Guid => Self::Guid,
VarLenType::Intn => Self::Intn,
VarLenType::Bitn => Self::Bitn,
Expand Down Expand Up @@ -176,7 +176,7 @@ impl From<&TypeInfo> for ColumnType {
VarLenType::NText => Self::NText,
VarLenType::SSVariant => Self::SSVariant,
},
TypeInfo::Xml { .. } => Self::Xml,
TypeInfoInner::Xml { .. } => Self::Xml,
}
}
}
Expand Down Expand Up @@ -233,7 +233,7 @@ impl From<&TypeInfo> for ColumnType {
#[derive(Debug)]
pub struct Row {
pub(crate) columns: Arc<Vec<Column>>,
pub(crate) data: TokenRow,
pub(crate) data: TokenRow<'static>,
}

pub trait QueryIdx
Expand Down
4 changes: 4 additions & 0 deletions src/tds/codec.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod batch_request;
mod bulk_load;
mod column_data;
mod decode;
mod encode;
mod guid;
mod header;
mod iterator_ext;
mod login;
mod packet;
mod pre_login;
Expand All @@ -12,12 +14,14 @@ mod token;
mod type_info;

pub use batch_request::*;
pub use bulk_load::*;
use bytes::BytesMut;
pub use column_data::*;
pub use decode::*;
pub(crate) use encode::*;
use futures::{Stream, TryStreamExt};
pub use header::*;
pub(crate) use iterator_ext::*;
pub use login::*;
pub use packet::*;
pub use pre_login::*;
Expand Down
59 changes: 59 additions & 0 deletions src/tds/codec/bulk_load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use asynchronous_codec::BytesMut;
use enumflags2::BitFlags;

use super::{
BaseMetaDataColumn, Encode, MetaDataColumn, TokenColMetaData, TokenDone, TokenRow, TypeInfo,
};

/// Column metadata for a bulk load request.
#[derive(Debug, Default, Clone)]
pub struct BulkLoadRequest<'a> {
columns: Vec<MetaDataColumn>,
rows: Vec<TokenRow<'a>>,
}

impl<'a> BulkLoadRequest<'a> {
/// Creates a metadata with no columns specified.
pub fn new() -> Self {
Self::default()
}

/// Add a column to the request. Order should be same as the order of data
/// in the rows.
pub fn add_column(&mut self, name: &str, ty: TypeInfo) {
self.columns.push(MetaDataColumn {
base: BaseMetaDataColumn {
flags: BitFlags::empty(),
ty,
},
col_name: name.into(),
});
}

/// Add a row of data to the request.
pub fn add_row(&mut self, row: TokenRow<'a>) {
self.rows.push(row);
}

pub(crate) fn column_descriptions(&'a self) -> impl Iterator<Item = String> + 'a {
self.columns.iter().map(|c| format!("{}", c))
}
}

impl<'a> Encode<BytesMut> for BulkLoadRequest<'a> {
fn encode(self, dst: &mut BytesMut) -> crate::Result<()> {
let cmd = TokenColMetaData {
columns: self.columns,
};

cmd.encode(dst)?;

for row in self.rows.into_iter() {
row.encode(dst)?;
}

TokenDone::default().encode(dst)?;

Ok(())
}
}
Loading

0 comments on commit d58ce24

Please sign in to comment.