Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Bulk loads #163

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ path = "./runtimes-macro"
names = "0.11"
anyhow = "1"
env_logger = "0.7"
indicatif = "0.16"
console = "0.14"
paste = "1.0"

[package.metadata.docs.rs]
features = ["all", "docs"]
Expand Down
54 changes: 54 additions & 0 deletions examples/bulk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use indicatif::ProgressBar;
use once_cell::sync::Lazy;
use std::env;
use tiberius::{BulkLoadMetadata, Client, ColumnFlag, Config, IntoSql, TokenRow, TypeInfo};
use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncWriteCompatExt;

static CONN_STR: Lazy<String> = Lazy::new(|| {
env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or_else(|_| {
"server=tcp:localhost,1433;IntegratedSecurity=true;TrustServerCertificate=true".to_owned()
})
});

#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();

let config = Config::from_ado_string(&CONN_STR)?;

let tcp = TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;

let mut client = Client::connect(config, tcp.compat_write()).await?;

client
.execute(
"CREATE TABLE ##bulk_test1 (id INT IDENTITY PRIMARY KEY, content VARCHAR(255))",
&[],
)
.await?;

let mut meta = BulkLoadMetadata::new();
meta.add_column("content", TypeInfo::int(), ColumnFlag::Nullable.into());

let mut req = client.bulk_insert("##bulk_test1", meta).await?;
let count = 2000i32;

let pb = ProgressBar::new(count as u64);

for i in vec!["aaaaaaaaaaaaaaaaaaaa"; 1000].into_iter() {
let mut row = TokenRow::new();
row.push(i.into_sql());
req.send(row).await?;
pb.inc(1);
}

pb.finish_with_message("waiting...");

let res = req.finalize().await?;

dbg!(res);

Ok(())
}
28 changes: 3 additions & 25 deletions examples/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,20 @@ static CONN_STR: Lazy<String> = Lazy::new(|| {
})
});

#[cfg(not(all(windows, feature = "sql-browser-tokio")))]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
let config = Config::from_ado_string(&CONN_STR)?;

let tcp = TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;

let mut client = Client::connect(config, tcp.compat_write()).await?;

let stream = client.query("SELECT @P1", &[&1i32]).await?;
let row = stream.into_row().await?.unwrap();
let stream = client.query("SELECT * from test", &[]).await?;
let row = stream.into_row().await?;

println!("{:?}", row);
assert_eq!(Some(1), row.get(0));

Ok(())
}

#[cfg(all(windows, feature = "sql-browser-tokio"))]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
use tiberius::SqlBrowser;

let config = Config::from_ado_string(&CONN_STR)?;

let tcp = TcpStream::connect_named(&config).await?;
tcp.set_nodelay(true)?;

let mut client = Client::connect(config, tcp.compat_write()).await?;

let stream = client.query("SELECT @P1", &[&1i32]).await?;
let row = stream.into_row().await?.unwrap();

println!("{:?}", row);
assert_eq!(Some(1), row.get(0));

Ok(())
}
75 changes: 73 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ pub(crate) use connection::*;
use crate::{
result::ExecuteResult,
tds::{
codec,
codec::{self, IteratorJoin},
stream::{QueryStream, TokenStream},
},
SqlReadBytes, ToSql,
BulkLoadMetadata, BulkLoadRequest, SqlReadBytes, ToSql,
};
use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
use enumflags2::BitFlags;
Expand Down Expand Up @@ -230,6 +230,77 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
Ok(result)
}

/// Execute a `BULK INSERT` statement, efficiantly storing a large number of
/// rows to a specified table.
///
/// # Example
///
/// ```
/// # use tiberius::{Config, BulkLoadMetadata, TypeInfo, TokenRow, IntoSql};
/// # use tokio_util::compat::TokioAsyncWriteCompatExt;
/// # use std::env;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
/// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
/// # );
/// # let config = Config::from_ado_string(&c_str)?;
/// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
/// # tcp.set_nodelay(true)?;
/// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
/// let create_table = r#"
/// CREATE TABLE ##bulk_test (
/// id INT IDENTITY PRIMARY KEY,
/// val INT NOT NULL
/// )
/// "#;
///
/// client.simple_query(create_table).await?;
///
/// // The request must have correct typing.
/// let mut meta = BulkLoadMetadata::new();
/// meta.add_column("val", TypeInfo::int());
///
/// // Start the bulk insert with the client.
/// let mut req = client.bulk_insert("##bulk_test", meta).await?;
///
/// for i in [0i32, 1i32, 2i32] {
/// let mut row = TokenRow::new();
/// row.push(i.into_sql());
///
/// // The request will handle flushing to the wire in an optimal way,
/// // balancing between memory usage and IO performance.
/// req.send(row).await?;
/// }
///
/// // The request must be finalized.
/// let res = req.finalize().await?;
/// assert_eq!(3, res.total());
/// # Ok(())
/// # }
/// ```
pub async fn bulk_insert<'a>(
&'a mut self,
table: &str,
meta: BulkLoadMetadata<'a>,
) -> crate::Result<BulkLoadRequest<'a, S>> {
// Start the bulk request
self.connection.flush_stream().await?;

let col_data = meta.column_descriptions().join(", ");
let query = format!("INSERT BULK {} ({})", table, col_data);

let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
let id = self.connection.context_mut().next_packet_id();

self.connection.send(PacketHeader::batch(id), req).await?;

let ts = TokenStream::new(&mut self.connection);
ts.flush_done().await?;

BulkLoadRequest::new(&mut self.connection, meta)
}

fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
vec![
RpcParam {
Expand Down
30 changes: 26 additions & 4 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,38 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
split_payload.len() + HEADER_BYTES,
);

let packet = Packet::new(header, split_payload);
self.transport.send(packet).await?;
self.write_to_wire(header, split_payload).await?;
}

// Rai rai says the turbofish goodbye
SinkExt::<Packet>::flush(&mut self.transport).await?;
self.flush_sink().await?;

Ok(())
}

/// Sends a packet of data to the database.
///
/// # Warning
///
/// Please be sure the packet size doesn't exceed the largest allowed size
/// dictaded by the server.
pub async fn write_to_wire(
&mut self,
header: PacketHeader,
data: BytesMut,
) -> crate::Result<()> {
self.flushed = false;

let packet = Packet::new(header, data);
self.transport.send(packet).await?;

Ok(())
}

/// Sends all pending packages to the wire.
pub async fn flush_sink(&mut self) -> crate::Result<()> {
SinkExt::<Packet>::flush(&mut self.transport).await
}

/// Cleans the packet stream from previous use. It is important to use the
/// whole stream before using the connection again. Flushing the stream
/// makes sure we don't have any old data causing undefined behaviour after
Expand Down
8 changes: 4 additions & 4 deletions src/client/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

// We only get pre-login packets in the handshake process.
assert_eq!(header.ty, PacketType::PreLogin);
assert_eq!(header.r#type(), PacketType::PreLogin);

// And we know from this point on how much data we should expect
inner.read_remaining = header.length as usize - HEADER_BYTES;
inner.read_remaining = header.length() as usize - HEADER_BYTES;

event!(
Level::TRACE,
Expand Down Expand Up @@ -196,8 +196,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper
if !inner.header_written {
let mut header = PacketHeader::new(inner.wr_buf.len(), 0);

header.ty = PacketType::PreLogin;
header.status = PacketStatus::EndOfMessage;
header.set_type(PacketType::PreLogin);
header.set_status(PacketStatus::EndOfMessage);

header
.encode(&mut &mut inner.wr_buf[0..HEADER_BYTES])
Expand Down
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,14 @@ pub use from_sql::{FromSql, FromSqlOwned};
pub use result::*;
pub use row::{Column, ColumnType, Row};
pub use sql_browser::SqlBrowser;
pub use tds::{codec::ColumnData, numeric, stream::QueryStream, time, xml, EncryptionLevel};
pub use tds::{
codec::{
BulkLoadMetadata, BulkLoadRequest, ColumnData, ColumnFlag, TokenRow, TypeInfo, TypeLength,
},
numeric,
stream::QueryStream,
time, xml, EncryptionLevel,
};
pub use to_sql::{IntoSql, ToSql};
pub use uuid::Uuid;

Expand Down
1 change: 1 addition & 0 deletions src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl<'a> ExecuteResult {
ReceivedToken::DoneProc(done) if done.is_final() => (),
ReceivedToken::DoneProc(done) => acc.push(done.rows()),
ReceivedToken::DoneInProc(done) => acc.push(done.rows()),
ReceivedToken::Done(done) => acc.push(done.rows()),
_ => (),
}
Ok(acc)
Expand Down
14 changes: 7 additions & 7 deletions src/row.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::Error,
tds::codec::{ColumnData, FixedLenType, TokenRow, TypeInfo, VarLenType},
tds::codec::{ColumnData, FixedLenType, TokenRow, TypeInfo, TypeInfoInner, VarLenType},
FromSql,
};
use std::{fmt::Display, sync::Arc};
Expand Down Expand Up @@ -101,8 +101,8 @@ pub enum ColumnType {

impl From<&TypeInfo> for ColumnType {
fn from(ti: &TypeInfo) -> Self {
match ti {
TypeInfo::FixedLen(flt) => match flt {
match &ti.inner {
TypeInfoInner::FixedLen(flt) => match flt {
FixedLenType::Int1 => Self::Int1,
FixedLenType::Bit => Self::Bit,
FixedLenType::Int2 => Self::Int2,
Expand All @@ -116,7 +116,7 @@ impl From<&TypeInfo> for ColumnType {
FixedLenType::Int8 => Self::Int8,
FixedLenType::Null => Self::Null,
},
TypeInfo::VarLenSized(cx) => match cx.r#type() {
TypeInfoInner::VarLenSized(cx) => match cx.r#type() {
VarLenType::Guid => Self::Guid,
VarLenType::Intn => Self::Intn,
VarLenType::Bitn => Self::Bitn,
Expand Down Expand Up @@ -146,7 +146,7 @@ impl From<&TypeInfo> for ColumnType {
VarLenType::NText => Self::NText,
VarLenType::SSVariant => Self::SSVariant,
},
TypeInfo::VarLenSizedPrecision { ty, .. } => match ty {
TypeInfoInner::VarLenSizedPrecision { ty, .. } => match ty {
VarLenType::Guid => Self::Guid,
VarLenType::Intn => Self::Intn,
VarLenType::Bitn => Self::Bitn,
Expand Down Expand Up @@ -176,7 +176,7 @@ impl From<&TypeInfo> for ColumnType {
VarLenType::NText => Self::NText,
VarLenType::SSVariant => Self::SSVariant,
},
TypeInfo::Xml { .. } => Self::Xml,
TypeInfoInner::Xml { .. } => Self::Xml,
}
}
}
Expand Down Expand Up @@ -233,7 +233,7 @@ impl From<&TypeInfo> for ColumnType {
#[derive(Debug)]
pub struct Row {
pub(crate) columns: Arc<Vec<Column>>,
pub(crate) data: TokenRow,
pub(crate) data: TokenRow<'static>,
pub(crate) result_index: usize,
}

Expand Down
4 changes: 4 additions & 0 deletions src/tds/codec.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod batch_request;
mod bulk_load;
mod column_data;
mod decode;
mod encode;
mod guid;
mod header;
mod iterator_ext;
mod login;
mod packet;
mod pre_login;
Expand All @@ -12,12 +14,14 @@ mod token;
mod type_info;

pub use batch_request::*;
pub use bulk_load::*;
use bytes::BytesMut;
pub use column_data::*;
pub use decode::*;
pub(crate) use encode::*;
use futures::{Stream, TryStreamExt};
pub use header::*;
pub(crate) use iterator_ext::*;
pub use login::*;
pub use packet::*;
pub use pre_login::*;
Expand Down
Loading