From a5c9b8cdb72885f8e43990875c8e7061b07da0b1 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 27 Jul 2021 17:22:59 +0200 Subject: [PATCH 1/3] WIP: Bulk loads --- examples/bulk.rs | 41 +++ src/client.rs | 75 ++++- src/client/connection.rs | 30 +- src/client/tls.rs | 8 +- src/lib.rs | 2 +- src/result.rs | 1 + src/row.rs | 14 +- src/tds/codec.rs | 4 + src/tds/codec/bulk_load.rs | 138 +++++++++ src/tds/codec/column_data.rs | 191 ++++++++----- src/tds/codec/column_data/buf.rs | 81 ++++++ src/tds/codec/decode.rs | 10 +- src/tds/codec/header.rs | 36 ++- src/tds/codec/iterator_ext.rs | 27 ++ src/tds/codec/packet.rs | 2 +- src/tds/codec/rpc_request.rs | 3 +- src/tds/codec/token/token_col_metadata.rs | 143 +++++++++- src/tds/codec/token/token_done.rs | 18 +- src/tds/codec/token/token_row.rs | 80 ++++-- src/tds/codec/type_info.rs | 326 +++++++++++++++++++++- src/tds/collation.rs | 4 +- src/tds/stream/token.rs | 2 +- src/to_sql.rs | 8 + tests/query.rs | 1 + 24 files changed, 1105 insertions(+), 140 deletions(-) create mode 100644 examples/bulk.rs create mode 100644 src/tds/codec/bulk_load.rs create mode 100644 src/tds/codec/column_data/buf.rs create mode 100644 src/tds/codec/iterator_ext.rs diff --git a/examples/bulk.rs b/examples/bulk.rs new file mode 100644 index 00000000..51b5845d --- /dev/null +++ b/examples/bulk.rs @@ -0,0 +1,41 @@ +use once_cell::sync::Lazy; +use std::env; +use tiberius::{BulkLoadMetadata, Client, Config, IntoSql, TokenRow, TypeInfo}; +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +static CONN_STR: Lazy = Lazy::new(|| { + env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or_else(|_| { + "server=tcp:localhost,1433;IntegratedSecurity=true;TrustServerCertificate=true".to_owned() + }) +}); + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + + let config = Config::from_ado_string(&CONN_STR)?; + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let mut client = Client::connect(config, tcp.compat_write()).await?; + + let mut meta = BulkLoadMetadata::new(); + meta.add_column("val", TypeInfo::int()); + + let mut req = client.bulk_insert("bulk_test1", meta).await?; + + for i in [0, 1, 2, 3, 4, 5] { + let mut row = TokenRow::new(); + row.push(i.into_sql()); + + req.send(row).await?; + } + + let res = req.finalize().await?; + + dbg!(res); + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index e9ee7893..e70c1a08 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,10 +10,10 @@ pub(crate) use connection::*; use crate::{ result::ExecuteResult, tds::{ - codec, + codec::{self, IteratorJoin}, stream::{QueryStream, TokenStream}, }, - SqlReadBytes, ToSql, + BulkLoadMetadata, BulkLoadRequest, SqlReadBytes, ToSql, }; use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest}; use enumflags2::BitFlags; @@ -230,6 +230,77 @@ impl Client { Ok(result) } + /// Execute a `BULK INSERT` statement, efficiantly storing a large number of + /// rows to a specified table. + /// + /// # Example + /// + /// ``` + /// # use tiberius::{Config, BulkLoadMetadata, TypeInfo, TokenRow, IntoSql}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; + /// # use std::env; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( + /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), + /// # ); + /// # let config = Config::from_ado_string(&c_str)?; + /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + /// # tcp.set_nodelay(true)?; + /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + /// let create_table = r#" + /// CREATE TABLE ##bulk_test ( + /// id INT IDENTITY PRIMARY KEY, + /// val INT NOT NULL + /// ) + /// "#; + /// + /// client.simple_query(create_table).await?; + /// + /// // The request must have correct typing. + /// let mut meta = BulkLoadMetadata::new(); + /// meta.add_column("val", TypeInfo::int()); + /// + /// // Start the bulk insert with the client. + /// let mut req = client.bulk_insert("##bulk_test", meta).await?; + /// + /// for i in [0i32, 1i32, 2i32] { + /// let mut row = TokenRow::new(); + /// row.push(i.into_sql()); + /// + /// // The request will handle flushing to the wire in an optimal way, + /// // balancing between memory usage and IO performance. + /// req.send(row).await?; + /// } + /// + /// // The request must be finalized. + /// let res = req.finalize().await?; + /// assert_eq!(3, res.total()); + /// # Ok(()) + /// # } + /// ``` + pub async fn bulk_insert( + &mut self, + table: &str, + meta: BulkLoadMetadata, + ) -> crate::Result> { + // Start the bulk request + self.connection.flush_stream().await?; + + let col_data = meta.column_descriptions().join(", "); + let query = format!("INSERT BULK {} ({})", table, col_data); + + let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); + let id = self.connection.context_mut().next_packet_id(); + + self.connection.send(PacketHeader::batch(id), req).await?; + + let ts = TokenStream::new(&mut self.connection); + ts.flush_done().await?; + + BulkLoadRequest::new(&mut self.connection, meta) + } + fn rpc_params<'a>(query: impl Into>) -> Vec> { vec![ RpcParam { diff --git a/src/client/connection.rs b/src/client/connection.rs index 2aa7b3b7..14467fbf 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -167,16 +167,38 @@ impl Connection { split_payload.len() + HEADER_BYTES, ); - let packet = Packet::new(header, split_payload); - self.transport.send(packet).await?; + self.write_to_wire(header, split_payload).await?; } - // Rai rai says the turbofish goodbye - SinkExt::::flush(&mut self.transport).await?; + self.flush_sink().await?; Ok(()) } + /// Sends a packet of data to the database. + /// + /// # Warning + /// + /// Please be sure the packet size doesn't exceed the largest allowed size + /// dictaded by the server. + pub async fn write_to_wire( + &mut self, + header: PacketHeader, + data: BytesMut, + ) -> crate::Result<()> { + self.flushed = false; + + let packet = Packet::new(header, data); + self.transport.send(packet).await?; + + Ok(()) + } + + /// Sends all pending packages to the wire. + pub async fn flush_sink(&mut self) -> crate::Result<()> { + SinkExt::::flush(&mut self.transport).await + } + /// Cleans the packet stream from previous use. It is important to use the /// whole stream before using the connection again. Flushing the stream /// makes sure we don't have any old data causing undefined behaviour after diff --git a/src/client/tls.rs b/src/client/tls.rs index 65b0f1d8..3902c108 100644 --- a/src/client/tls.rs +++ b/src/client/tls.rs @@ -140,10 +140,10 @@ impl AsyncRead for TlsPreloginWrapper< .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; // We only get pre-login packets in the handshake process. - assert_eq!(header.ty, PacketType::PreLogin); + assert_eq!(header.r#type(), PacketType::PreLogin); // And we know from this point on how much data we should expect - inner.read_remaining = header.length as usize - HEADER_BYTES; + inner.read_remaining = header.length() as usize - HEADER_BYTES; event!( Level::TRACE, @@ -196,8 +196,8 @@ impl AsyncWrite for TlsPreloginWrapper if !inner.header_written { let mut header = PacketHeader::new(inner.wr_buf.len(), 0); - header.ty = PacketType::PreLogin; - header.status = PacketStatus::EndOfMessage; + header.set_type(PacketType::PreLogin); + header.set_status(PacketStatus::EndOfMessage); header .encode(&mut &mut inner.wr_buf[0..HEADER_BYTES]) diff --git a/src/lib.rs b/src/lib.rs index db7ee852..b8806656 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -196,7 +196,7 @@ pub use from_sql::{FromSql, FromSqlOwned}; pub use result::*; pub use row::{Column, ColumnType, Row}; pub use sql_browser::SqlBrowser; -pub use tds::{codec::ColumnData, numeric, stream::QueryStream, time, xml, EncryptionLevel}; +pub use tds::{EncryptionLevel, codec::{BulkLoadMetadata, BulkLoadRequest, ColumnData, TokenRow, TypeInfo, TypeLength}, numeric, stream::QueryStream, time, xml}; pub use to_sql::{IntoSql, ToSql}; pub use uuid::Uuid; diff --git a/src/result.rs b/src/result.rs index 3f804bcc..5c037b24 100644 --- a/src/result.rs +++ b/src/result.rs @@ -58,6 +58,7 @@ impl<'a> ExecuteResult { ReceivedToken::DoneProc(done) if done.is_final() => (), ReceivedToken::DoneProc(done) => acc.push(done.rows()), ReceivedToken::DoneInProc(done) => acc.push(done.rows()), + ReceivedToken::Done(done) => acc.push(done.rows()), _ => (), } Ok(acc) diff --git a/src/row.rs b/src/row.rs index bcd58947..5854ed25 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,6 +1,6 @@ use crate::{ error::Error, - tds::codec::{ColumnData, FixedLenType, TokenRow, TypeInfo, VarLenType}, + tds::codec::{ColumnData, FixedLenType, TokenRow, TypeInfo, TypeInfoInner, VarLenType}, FromSql, }; use std::{fmt::Display, sync::Arc}; @@ -101,8 +101,8 @@ pub enum ColumnType { impl From<&TypeInfo> for ColumnType { fn from(ti: &TypeInfo) -> Self { - match ti { - TypeInfo::FixedLen(flt) => match flt { + match &ti.inner { + TypeInfoInner::FixedLen(flt) => match flt { FixedLenType::Int1 => Self::Int1, FixedLenType::Bit => Self::Bit, FixedLenType::Int2 => Self::Int2, @@ -116,7 +116,7 @@ impl From<&TypeInfo> for ColumnType { FixedLenType::Int8 => Self::Int8, FixedLenType::Null => Self::Null, }, - TypeInfo::VarLenSized(cx) => match cx.r#type() { + TypeInfoInner::VarLenSized(cx) => match cx.r#type() { VarLenType::Guid => Self::Guid, VarLenType::Intn => Self::Intn, VarLenType::Bitn => Self::Bitn, @@ -146,7 +146,7 @@ impl From<&TypeInfo> for ColumnType { VarLenType::NText => Self::NText, VarLenType::SSVariant => Self::SSVariant, }, - TypeInfo::VarLenSizedPrecision { ty, .. } => match ty { + TypeInfoInner::VarLenSizedPrecision { ty, .. } => match ty { VarLenType::Guid => Self::Guid, VarLenType::Intn => Self::Intn, VarLenType::Bitn => Self::Bitn, @@ -176,7 +176,7 @@ impl From<&TypeInfo> for ColumnType { VarLenType::NText => Self::NText, VarLenType::SSVariant => Self::SSVariant, }, - TypeInfo::Xml { .. } => Self::Xml, + TypeInfoInner::Xml { .. } => Self::Xml, } } } @@ -233,7 +233,7 @@ impl From<&TypeInfo> for ColumnType { #[derive(Debug)] pub struct Row { pub(crate) columns: Arc>, - pub(crate) data: TokenRow, + pub(crate) data: TokenRow<'static>, pub(crate) result_index: usize, } diff --git a/src/tds/codec.rs b/src/tds/codec.rs index 7874bc50..03356be2 100644 --- a/src/tds/codec.rs +++ b/src/tds/codec.rs @@ -1,9 +1,11 @@ mod batch_request; +mod bulk_load; mod column_data; mod decode; mod encode; mod guid; mod header; +mod iterator_ext; mod login; mod packet; mod pre_login; @@ -12,12 +14,14 @@ mod token; mod type_info; pub use batch_request::*; +pub use bulk_load::*; use bytes::BytesMut; pub use column_data::*; pub use decode::*; pub(crate) use encode::*; use futures::{Stream, TryStreamExt}; pub use header::*; +pub(crate) use iterator_ext::*; pub use login::*; pub use packet::*; pub use pre_login::*; diff --git a/src/tds/codec/bulk_load.rs b/src/tds/codec/bulk_load.rs new file mode 100644 index 00000000..c8a484a7 --- /dev/null +++ b/src/tds/codec/bulk_load.rs @@ -0,0 +1,138 @@ +use asynchronous_codec::BytesMut; +use enumflags2::BitFlags; +use futures::{AsyncRead, AsyncWrite}; +use tracing::{event, Level}; + +use crate::{client::Connection, sql_read_bytes::SqlReadBytes, ExecuteResult}; + +use super::{ + BaseMetaDataColumn, Encode, MetaDataColumn, PacketHeader, PacketStatus, TokenColMetaData, + TokenDone, TokenRow, TypeInfo, HEADER_BYTES, +}; + +/// Column metadata for a bulk load request. +#[derive(Debug, Default, Clone)] +pub struct BulkLoadMetadata { + columns: Vec, +} + +impl BulkLoadMetadata { + /// Creates a metadata with no columns specified. + pub fn new() -> Self { + Self::default() + } + + /// Add a column to the request. Order should be same as the order of data + /// in the rows. + pub fn add_column(&mut self, name: &str, ty: TypeInfo) { + self.columns.push(MetaDataColumn { + base: BaseMetaDataColumn { + flags: BitFlags::empty(), + ty, + }, + col_name: name.into(), + }); + } + + pub(crate) fn column_descriptions(&self) -> impl Iterator + '_ { + self.columns.iter().map(|c| format!("{}", c)) + } +} + +impl Encode for BulkLoadMetadata { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + let cmd = TokenColMetaData { + columns: self.columns, + }; + + cmd.encode(dst) + } +} + +/// A handler for a bulk insert data flow. +#[derive(Debug)] +pub struct BulkLoadRequest<'a, S> +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + connection: &'a mut Connection, + packet_id: u8, + buf: BytesMut, +} + +impl<'a, S> BulkLoadRequest<'a, S> +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + pub(crate) fn new( + connection: &'a mut Connection, + meta: BulkLoadMetadata, + ) -> crate::Result { + let packet_id = connection.context_mut().next_packet_id(); + let mut buf = BytesMut::new(); + + meta.encode(&mut buf)?; + + let this = Self { + connection, + packet_id, + buf, + }; + + Ok(this) + } + + /// Adds a new row to the bulk insert, flushing only when having a full packet of data. + /// + /// # Warning + /// + /// After the last row, [`finalize`] must be called to flush the buffered + /// data and for the data to actually be available in the table. + /// + /// [`finalize`]: #method.finalize + pub async fn send(&mut self, row: TokenRow<'a>) -> crate::Result<()> { + let packet_size = (self.connection.context().packet_size() as usize) - HEADER_BYTES; + + row.encode(&mut self.buf)?; + + while self.buf.len() >= packet_size { + let header = PacketHeader::bulk_load(self.packet_id); + let data = self.buf.split_to(packet_size); + + event!( + Level::TRACE, + "Bulk insert packet ({} bytes)", + data.len() + HEADER_BYTES, + ); + + self.connection.write_to_wire(header, data).await?; + } + + Ok(()) + } + + /// Ends the bulk load, flushing all pending data to the wire. + /// + /// This method must be called after sending all the data to flush all + /// pending data and to get the server actually to store the rows to the + /// table. + pub async fn finalize(mut self) -> crate::Result { + TokenDone::default().encode(&mut self.buf)?; + + let mut header = PacketHeader::bulk_load(self.packet_id); + header.set_status(PacketStatus::EndOfMessage); + + let data = self.buf.split(); + + event!( + Level::TRACE, + "Finalizing a bulk insert ({} bytes)", + data.len() + HEADER_BYTES, + ); + + self.connection.write_to_wire(header, data).await?; + self.connection.flush_sink().await?; + + ExecuteResult::new(self.connection).await + } +} diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 70871b3e..764bbcf6 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -1,5 +1,6 @@ mod binary; mod bit; +mod buf; #[cfg(feature = "tds73")] mod date; #[cfg(feature = "tds73")] @@ -25,10 +26,11 @@ use super::{Encode, FixedLenType, TypeInfo, VarLenType}; #[cfg(feature = "tds73")] use crate::tds::{Date, DateTime2, DateTimeOffset, Time}; use crate::{ - tds::{xml::XmlData, DateTime, Numeric, SmallDateTime}, + tds::{codec::TypeInfoInner, xml::XmlData, DateTime, Numeric, SmallDateTime}, SqlReadBytes, }; -use bytes::{BufMut, BytesMut}; +pub(crate) use buf::BufColumnData; +use bytes::BufMut; use std::borrow::{BorrowMut, Cow}; use uuid::Uuid; @@ -124,79 +126,97 @@ impl<'a> ColumnData<'a> { where R: SqlReadBytes + Unpin, { - let res = match ctx { - TypeInfo::FixedLen(fixed_ty) => fixed_len::decode(src, fixed_ty).await?, - TypeInfo::VarLenSized(cx) => var_len::decode(src, cx).await?, - TypeInfo::VarLenSizedPrecision { ty, scale, .. } => match ty { + let res = match &ctx.inner { + TypeInfoInner::FixedLen(fixed_ty) => fixed_len::decode(src, fixed_ty).await?, + TypeInfoInner::VarLenSized(cx) => var_len::decode(src, cx).await?, + TypeInfoInner::VarLenSizedPrecision { ty, scale, .. } => match ty { VarLenType::Decimaln | VarLenType::Numericn => { ColumnData::Numeric(Numeric::decode(src, *scale).await?) } _ => todo!(), }, - TypeInfo::Xml { schema, size } => xml::decode(src, *size, schema.clone()).await?, + TypeInfoInner::Xml { schema, size } => xml::decode(src, *size, schema.clone()).await?, }; Ok(res) } } -impl<'a> Encode for ColumnData<'a> { - fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { +impl<'a> Encode> for ColumnData<'a> { + fn encode(self, dst: &mut BufColumnData<'a>) -> crate::Result<()> { match self { ColumnData::Bit(Some(val)) => { - let header = [&[VarLenType::Bitn as u8, 1, 1][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Bitn as u8, 1, 1][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_u8(val as u8); } ColumnData::U8(Some(val)) => { - let header = [&[VarLenType::Intn as u8, 1, 1][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Intn as u8, 1, 1][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_u8(val); } ColumnData::I16(Some(val)) => { - let header = [&[VarLenType::Intn as u8, 2, 2][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Intn as u8, 2, 2][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_i16_le(val); } ColumnData::I32(Some(val)) => { - let header = [&[VarLenType::Intn as u8, 4, 4][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Intn as u8, 4, 4][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_i32_le(val); } ColumnData::I64(Some(val)) => { - let header = [&[VarLenType::Intn as u8, 8, 8][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Intn as u8, 8, 8][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_i64_le(val); } ColumnData::F32(Some(val)) => { - let header = [&[VarLenType::Floatn as u8, 4, 4][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Floatn as u8, 4, 4][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_f32_le(val); } ColumnData::F64(Some(val)) => { - let header = [&[VarLenType::Floatn as u8, 8, 8][..]].concat(); + if dst.write_headers { + let header = [&[VarLenType::Floatn as u8, 8, 8][..]].concat(); + dst.extend_from_slice(&header); + } - dst.extend_from_slice(&header); dst.put_f64_le(val); } ColumnData::Guid(Some(uuid)) => { - let header = [&[VarLenType::Guid as u8, 16, 16][..]].concat(); - dst.extend_from_slice(&header); + if dst.write_headers { + let header = [&[VarLenType::Guid as u8, 16, 16][..]].concat(); + dst.extend_from_slice(&header); + } let mut data = *uuid.as_bytes(); super::guid::reorder_bytes(&mut data); dst.extend_from_slice(&data); } ColumnData::String(Some(ref s)) if s.len() <= 4000 => { - dst.put_u8(VarLenType::NVarchar as u8); - dst.put_u16_le(8000); - dst.extend_from_slice(&[0u8; 5][..]); + if dst.write_headers { + dst.put_u8(VarLenType::NVarchar as u8); + dst.put_u16_le(8000); + dst.extend_from_slice(&[0u8; 5][..]); + } let mut length = 0u16; let len_pos = dst.len(); @@ -216,14 +236,16 @@ impl<'a> Encode for ColumnData<'a> { } } ColumnData::String(Some(ref s)) => { - // length: 0xffff and raw collation - dst.put_u8(VarLenType::NVarchar as u8); - dst.extend_from_slice(&[0xff_u8; 2][..]); - dst.extend_from_slice(&[0u8; 5][..]); - - // we cannot cheaply predetermine the length of the UCS2 string beforehand - // (2 * bytes(UTF8) is not always right) - so just let the SQL server handle it - dst.put_u64_le(0xfffffffffffffffe_u64); + if dst.write_headers { + // length: 0xffff and raw collation + dst.put_u8(VarLenType::NVarchar as u8); + dst.extend_from_slice(&[0xff_u8; 2][..]); + dst.extend_from_slice(&[0u8; 5][..]); + + // we cannot cheaply predetermine the length of the UCS2 string beforehand + // (2 * bytes(UTF8) is not always right) - so just let the SQL server handle it + dst.put_u64_le(0xfffffffffffffffe_u64); + } // Write the varchar length let mut length = 0u32; @@ -247,17 +269,23 @@ impl<'a> Encode for ColumnData<'a> { } } ColumnData::Binary(Some(bytes)) if bytes.len() <= 8000 => { - dst.put_u8(VarLenType::BigVarBin as u8); - dst.put_u16_le(8000); + if dst.write_headers { + dst.put_u8(VarLenType::BigVarBin as u8); + dst.put_u16_le(8000); + } + dst.put_u16_le(bytes.len() as u16); dst.extend(bytes.into_owned()); } ColumnData::Binary(Some(bytes)) => { - dst.put_u8(VarLenType::BigVarBin as u8); - // Max length - dst.put_u16_le(0xffff_u16); - // Also the length is unknown - dst.put_u64_le(0xfffffffffffffffe_u64); + if dst.write_headers { + dst.put_u8(VarLenType::BigVarBin as u8); + // Max length + dst.put_u16_le(0xffff_u16); + // Also the length is unknown + dst.put_u64_le(0xfffffffffffffffe_u64); + } + // We'll write in one chunk, length is the whole bytes length dst.put_u32_le(bytes.len() as u32); // Payload @@ -266,54 +294,77 @@ impl<'a> Encode for ColumnData<'a> { dst.put_u32_le(0); } ColumnData::DateTime(Some(dt)) => { - dst.extend_from_slice(&[VarLenType::Datetimen as u8, 8, 8]); - dt.encode(dst)?; + if dst.write_headers { + dst.extend_from_slice(&[VarLenType::Datetimen as u8, 8, 8]); + } + + dt.encode(&mut *dst)?; } ColumnData::SmallDateTime(Some(dt)) => { - dst.extend_from_slice(&[VarLenType::Datetimen as u8, 4, 4]); - dt.encode(dst)?; + if dst.write_headers { + dst.extend_from_slice(&[VarLenType::Datetimen as u8, 4, 4]); + } + + dt.encode(&mut *dst)?; } #[cfg(feature = "tds73")] ColumnData::Time(Some(time)) => { - dst.extend_from_slice(&[VarLenType::Timen as u8, time.scale(), time.len()?]); + if dst.write_headers { + dst.extend_from_slice(&[VarLenType::Timen as u8, time.scale(), time.len()?]); + } - time.encode(dst)?; + time.encode(&mut *dst)?; } #[cfg(feature = "tds73")] ColumnData::Date(Some(date)) => { - dst.extend_from_slice(&[VarLenType::Daten as u8, 3]); - date.encode(dst)?; + if dst.write_headers { + dst.extend_from_slice(&[VarLenType::Daten as u8, 3]); + } + + date.encode(&mut *dst)?; } #[cfg(feature = "tds73")] ColumnData::DateTime2(Some(dt)) => { - let len = dt.time().len()? + 3; - - dst.extend_from_slice(&[VarLenType::Datetime2 as u8, dt.time().scale(), len]); + if dst.write_headers { + let len = dt.time().len()? + 3; + dst.extend_from_slice(&[VarLenType::Datetime2 as u8, dt.time().scale(), len]); + } - dt.encode(dst)?; + dt.encode(&mut *dst)?; } #[cfg(feature = "tds73")] ColumnData::DateTimeOffset(Some(dto)) => { - dst.extend_from_slice(&[ - VarLenType::DatetimeOffsetn as u8, - dto.datetime2().time().scale(), - dto.datetime2().time().len()? + 5, - ]); + if dst.write_headers { + let headers = &[ + VarLenType::DatetimeOffsetn as u8, + dto.datetime2().time().scale(), + dto.datetime2().time().len()? + 5, + ]; + + dst.extend_from_slice(headers); + } - dto.encode(dst)?; + dto.encode(&mut *dst)?; } ColumnData::Xml(Some(xml)) => { - dst.put_u8(VarLenType::Xml as u8); - xml.into_owned().encode(dst)?; + if dst.write_headers { + dst.put_u8(VarLenType::Xml as u8); + } + xml.into_owned().encode(&mut *dst)?; } ColumnData::Numeric(Some(num)) => { - dst.extend_from_slice(&[ - VarLenType::Numericn as u8, - num.len(), - num.precision(), - num.scale(), - ]); - num.encode(dst)?; + if dst.write_headers { + let headers = &[ + VarLenType::Numericn as u8, + num.len(), + num.precision(), + num.scale(), + ]; + + dst.extend_from_slice(headers); + } + + num.encode(&mut *dst)?; } _ => { // None/null diff --git a/src/tds/codec/column_data/buf.rs b/src/tds/codec/column_data/buf.rs new file mode 100644 index 00000000..b2dccf14 --- /dev/null +++ b/src/tds/codec/column_data/buf.rs @@ -0,0 +1,81 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut}, +}; + +use asynchronous_codec::BytesMut; +use bytes::BufMut; + +pub(crate) struct BufColumnData<'a> { + buf: &'a mut BytesMut, + pub(crate) write_headers: bool, +} + +impl<'a> BufColumnData<'a> { + pub(crate) fn with_headers(buf: &'a mut BytesMut) -> Self { + Self { + buf, + write_headers: true, + } + } + + #[allow(dead_code)] + pub(crate) fn without_headers(buf: &'a mut BytesMut) -> Self { + Self { + buf, + write_headers: false, + } + } + + pub(crate) fn extend(&mut self, other: Vec) { + self.buf.extend(other) + } + + pub(crate) fn extend_from_slice(&mut self, other: &[u8]) { + self.buf.extend_from_slice(other); + } + + pub(crate) fn len(&mut self) -> usize { + self.buf.len() + } +} + +unsafe impl<'a> BufMut for BufColumnData<'a> { + fn remaining_mut(&self) -> usize { + self.buf.remaining_mut() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + self.buf.advance_mut(cnt) + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + self.buf.chunk_mut() + } +} + +impl<'a> Borrow<[u8]> for BufColumnData<'a> { + fn borrow(&self) -> &[u8] { + self.buf.deref() + } +} + +impl<'a> BorrowMut<[u8]> for BufColumnData<'a> { + fn borrow_mut(&mut self) -> &mut [u8] { + self.buf.borrow_mut() + } +} + +impl<'a> Deref for BufColumnData<'a> { + type Target = BytesMut; + + fn deref(&self) -> &Self::Target { + self.buf + } +} + +impl<'a> DerefMut for BufColumnData<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.buf + } +} diff --git a/src/tds/codec/decode.rs b/src/tds/codec/decode.rs index a29f4e6e..d19fec0c 100644 --- a/src/tds/codec/decode.rs +++ b/src/tds/codec/decode.rs @@ -21,16 +21,22 @@ impl Decoder for PacketCodec { } let header = PacketHeader::decode(&mut BytesMut::from(&src[0..HEADER_BYTES]))?; - let length = header.length as usize; + let length = header.length() as usize; if src.len() < length { src.reserve(length); return Ok(None); } - event!(Level::TRACE, "Reading a {:?} ({} bytes)", header.ty, length,); + event!( + Level::TRACE, + "Reading a {:?} ({} bytes)", + header.r#type(), + length, + ); let header = PacketHeader::decode(src)?; + if length < HEADER_BYTES { return Err(Error::Protocol("Invalid packet length".into())); } diff --git a/src/tds/codec/header.rs b/src/tds/codec/header.rs index 5f4cb1ac..719fc158 100644 --- a/src/tds/codec/header.rs +++ b/src/tds/codec/header.rs @@ -41,18 +41,18 @@ uint_enum! { /// packet header consisting of 8 bytes [2.2.3.1] #[derive(Debug, Clone, Copy)] pub(crate) struct PacketHeader { - pub ty: PacketType, - pub status: PacketStatus, + ty: PacketType, + status: PacketStatus, /// [BE] the length of the packet (including the 8 header bytes) /// must match the negotiated size sending from client to server [since TDSv7.3] after login /// (only if not EndOfMessage) - pub length: u16, + length: u16, /// [BE] the process ID on the server, for debugging purposes only - pub spid: u16, + spid: u16, /// packet id - pub id: u8, + id: u8, /// currently unused - pub window: u8, + window: u8, } impl PacketHeader { @@ -100,9 +100,33 @@ impl PacketHeader { } } + pub fn bulk_load(id: u8) -> Self { + Self { + ty: PacketType::BulkLoad, + status: PacketStatus::NormalMessage, + ..Self::new(0, id) + } + } + pub fn set_status(&mut self, status: PacketStatus) { self.status = status; } + + pub fn set_type(&mut self, ty: PacketType) { + self.ty = ty; + } + + pub fn status(&self) -> PacketStatus { + self.status + } + + pub fn r#type(&self) -> PacketType { + self.ty + } + + pub fn length(&self) -> u16 { + self.length + } } impl Encode for PacketHeader diff --git a/src/tds/codec/iterator_ext.rs b/src/tds/codec/iterator_ext.rs new file mode 100644 index 00000000..aecdd6d5 --- /dev/null +++ b/src/tds/codec/iterator_ext.rs @@ -0,0 +1,27 @@ +use std::fmt::{Display, Write}; + +pub(crate) trait IteratorJoin { + fn join(self, sep: &str) -> String; +} + +impl IteratorJoin for T +where + T: Iterator, + I: Display, +{ + fn join(mut self, sep: &str) -> String { + let (lower_bound, _) = self.size_hint(); + let mut out = String::with_capacity(sep.len() * lower_bound); + + if let Some(first_item) = self.next() { + write!(out, "{}", first_item).unwrap(); + } + + for item in self { + out.push_str(sep); + write!(out, "{}", item).unwrap(); + } + + out + } +} diff --git a/src/tds/codec/packet.rs b/src/tds/codec/packet.rs index 6cb0892c..9927ed35 100644 --- a/src/tds/codec/packet.rs +++ b/src/tds/codec/packet.rs @@ -13,7 +13,7 @@ impl Packet { } pub(crate) fn is_last(&self) -> bool { - self.header.status == PacketStatus::EndOfMessage + self.header.status() == PacketStatus::EndOfMessage } pub(crate) fn into_parts(self) -> (PacketHeader, BytesMut) { diff --git a/src/tds/codec/rpc_request.rs b/src/tds/codec/rpc_request.rs index 35beb75e..6c9d28f2 100644 --- a/src/tds/codec/rpc_request.rs +++ b/src/tds/codec/rpc_request.rs @@ -1,3 +1,4 @@ +use super::BufColumnData; use super::{AllHeaderTy, Encode, ALL_HEADERS_LEN_TX}; use crate::{tds::codec::ColumnData, Result}; use bytes::{BufMut, BytesMut}; @@ -133,7 +134,7 @@ impl<'a> Encode for RpcParam<'a> { } dst.put_u8(self.flags.bits()); - self.value.encode(dst)?; + self.value.encode(&mut BufColumnData::with_headers(dst))?; let dst: &mut [u8] = dst.borrow_mut(); dst[len_pos] = length; diff --git a/src/tds/codec/token/token_col_metadata.rs b/src/tds/codec/token/token_col_metadata.rs index 9800090d..db0085a0 100644 --- a/src/tds/codec/token/token_col_metadata.rs +++ b/src/tds/codec/token/token_col_metadata.rs @@ -1,22 +1,101 @@ +use std::{borrow::BorrowMut, fmt::Display}; + use crate::{ error::Error, - tds::codec::{FixedLenType, TypeInfo, VarLenType}, + tds::codec::{Encode, FixedLenType, TokenType, TypeInfo, TypeInfoInner, VarLenType}, Column, ColumnData, ColumnType, SqlReadBytes, }; +use asynchronous_codec::BytesMut; +use bytes::BufMut; use enumflags2::{bitflags, BitFlags}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TokenColMetaData { pub columns: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MetaDataColumn { pub base: BaseMetaDataColumn, pub col_name: String, } -#[derive(Debug)] +impl Display for MetaDataColumn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} ", self.col_name)?; + + match &self.base.ty.inner { + TypeInfoInner::FixedLen(fixed) => match fixed { + FixedLenType::Int1 => write!(f, "tinyint")?, + FixedLenType::Bit => write!(f, "bit")?, + FixedLenType::Int2 => write!(f, "smallint")?, + FixedLenType::Int4 => write!(f, "int")?, + FixedLenType::Datetime4 => write!(f, "smalldatetime")?, + FixedLenType::Float4 => write!(f, "real")?, + FixedLenType::Money => write!(f, "money")?, + FixedLenType::Datetime => write!(f, "datetime")?, + FixedLenType::Float8 => write!(f, "float")?, + FixedLenType::Money4 => write!(f, "smallmoney")?, + FixedLenType::Int8 => write!(f, "bigint")?, + FixedLenType::Null => unreachable!(), + }, + TypeInfoInner::VarLenSized(ctx) => match ctx.r#type() { + VarLenType::Guid => write!(f, "uniqueidentifier")?, + #[cfg(feature = "tds73")] + VarLenType::Daten => write!(f, "date")?, + #[cfg(feature = "tds73")] + VarLenType::Timen => write!(f, "time")?, + #[cfg(feature = "tds73")] + VarLenType::Datetime2 => write!(f, "datetime2")?, + #[cfg(feature = "tds73")] + VarLenType::DatetimeOffsetn => write!(f, "datetimeoffset")?, + VarLenType::BigVarBin => { + if ctx.len() <= 8000 { + write!(f, "varbinary({})", ctx.len())? + } else { + write!(f, "varbinary(max)")? + } + } + VarLenType::BigVarChar => { + if ctx.len() <= 8000 { + write!(f, "varchar({})", ctx.len())? + } else { + write!(f, "varchar(max)")? + } + } + VarLenType::BigBinary => write!(f, "binary({})", ctx.len())?, + VarLenType::BigChar => write!(f, "char({})", ctx.len())?, + VarLenType::NVarchar => { + if ctx.len() <= 4000 { + write!(f, "nvarchar({})", ctx.len())? + } else { + write!(f, "nvarchar(max)")? + } + } + VarLenType::NChar => write!(f, "nchar({})", ctx.len())?, + VarLenType::Text => write!(f, "text")?, + VarLenType::Image => write!(f, "image")?, + VarLenType::NText => write!(f, "ntext")?, + _ => unreachable!(), + }, + TypeInfoInner::VarLenSizedPrecision { + ty, + size: _, + precision, + scale, + } => match ty { + VarLenType::Decimaln => write!(f, "decimal({},{})", precision, scale)?, + VarLenType::Numericn => write!(f, "numeric({},{})", precision, scale)?, + _ => unreachable!(), + }, + TypeInfoInner::Xml { .. } => write!(f, "xml")?, + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] pub struct BaseMetaDataColumn { pub flags: BitFlags, pub ty: TypeInfo, @@ -24,8 +103,8 @@ pub struct BaseMetaDataColumn { impl BaseMetaDataColumn { pub(crate) fn null_value(&self) -> ColumnData<'static> { - match self.ty { - TypeInfo::FixedLen(ty) => match ty { + match &self.ty.inner { + TypeInfoInner::FixedLen(ty) => match ty { FixedLenType::Null => ColumnData::I32(None), FixedLenType::Int1 => ColumnData::U8(None), FixedLenType::Bit => ColumnData::Bit(None), @@ -39,7 +118,7 @@ impl BaseMetaDataColumn { FixedLenType::Money4 => ColumnData::F32(None), FixedLenType::Int8 => ColumnData::I64(None), }, - TypeInfo::VarLenSized(cx) => match cx.r#type() { + TypeInfoInner::VarLenSized(cx) => match cx.r#type() { VarLenType::Guid => ColumnData::Guid(None), VarLenType::Intn => ColumnData::I32(None), VarLenType::Bitn => ColumnData::Bit(None), @@ -69,7 +148,7 @@ impl BaseMetaDataColumn { VarLenType::NText => ColumnData::String(None), VarLenType::SSVariant => todo!(), }, - TypeInfo::VarLenSizedPrecision { ty, .. } => match ty { + TypeInfoInner::VarLenSizedPrecision { ty, .. } => match ty { VarLenType::Guid => ColumnData::Guid(None), VarLenType::Intn => ColumnData::I32(None), VarLenType::Bitn => ColumnData::Bit(None), @@ -99,8 +178,52 @@ impl BaseMetaDataColumn { VarLenType::NText => ColumnData::String(None), VarLenType::SSVariant => todo!(), }, - TypeInfo::Xml { .. } => ColumnData::Xml(None), + TypeInfoInner::Xml { .. } => ColumnData::Xml(None), + } + } +} + +impl Encode for TokenColMetaData { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + dst.put_u8(TokenType::ColMetaData as u8); + dst.put_u16_le(self.columns.len() as u16); + + for col in self.columns.into_iter() { + col.encode(dst)?; + } + + Ok(()) + } +} + +impl Encode for MetaDataColumn { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + dst.put_u32_le(0); + self.base.encode(dst)?; + + let len_pos = dst.len(); + let mut length = 0u8; + + dst.put_u8(length); + + for chr in self.col_name.encode_utf16() { + length += 1; + dst.put_u16_le(chr); } + + let dst: &mut [u8] = dst.borrow_mut(); + dst[len_pos] = length; + + Ok(()) + } +} + +impl Encode for BaseMetaDataColumn { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + dst.put_u16_le(BitFlags::bits(self.flags)); + self.ty.encode(dst)?; + + Ok(()) } } @@ -183,7 +306,7 @@ impl BaseMetaDataColumn { let ty = TypeInfo::decode(src).await?; - if let TypeInfo::VarLenSized(cx) = ty { + if let TypeInfoInner::VarLenSized(cx) = ty.inner { if let Text | NText | Image = cx.r#type() { let num_of_parts = src.read_u8().await?; diff --git a/src/tds/codec/token/token_done.rs b/src/tds/codec/token/token_done.rs index 472a49ed..1b41a2d0 100644 --- a/src/tds/codec/token/token_done.rs +++ b/src/tds/codec/token/token_done.rs @@ -1,8 +1,10 @@ -use crate::{Error, SqlReadBytes}; +use crate::{tds::codec::Encode, Error, SqlReadBytes, TokenType}; +use asynchronous_codec::BytesMut; +use bytes::BufMut; use enumflags2::{bitflags, BitFlags}; use std::fmt; -#[derive(Debug)] +#[derive(Debug, Default)] pub struct TokenDone { status: BitFlags, cur_cmd: u16, @@ -57,6 +59,18 @@ impl TokenDone { } } +impl Encode for TokenDone { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + dst.put_u8(TokenType::Done as u8); + dst.put_u16_le(BitFlags::bits(self.status)); + + dst.put_u16_le(self.cur_cmd); + dst.put_u64_le(self.done_rows); + + Ok(()) + } +} + impl fmt::Display for TokenDone { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.done_rows == 0 { diff --git a/src/tds/codec/token/token_row.rs b/src/tds/codec/token/token_row.rs index 4e9ee57c..de72a3ec 100644 --- a/src/tds/codec/token/token_row.rs +++ b/src/tds/codec/token/token_row.rs @@ -1,13 +1,19 @@ -use crate::{tds::codec::ColumnData, SqlReadBytes}; +use crate::{ + tds::codec::{BufColumnData, ColumnData, Encode}, + SqlReadBytes, TokenType, +}; +use asynchronous_codec::BytesMut; +use bytes::BufMut; use futures::io::AsyncReadExt; -#[derive(Debug)] -pub struct TokenRow { - data: Vec>, +/// A row of data. +#[derive(Debug, Default, Clone)] +pub struct TokenRow<'a> { + data: Vec>, } -impl IntoIterator for TokenRow { - type Item = ColumnData<'static>; +impl<'a> IntoIterator for TokenRow<'a> { + type Item = ColumnData<'a>; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { @@ -15,7 +21,56 @@ impl IntoIterator for TokenRow { } } -impl TokenRow { +impl<'a> Encode for TokenRow<'a> { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + dst.put_u8(TokenType::Row as u8); + + let mut col_buf = BufColumnData::without_headers(dst); + + for value in self.data.into_iter() { + value.encode(&mut col_buf)? + } + + Ok(()) + } +} + +impl<'a> TokenRow<'a> { + /// Creates a new empty row. + pub const fn new() -> Self { + Self { data: Vec::new() } + } + + /// Creates a new empty row with allocated capacity. + pub fn with_capacity(&self, capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity), + } + } + + /// The number of columns. + pub fn len(&self) -> usize { + self.data.len() + } + + /// True if row has no columns. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Gets the columnar data with the given index. `None` if index out of + /// bounds. + pub fn get(&self, index: usize) -> Option<&ColumnData<'a>> { + self.data.get(index) + } + + /// Adds a new value to the row. + pub fn push(&mut self, value: ColumnData<'a>) { + self.data.push(value); + } +} + +impl TokenRow<'static> { /// Normal row. We'll read the metadata what we've cached and parse columns /// based on that. pub(crate) async fn decode(src: &mut R) -> crate::Result @@ -61,17 +116,6 @@ impl TokenRow { Ok(row) } - - /// The number of columns. - pub fn len(&self) -> usize { - self.data.len() - } - - /// Gives the columnar data with the given index. `None` if index out of - /// bounds. - pub fn get(&self, index: usize) -> Option<&ColumnData<'static>> { - self.data.get(index) - } } /// A bitmap of null values in the row. Sometimes SQL Server decides to pack the diff --git a/src/tds/codec/type_info.rs b/src/tds/codec/type_info.rs index 9914dd61..8b5f625e 100644 --- a/src/tds/codec/type_info.rs +++ b/src/tds/codec/type_info.rs @@ -1,8 +1,28 @@ +use asynchronous_codec::BytesMut; +use bytes::BufMut; + use crate::{tds::Collation, xml::XmlSchema, Error, SqlReadBytes}; -use std::{convert::TryFrom, sync::Arc}; +use std::{convert::TryFrom, sync::Arc, usize}; + +use super::Encode; + +/// Describes a type of a column. +#[derive(Debug, Clone)] +pub struct TypeInfo { + pub(crate) inner: TypeInfoInner, +} + +/// A length of a column in bytes or characters. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypeLength { + /// The number of bytes (or characters) reserved in the column. + Limited(u16), + /// Unlimited, stored in the heap outside of the row. + Max, +} -#[derive(Debug)] -pub enum TypeInfo { +#[derive(Debug, Clone)] +pub(crate) enum TypeInfoInner { FixedLen(FixedLenType), VarLenSized(VarLenContext), VarLenSizedPrecision { @@ -49,6 +69,53 @@ impl VarLenContext { } } +impl Encode for VarLenContext { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + dst.put_u8(self.r#type() as u8); + + // length + match self.r#type { + #[cfg(feature = "tds73")] + VarLenType::Daten + | VarLenType::Timen + | VarLenType::DatetimeOffsetn + | VarLenType::Datetime2 => { + dst.put_u8(self.len() as u8); + } + VarLenType::Bitn + | VarLenType::Intn + | VarLenType::Floatn + | VarLenType::Decimaln + | VarLenType::Numericn + | VarLenType::Guid + | VarLenType::Money + | VarLenType::Datetimen => { + dst.put_u8(self.len() as u8); + } + VarLenType::NChar + | VarLenType::BigChar + | VarLenType::NVarchar + | VarLenType::BigVarChar + | VarLenType::BigBinary + | VarLenType::BigVarBin => { + dst.put_u16_le(self.len() as u16); + } + VarLenType::Image | VarLenType::Text | VarLenType::NText => { + dst.put_u32_le(self.len() as u32); + } + VarLenType::Xml => (), + typ => todo!("encoding {:?} is not supported yet", typ), + } + + if let Some(collation) = self.collation() { + dst.put_u32_le(collation.info); + dst.put_u8(collation.sort_id); + } + + Ok(()) + } +} + uint_enum! { #[repr(u8)] pub enum FixedLenType { @@ -143,7 +210,241 @@ uint_enum! { } } +impl Encode for TypeInfo { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + match self.inner { + TypeInfoInner::FixedLen(ty) => { + dst.put_u8(ty as u8); + } + TypeInfoInner::VarLenSized(ctx) => ctx.encode(dst)?, + TypeInfoInner::VarLenSizedPrecision { + ty, + size, + precision, + scale, + } => { + dst.put_u8(ty as u8); + dst.put_u8(size as u8); + dst.put_u8(precision as u8); + dst.put_u8(scale as u8); + } + TypeInfoInner::Xml { .. } => { + unreachable!() + } + } + + Ok(()) + } +} + impl TypeInfo { + /// A bit, either zero or one. + pub fn bit() -> Self { + Self::fixed(FixedLenType::Bit) + } + + /// 8-bit integer, unsigned. + pub fn tinyint() -> Self { + Self::fixed(FixedLenType::Int1) + } + + /// 16-bit integer, signed. + pub fn smallint() -> Self { + Self::fixed(FixedLenType::Int2) + } + + /// 32-bit integer, signed. + pub fn int() -> Self { + Self::fixed(FixedLenType::Int4) + } + + /// 64-bit integer, signed. + pub fn bigint() -> Self { + Self::fixed(FixedLenType::Int8) + } + + /// 32-bit floating point number. + pub fn real() -> Self { + Self::fixed(FixedLenType::Float4) + } + + /// 64-bit floating point number. + pub fn float() -> Self { + Self::fixed(FixedLenType::Float8) + } + + /// 32-bit money type. + pub fn smallmoney() -> Self { + Self::fixed(FixedLenType::Money4) + } + + /// 64-bit money type. + pub fn money() -> Self { + Self::fixed(FixedLenType::Money) + } + + /// A small DateTime value. + pub fn smalldatetime() -> Self { + Self::fixed(FixedLenType::Datetime4) + } + + /// A datetime value. + pub fn datetime() -> Self { + Self::fixed(FixedLenType::Datetime) + } + + /// A datetime2 value. + #[cfg(feature = "tds73")] + pub fn datetime2() -> Self { + Self::varlen(VarLenType::Datetime2, 8, None) + } + + /// A uniqueidentifier value. + pub fn guid() -> Self { + Self::varlen(VarLenType::Guid, 16, None) + } + + /// A date value. + #[cfg(feature = "tds73")] + pub fn date() -> Self { + Self::varlen(VarLenType::Daten, 3, None) + } + + /// A time value. + #[cfg(feature = "tds73")] + pub fn time() -> Self { + Self::varlen(VarLenType::Timen, 5, None) + } + + /// A time value. + #[cfg(feature = "tds73")] + pub fn datetimeoffset() -> Self { + Self::varlen(VarLenType::DatetimeOffsetn, 10, None) + } + + /// A variable binary value. If length is limited and larger than 8000 + /// bytes, the `MAX` variant is used instead. + pub fn varbinary(length: TypeLength) -> Self { + let length = match length { + TypeLength::Limited(n) if n <= 8000 => n, + _ => u16::MAX, + }; + + Self::varlen(VarLenType::BigVarBin, length as usize, None) + } + + /// A binary value. + /// + /// # Panics + /// + /// - If length is more than 8000 bytes. + pub fn binary(length: u16) -> Self { + assert!(length <= 8000); + Self::varlen(VarLenType::BigBinary, length as usize, None) + } + + /// A variable string value. If length is limited and larger than 8000 + /// characters, the `MAX` variant is used instead. + pub fn varchar(length: TypeLength) -> Self { + let length = match length { + TypeLength::Limited(n) if n <= 8000 => n, + _ => u16::MAX, + }; + + Self::varlen(VarLenType::BigVarChar, length as usize, None) + } + + /// A variable UTF-16 string value. If length is limited and larger than + /// 4000 characters, the `MAX` variant is used instead. + pub fn nvarchar(length: TypeLength) -> Self { + let length = match length { + TypeLength::Limited(n) if n <= 4000 => n, + _ => u16::MAX, + }; + + Self::varlen(VarLenType::BigVarChar, length as usize, None) + } + + /// A constant-size string value. + /// + /// # Panics + /// + /// - If length is more than 8000 characters. + pub fn char(length: u16) -> Self { + assert!(length <= 8000); + Self::varlen(VarLenType::BigChar, length as usize, None) + } + + /// A constant-size UTF-16 string value. + /// + /// # Panics + /// + /// - If length is more than 4000 characters. + pub fn nchar(length: u16) -> Self { + assert!(length <= 4000); + Self::varlen(VarLenType::NChar, length as usize, None) + } + + /// A (deprecated) heap-allocated text storage. + pub fn text() -> Self { + Self::varlen(VarLenType::Text, u32::MAX as usize, None) + } + + /// A (deprecated) heap-allocated UTF-16 text storage. + pub fn ntext() -> Self { + Self::varlen(VarLenType::NText, u32::MAX as usize, None) + } + + /// A (deprecated) heap-allocated binary storage. + pub fn image() -> Self { + Self::varlen(VarLenType::Image, u32::MAX as usize, None) + } + + /// Numeric data types that have fixed precision and scale. Decimal and + /// numeric are synonyms and can be used interchangeably. + pub fn decimal(precision: u8, scale: u8) -> Self { + Self::varlen_precision(VarLenType::Decimaln, precision, scale) + } + + /// Numeric data types that have fixed precision and scale. Decimal and + /// numeric are synonyms and can be used interchangeably. + pub fn numeric(precision: u8, scale: u8) -> Self { + Self::varlen_precision(VarLenType::Numericn, precision, scale) + } + + fn varlen_precision(ty: VarLenType, precision: u8, scale: u8) -> Self { + let size = if precision <= 9 { + 5 + } else if precision <= 19 { + 9 + } else if precision <= 28 { + 13 + } else { + 17 + }; + + let inner = TypeInfoInner::VarLenSizedPrecision { + ty, + size, + precision, + scale, + }; + + Self { inner } + } + + fn varlen(ty: VarLenType, len: usize, collation: Option) -> Self { + let cx = VarLenContext::new(ty, len, collation); + let inner = TypeInfoInner::VarLenSized(cx); + + Self { inner } + } + + fn fixed(ty: FixedLenType) -> Self { + let inner = TypeInfoInner::FixedLen(ty); + Self { inner } + } + pub(crate) async fn decode(src: &mut R) -> crate::Result where R: SqlReadBytes + Unpin, @@ -151,7 +452,8 @@ impl TypeInfo { let ty = src.read_u8().await?; if let Ok(ty) = FixedLenType::try_from(ty) { - return Ok(TypeInfo::FixedLen(ty)); + let inner = TypeInfoInner::FixedLen(ty); + return Ok(TypeInfo { inner }); } match VarLenType::try_from(ty) { @@ -173,10 +475,12 @@ impl TypeInfo { None }; - Ok(TypeInfo::Xml { + let inner = TypeInfoInner::Xml { schema, size: 0xfffffffffffffffe_usize, - }) + }; + + Ok(TypeInfo { inner }) } Ok(ty) => { let len = match ty { @@ -226,16 +530,20 @@ impl TypeInfo { let precision = src.read_u8().await?; let scale = src.read_u8().await?; - TypeInfo::VarLenSizedPrecision { + let inner = TypeInfoInner::VarLenSizedPrecision { size: len, ty, precision, scale, - } + }; + + TypeInfo { inner } } _ => { let cx = VarLenContext::new(ty, len, collation); - TypeInfo::VarLenSized(cx) + let inner = TypeInfoInner::VarLenSized(cx); + + TypeInfo { inner } } }; diff --git a/src/tds/collation.rs b/src/tds/collation.rs index 69588caa..65d4abdf 100644 --- a/src/tds/collation.rs +++ b/src/tds/collation.rs @@ -12,9 +12,9 @@ use crate::error::Error; #[derive(Debug, Clone, Copy)] pub struct Collation { /// LCID ColFlags Version - info: u32, + pub(crate) info: u32, /// Sortid - sort_id: u8, + pub(crate) sort_id: u8, } impl Collation { diff --git a/src/tds/stream/token.rs b/src/tds/stream/token.rs index d9e80093..0b7dd3dd 100644 --- a/src/tds/stream/token.rs +++ b/src/tds/stream/token.rs @@ -14,7 +14,7 @@ use tracing::{event, Level}; #[derive(Debug)] pub enum ReceivedToken { NewResultset(Arc), - Row(TokenRow), + Row(TokenRow<'static>), Done(TokenDone), DoneInProc(TokenDone), DoneProc(TokenDone), diff --git a/src/to_sql.rs b/src/to_sql.rs index 235e6722..401d5d24 100644 --- a/src/to_sql.rs +++ b/src/to_sql.rs @@ -74,6 +74,14 @@ into_sql!(self_, String: (ColumnData::String, Cow::from(self_)); Vec: (ColumnData::Binary, Cow::from(self_)); XmlData: (ColumnData::Xml, Cow::Owned(self_)); + bool: (ColumnData::Bit, self_); + u8: (ColumnData::U8, self_); + i16: (ColumnData::I16, self_); + i32: (ColumnData::I32, self_); + i64: (ColumnData::I64, self_); + f32: (ColumnData::F32, self_); + f64: (ColumnData::F64, self_); + Uuid: (ColumnData::Guid, self_); ); to_sql!(self_, diff --git a/tests/query.rs b/tests/query.rs index 048732aa..5d79d3a5 100644 --- a/tests/query.rs +++ b/tests/query.rs @@ -1794,6 +1794,7 @@ where ) .await?; + assert_eq!(&[1], res.rows_affected()); assert_eq!(1, res.total()); let row = conn From 345d47c9c8c4b28c3703328ddcf283495586f19b Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 10 Aug 2021 16:24:39 +0200 Subject: [PATCH 2/3] Start testing, poking the API --- Cargo.toml | 3 + examples/bulk.rs | 23 ++- src/client.rs | 8 +- src/lib.rs | 9 +- src/tds/codec/bulk_load.rs | 21 ++- src/tds/codec/token/token_col_metadata.rs | 38 ++-- src/tds/codec/token/token_row.rs | 2 + src/tds/codec/token/token_row/into_row.rs | 205 ++++++++++++++++++++++ src/tds/codec/type_info.rs | 2 +- src/tds/context.rs | 6 +- src/tds/stream/query.rs | 2 +- src/tds/stream/token.rs | 2 +- src/to_sql.rs | 1 + tests/bulk.rs | 128 ++++++++++++++ 14 files changed, 412 insertions(+), 38 deletions(-) create mode 100644 src/tds/codec/token/token_row/into_row.rs create mode 100644 tests/bulk.rs diff --git a/Cargo.toml b/Cargo.toml index db2fbc74..fb57b873 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,6 +127,9 @@ path = "./runtimes-macro" names = "0.11" anyhow = "1" env_logger = "0.7" +indicatif = "0.16" +console = "0.14" +paste = "1.0" [package.metadata.docs.rs] features = ["all", "docs"] diff --git a/examples/bulk.rs b/examples/bulk.rs index 51b5845d..e2956c13 100644 --- a/examples/bulk.rs +++ b/examples/bulk.rs @@ -1,6 +1,7 @@ +use indicatif::ProgressBar; use once_cell::sync::Lazy; use std::env; -use tiberius::{BulkLoadMetadata, Client, Config, IntoSql, TokenRow, TypeInfo}; +use tiberius::{BulkLoadMetadata, Client, ColumnFlag, Config, IntoSql, TokenRow, TypeInfo}; use tokio::net::TcpStream; use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -21,18 +22,30 @@ async fn main() -> anyhow::Result<()> { let mut client = Client::connect(config, tcp.compat_write()).await?; + client + .execute( + "CREATE TABLE ##bulk_test1 (id INT IDENTITY PRIMARY KEY, content INT)", + &[], + ) + .await?; + let mut meta = BulkLoadMetadata::new(); - meta.add_column("val", TypeInfo::int()); + meta.add_column("content", TypeInfo::int(), ColumnFlag::Nullable.into()); + + let mut req = client.bulk_insert("##bulk_test1", meta).await?; + let count = 2000i32; - let mut req = client.bulk_insert("bulk_test1", meta).await?; + let pb = ProgressBar::new(count as u64); - for i in [0, 1, 2, 3, 4, 5] { + for i in 0..count { let mut row = TokenRow::new(); row.push(i.into_sql()); - req.send(row).await?; + pb.inc(1); } + pb.finish_with_message("waiting..."); + let res = req.finalize().await?; dbg!(res); diff --git a/src/client.rs b/src/client.rs index e70c1a08..c5d996fb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -279,11 +279,11 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn bulk_insert( - &mut self, + pub async fn bulk_insert<'a>( + &'a mut self, table: &str, - meta: BulkLoadMetadata, - ) -> crate::Result> { + meta: BulkLoadMetadata<'a>, + ) -> crate::Result> { // Start the bulk request self.connection.flush_stream().await?; diff --git a/src/lib.rs b/src/lib.rs index b8806656..05d75732 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -196,7 +196,14 @@ pub use from_sql::{FromSql, FromSqlOwned}; pub use result::*; pub use row::{Column, ColumnType, Row}; pub use sql_browser::SqlBrowser; -pub use tds::{EncryptionLevel, codec::{BulkLoadMetadata, BulkLoadRequest, ColumnData, TokenRow, TypeInfo, TypeLength}, numeric, stream::QueryStream, time, xml}; +pub use tds::{ + codec::{ + BulkLoadMetadata, BulkLoadRequest, ColumnData, ColumnFlag, TokenRow, TypeInfo, TypeLength, + }, + numeric, + stream::QueryStream, + time, xml, EncryptionLevel, +}; pub use to_sql::{IntoSql, ToSql}; pub use uuid::Uuid; diff --git a/src/tds/codec/bulk_load.rs b/src/tds/codec/bulk_load.rs index c8a484a7..fbc1c7c5 100644 --- a/src/tds/codec/bulk_load.rs +++ b/src/tds/codec/bulk_load.rs @@ -6,17 +6,17 @@ use tracing::{event, Level}; use crate::{client::Connection, sql_read_bytes::SqlReadBytes, ExecuteResult}; use super::{ - BaseMetaDataColumn, Encode, MetaDataColumn, PacketHeader, PacketStatus, TokenColMetaData, - TokenDone, TokenRow, TypeInfo, HEADER_BYTES, + BaseMetaDataColumn, ColumnFlag, Encode, MetaDataColumn, PacketHeader, PacketStatus, + TokenColMetaData, TokenDone, TokenRow, TypeInfo, HEADER_BYTES, }; /// Column metadata for a bulk load request. #[derive(Debug, Default, Clone)] -pub struct BulkLoadMetadata { - columns: Vec, +pub struct BulkLoadMetadata<'a> { + columns: Vec>, } -impl BulkLoadMetadata { +impl<'a> BulkLoadMetadata<'a> { /// Creates a metadata with no columns specified. pub fn new() -> Self { Self::default() @@ -24,10 +24,13 @@ impl BulkLoadMetadata { /// Add a column to the request. Order should be same as the order of data /// in the rows. - pub fn add_column(&mut self, name: &str, ty: TypeInfo) { + pub fn add_column(&mut self, name: &'a str, ty: TypeInfo, flags: C) + where + C: Into>, + { self.columns.push(MetaDataColumn { base: BaseMetaDataColumn { - flags: BitFlags::empty(), + flags: flags.into(), ty, }, col_name: name.into(), @@ -39,7 +42,7 @@ impl BulkLoadMetadata { } } -impl Encode for BulkLoadMetadata { +impl<'a> Encode for BulkLoadMetadata<'a> { fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { let cmd = TokenColMetaData { columns: self.columns, @@ -66,7 +69,7 @@ where { pub(crate) fn new( connection: &'a mut Connection, - meta: BulkLoadMetadata, + meta: BulkLoadMetadata<'a>, ) -> crate::Result { let packet_id = connection.context_mut().next_packet_id(); let mut buf = BytesMut::new(); diff --git a/src/tds/codec/token/token_col_metadata.rs b/src/tds/codec/token/token_col_metadata.rs index db0085a0..b8785752 100644 --- a/src/tds/codec/token/token_col_metadata.rs +++ b/src/tds/codec/token/token_col_metadata.rs @@ -1,4 +1,7 @@ -use std::{borrow::BorrowMut, fmt::Display}; +use std::{ + borrow::{BorrowMut, Cow}, + fmt::Display, +}; use crate::{ error::Error, @@ -10,17 +13,17 @@ use bytes::BufMut; use enumflags2::{bitflags, BitFlags}; #[derive(Debug, Clone)] -pub struct TokenColMetaData { - pub columns: Vec, +pub struct TokenColMetaData<'a> { + pub columns: Vec>, } #[derive(Debug, Clone)] -pub struct MetaDataColumn { +pub struct MetaDataColumn<'a> { pub base: BaseMetaDataColumn, - pub col_name: String, + pub col_name: Cow<'a, str>, } -impl Display for MetaDataColumn { +impl<'a> Display for MetaDataColumn<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{} ", self.col_name)?; @@ -76,6 +79,13 @@ impl Display for MetaDataColumn { VarLenType::Text => write!(f, "text")?, VarLenType::Image => write!(f, "image")?, VarLenType::NText => write!(f, "ntext")?, + VarLenType::Intn => match ctx.len() { + 1 => write!(f, "tinyint")?, + 2 => write!(f, "smallint")?, + 4 => write!(f, "int")?, + 8 => write!(f, "bigint")?, + _ => unreachable!(), + }, _ => unreachable!(), }, TypeInfoInner::VarLenSizedPrecision { @@ -183,7 +193,7 @@ impl BaseMetaDataColumn { } } -impl Encode for TokenColMetaData { +impl<'a> Encode for TokenColMetaData<'a> { fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { dst.put_u8(TokenType::ColMetaData as u8); dst.put_u16_le(self.columns.len() as u16); @@ -196,7 +206,7 @@ impl Encode for TokenColMetaData { } } -impl Encode for MetaDataColumn { +impl<'a> Encode for MetaDataColumn<'a> { fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { dst.put_u32_le(0); self.base.encode(dst)?; @@ -227,6 +237,7 @@ impl Encode for BaseMetaDataColumn { } } +/// A setting a column can hold. #[bitflags] #[repr(u16)] #[derive(Debug, Clone, Copy, PartialEq)] @@ -263,8 +274,7 @@ pub enum ColumnFlag { NullableUnknown = 1 << 15, } -#[allow(dead_code)] -impl TokenColMetaData { +impl TokenColMetaData<'static> { pub(crate) async fn decode(src: &mut R) -> crate::Result where R: SqlReadBytes + Unpin, @@ -275,7 +285,7 @@ impl TokenColMetaData { if column_count > 0 && column_count < 0xffff { for _ in 0..column_count { let base = BaseMetaDataColumn::decode(src).await?; - let col_name = src.read_b_varchar().await?; + let col_name = Cow::from(src.read_b_varchar().await?); columns.push(MetaDataColumn { base, col_name }); } @@ -283,10 +293,12 @@ impl TokenColMetaData { Ok(TokenColMetaData { columns }) } +} - pub(crate) fn columns<'a>(&'a self) -> impl Iterator + 'a { +impl<'a> TokenColMetaData<'a> { + pub(crate) fn columns(&'a self) -> impl Iterator + 'a { self.columns.iter().map(|x| Column { - name: x.col_name.clone(), + name: x.col_name.to_string(), column_type: ColumnType::from(&x.base.ty), }) } diff --git a/src/tds/codec/token/token_row.rs b/src/tds/codec/token/token_row.rs index de72a3ec..ce52b44b 100644 --- a/src/tds/codec/token/token_row.rs +++ b/src/tds/codec/token/token_row.rs @@ -1,3 +1,4 @@ +mod into_row; use crate::{ tds::codec::{BufColumnData, ColumnData, Encode}, SqlReadBytes, TokenType, @@ -5,6 +6,7 @@ use crate::{ use asynchronous_codec::BytesMut; use bytes::BufMut; use futures::io::AsyncReadExt; +pub use into_row::IntoRow; /// A row of data. #[derive(Debug, Default, Clone)] diff --git a/src/tds/codec/token/token_row/into_row.rs b/src/tds/codec/token/token_row/into_row.rs new file mode 100644 index 00000000..0989388f --- /dev/null +++ b/src/tds/codec/token/token_row/into_row.rs @@ -0,0 +1,205 @@ +use crate::{IntoSql, TokenRow}; + +pub trait IntoRow<'a> { + fn into_row(self) -> TokenRow<'a>; +} + +impl<'a, A> IntoRow<'a> for A +where + A: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.into_sql()); + row + } +} + +impl<'a, A, B> IntoRow<'a> for (A, B) +where + A: IntoSql, + B: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row + } +} + +impl<'a, A, B, C> IntoRow<'a> for (A, B, C) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row + } +} + +impl<'a, A, B, C, D> IntoRow<'a> for (A, B, C, D) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row + } +} + +impl<'a, A, B, C, D, E> IntoRow<'a> for (A, B, C, D, E) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, + E: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row.push(self.4.into_sql()); + row + } +} + +impl<'a, A, B, C, D, E, F> IntoRow<'a> for (A, B, C, D, E, F) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, + E: IntoSql, + F: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row.push(self.4.into_sql()); + row.push(self.5.into_sql()); + row + } +} + +impl<'a, A, B, C, D, E, F, G> IntoRow<'a> for (A, B, C, D, E, F, G) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, + E: IntoSql, + F: IntoSql, + G: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row.push(self.4.into_sql()); + row.push(self.5.into_sql()); + row.push(self.6.into_sql()); + row + } +} + +impl<'a, A, B, C, D, E, F, G, H> IntoRow<'a> for (A, B, C, D, E, F, G, H) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, + E: IntoSql, + F: IntoSql, + G: IntoSql, + H: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row.push(self.4.into_sql()); + row.push(self.5.into_sql()); + row.push(self.6.into_sql()); + row.push(self.7.into_sql()); + row + } +} + +impl<'a, A, B, C, D, E, F, G, H, I> IntoRow<'a> for (A, B, C, D, E, F, G, H, I) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, + E: IntoSql, + F: IntoSql, + G: IntoSql, + H: IntoSql, + I: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row.push(self.4.into_sql()); + row.push(self.5.into_sql()); + row.push(self.6.into_sql()); + row.push(self.7.into_sql()); + row.push(self.8.into_sql()); + row + } +} + +impl<'a, A, B, C, D, E, F, G, H, I, J> IntoRow<'a> for (A, B, C, D, E, F, G, H, I, J) +where + A: IntoSql, + B: IntoSql, + C: IntoSql, + D: IntoSql, + E: IntoSql, + F: IntoSql, + G: IntoSql, + H: IntoSql, + I: IntoSql, + J: IntoSql, +{ + fn into_row(self) -> TokenRow<'a> { + let mut row = TokenRow::new(); + row.push(self.0.into_sql()); + row.push(self.1.into_sql()); + row.push(self.2.into_sql()); + row.push(self.3.into_sql()); + row.push(self.4.into_sql()); + row.push(self.5.into_sql()); + row.push(self.6.into_sql()); + row.push(self.7.into_sql()); + row.push(self.8.into_sql()); + row.push(self.9.into_sql()); + row + } +} diff --git a/src/tds/codec/type_info.rs b/src/tds/codec/type_info.rs index 8b5f625e..3dcca1d6 100644 --- a/src/tds/codec/type_info.rs +++ b/src/tds/codec/type_info.rs @@ -212,7 +212,7 @@ uint_enum! { impl Encode for TypeInfo { fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { - match self.inner { + match dbg!(self.inner) { TypeInfoInner::FixedLen(ty) => { dst.put_u8(ty as u8); } diff --git a/src/tds/context.rs b/src/tds/context.rs index ee6307a4..4c43b22b 100644 --- a/src/tds/context.rs +++ b/src/tds/context.rs @@ -8,7 +8,7 @@ pub(crate) struct Context { packet_size: u32, packet_id: u8, transaction_desc: [u8; 8], - last_meta: Option>, + last_meta: Option>>, spn: Option, } @@ -30,11 +30,11 @@ impl Context { id } - pub fn set_last_meta(&mut self, meta: Arc) { + pub fn set_last_meta(&mut self, meta: Arc>) { self.last_meta.replace(meta); } - pub fn last_meta(&self) -> Option> { + pub fn last_meta(&self) -> Option>> { self.last_meta.as_ref().map(Arc::clone) } diff --git a/src/tds/stream/query.rs b/src/tds/stream/query.rs index bf60990b..fd41c0f9 100644 --- a/src/tds/stream/query.rs +++ b/src/tds/stream/query.rs @@ -371,7 +371,7 @@ impl<'a> Stream for QueryStream<'a> { .columns .iter() .map(|x| Column { - name: x.col_name.clone(), + name: x.col_name.to_string(), column_type: ColumnType::from(&x.base.ty), }) .collect::>(); diff --git a/src/tds/stream/token.rs b/src/tds/stream/token.rs index 0b7dd3dd..4e325a05 100644 --- a/src/tds/stream/token.rs +++ b/src/tds/stream/token.rs @@ -13,7 +13,7 @@ use tracing::{event, Level}; #[derive(Debug)] pub enum ReceivedToken { - NewResultset(Arc), + NewResultset(Arc>), Row(TokenRow<'static>), Done(TokenDone), DoneInProc(TokenDone), diff --git a/src/to_sql.rs b/src/to_sql.rs index 401d5d24..5dee7461 100644 --- a/src/to_sql.rs +++ b/src/to_sql.rs @@ -72,6 +72,7 @@ pub trait IntoSql: Send + Sync { into_sql!(self_, String: (ColumnData::String, Cow::from(self_)); + &'static str: (ColumnData::String, Cow::from(self_)); Vec: (ColumnData::Binary, Cow::from(self_)); XmlData: (ColumnData::Xml, Cow::Owned(self_)); bool: (ColumnData::Bit, self_); diff --git a/tests/bulk.rs b/tests/bulk.rs new file mode 100644 index 00000000..070c1dda --- /dev/null +++ b/tests/bulk.rs @@ -0,0 +1,128 @@ +use enumflags2::BitFlags; +use futures::{lock::Mutex, AsyncRead, AsyncWrite}; +use names::{Generator, Name}; +use once_cell::sync::Lazy; +use std::env; +use std::sync::Once; +use tiberius::{BulkLoadMetadata, ColumnFlag, IntoSql, Result, TokenRow, TypeInfo}; + +use runtimes_macro::test_on_runtimes; + +// This is used in the testing macro :) +#[allow(dead_code)] +static LOGGER_SETUP: Once = Once::new(); + +static CONN_STR: Lazy = Lazy::new(|| { + env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or_else(|_| { + "server=tcp:localhost,1433;IntegratedSecurity=true;TrustServerCertificate=true".to_owned() + }) +}); + +static NAMES: Lazy> = + Lazy::new(|| Mutex::new(Generator::with_naming(Name::Plain))); + +async fn random_table() -> String { + NAMES.lock().await.next().unwrap().replace('-', "") +} + +macro_rules! test_bulk_type { + ($name:ident($sql_type:literal, $type_info:expr, $total_generated:expr, $generator:expr)) => { + paste::item! { + #[test_on_runtimes] + async fn [< bulk_load_optional_ $name >](mut conn: tiberius::Client) -> Result<()> + where + S: AsyncRead + AsyncWrite + Unpin + Send, + { + let table = format!("##{}", random_table().await); + + conn.execute( + &format!( + "CREATE TABLE {} (id INT IDENTITY PRIMARY KEY, content {} NULL)", + table, + $sql_type, + ), + &[], + ) + .await?; + + let mut meta = BulkLoadMetadata::new(); + meta.add_column("content", $type_info, ColumnFlag::Nullable); + + let mut req = conn.bulk_insert(&table, meta).await?; + + for i in $generator { + let mut row = TokenRow::new(); + row.push(i.into_sql()); + req.send(row).await?; + } + + let res = req.finalize().await?; + + assert_eq!($total_generated, res.total()); + + Ok(()) + } + + #[test_on_runtimes] + async fn [< bulk_load_required_ $name >](mut conn: tiberius::Client) -> Result<()> + where + S: AsyncRead + AsyncWrite + Unpin + Send, + { + let table = format!("##{}", random_table().await); + + conn.execute( + &format!( + "CREATE TABLE {} (id INT IDENTITY PRIMARY KEY, content {} NOT NULL)", + table, + $sql_type + ), + &[], + ) + .await?; + + let mut meta = BulkLoadMetadata::new(); + meta.add_column("content", $type_info, BitFlags::empty()); + + let mut req = conn.bulk_insert(&table, meta).await?; + + for i in $generator { + let mut row = TokenRow::new(); + row.push(i.into_sql()); + req.send(row).await?; + } + + let res = req.finalize().await?; + + assert_eq!($total_generated, res.total()); + + Ok(()) + } + } + }; +} + +test_bulk_type!(tinyint("TINYINT", TypeInfo::tinyint(), 256, 0..=255u8)); +test_bulk_type!(smallint("SMALLINT", TypeInfo::smallint(), 2000, 0..2000i16)); +test_bulk_type!(int("INT", TypeInfo::int(), 2000, 0..2000i32)); +test_bulk_type!(bigint("BIGINT", TypeInfo::bigint(), 2000, 0..2000i64)); + +test_bulk_type!(real( + "REAL", + TypeInfo::real(), + 1000, + vec![3.14f32; 1000].into_iter() +)); + +test_bulk_type!(float( + "FLOAT", + TypeInfo::float(), + 1000, + vec![3.14f64; 1000].into_iter() +)); + +test_bulk_type!(varchar_limited( + "VARCHAR(255)", + TypeInfo::float(), + 1000, + vec!["aaaaaaaaaaaaaaaaaaaaaaa"; 1000].into_iter() +)); From e8f03bcd32045bc3adcc5e7d797d56da5bfa869c Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 10 Aug 2021 22:34:22 +0200 Subject: [PATCH 3/3] wipwip --- examples/bulk.rs | 4 +- examples/tokio.rs | 28 +------ src/tds/codec/bulk_load.rs | 1 + src/tds/codec/column_data.rs | 2 +- src/tds/codec/token/token_env_change.rs | 98 +++++++++++-------------- src/tds/codec/type_info.rs | 34 ++++----- src/tds/collation.rs | 11 +++ src/tds/context.rs | 12 ++- src/tds/stream/token.rs | 5 ++ tests/bulk.rs | 4 +- 10 files changed, 95 insertions(+), 104 deletions(-) diff --git a/examples/bulk.rs b/examples/bulk.rs index e2956c13..908054fd 100644 --- a/examples/bulk.rs +++ b/examples/bulk.rs @@ -24,7 +24,7 @@ async fn main() -> anyhow::Result<()> { client .execute( - "CREATE TABLE ##bulk_test1 (id INT IDENTITY PRIMARY KEY, content INT)", + "CREATE TABLE ##bulk_test1 (id INT IDENTITY PRIMARY KEY, content VARCHAR(255))", &[], ) .await?; @@ -37,7 +37,7 @@ async fn main() -> anyhow::Result<()> { let pb = ProgressBar::new(count as u64); - for i in 0..count { + for i in vec!["aaaaaaaaaaaaaaaaaaaa"; 1000].into_iter() { let mut row = TokenRow::new(); row.push(i.into_sql()); req.send(row).await?; diff --git a/examples/tokio.rs b/examples/tokio.rs index 86162acc..7b5aea15 100644 --- a/examples/tokio.rs +++ b/examples/tokio.rs @@ -10,9 +10,9 @@ static CONN_STR: Lazy = Lazy::new(|| { }) }); -#[cfg(not(all(windows, feature = "sql-browser-tokio")))] #[tokio::main] async fn main() -> anyhow::Result<()> { + env_logger::init(); let config = Config::from_ado_string(&CONN_STR)?; let tcp = TcpStream::connect(config.get_addr()).await?; @@ -20,32 +20,10 @@ async fn main() -> anyhow::Result<()> { let mut client = Client::connect(config, tcp.compat_write()).await?; - let stream = client.query("SELECT @P1", &[&1i32]).await?; - let row = stream.into_row().await?.unwrap(); + let stream = client.query("SELECT * from test", &[]).await?; + let row = stream.into_row().await?; println!("{:?}", row); - assert_eq!(Some(1), row.get(0)); - - Ok(()) -} - -#[cfg(all(windows, feature = "sql-browser-tokio"))] -#[tokio::main] -async fn main() -> anyhow::Result<()> { - use tiberius::SqlBrowser; - - let config = Config::from_ado_string(&CONN_STR)?; - - let tcp = TcpStream::connect_named(&config).await?; - tcp.set_nodelay(true)?; - - let mut client = Client::connect(config, tcp.compat_write()).await?; - - let stream = client.query("SELECT @P1", &[&1i32]).await?; - let row = stream.into_row().await?.unwrap(); - - println!("{:?}", row); - assert_eq!(Some(1), row.get(0)); Ok(()) } diff --git a/src/tds/codec/bulk_load.rs b/src/tds/codec/bulk_load.rs index fbc1c7c5..70626953 100644 --- a/src/tds/codec/bulk_load.rs +++ b/src/tds/codec/bulk_load.rs @@ -72,6 +72,7 @@ where meta: BulkLoadMetadata<'a>, ) -> crate::Result { let packet_id = connection.context_mut().next_packet_id(); + let collation = connection.context().collation(); let mut buf = BytesMut::new(); meta.encode(&mut buf)?; diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 764bbcf6..c31ac603 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -126,7 +126,7 @@ impl<'a> ColumnData<'a> { where R: SqlReadBytes + Unpin, { - let res = match &ctx.inner { + let res = match dbg!(&ctx.inner) { TypeInfoInner::FixedLen(fixed_ty) => fixed_len::decode(src, fixed_ty).await?, TypeInfoInner::VarLenSized(cx) => var_len::decode(src, cx).await?, TypeInfoInner::VarLenSizedPrecision { ty, scale, .. } => match ty { diff --git a/src/tds/codec/token/token_env_change.rs b/src/tds/codec/token/token_env_change.rs index d6d5971b..003330de 100644 --- a/src/tds/codec/token/token_env_change.rs +++ b/src/tds/codec/token/token_env_change.rs @@ -1,14 +1,12 @@ -use crate::{ - tds::{lcid_to_encoding, sortid_to_encoding}, - Error, SqlReadBytes, -}; +use crate::{tds::Collation, Error, SqlReadBytes}; use byteorder::{LittleEndian, ReadBytesExt}; -use encoding::Encoding; use fmt::Debug; use futures::io::AsyncReadExt; -use std::io::Cursor; -use std::io::Read; -use std::{convert::TryFrom, fmt}; +use std::{ + convert::TryFrom, + fmt, + io::{Cursor, Read}, +}; uint_enum! { #[repr(u8)] @@ -63,52 +61,13 @@ impl fmt::Display for EnvChangeTy { } } -pub struct CollationInfo { - lcid_encoding: Option<&'static (dyn Encoding + Send + Sync)>, - sortid_encoding: Option<&'static (dyn Encoding + Send + Sync)>, -} - -impl CollationInfo { - pub fn new(bytes: &[u8]) -> Self { - let lcid_encoding = match (bytes.get(0), bytes.get(1)) { - (Some(fst), Some(snd)) => lcid_to_encoding(u16::from_le_bytes([*fst, *snd])), - _ => None, - }; - - let sortid_encoding = match bytes.get(4) { - Some(byte) => sortid_to_encoding(*byte), - _ => None, - }; - - Self { - lcid_encoding, - sortid_encoding, - } - } -} - -impl Debug for CollationInfo { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(self, f) - } -} - -impl fmt::Display for CollationInfo { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match (self.lcid_encoding, self.sortid_encoding) { - (Some(lcid), Some(sortid)) => write!(f, "{}/{}", lcid.name(), sortid.name()), - _ => write!(f, "None"), - } - } -} - #[derive(Debug)] pub enum TokenEnvChange { Database(String, String), PacketSize(u32, u32), SqlCollation { - old: CollationInfo, - new: CollationInfo, + old: Option, + new: Option, }, BeginTransaction([u8; 8]), CommitTransaction, @@ -131,9 +90,11 @@ impl fmt::Display for TokenEnvChange { Self::PacketSize(old, new) => { write!(f, "Packet size change from '{}' to '{}'", old, new) } - Self::SqlCollation { old, new } => { - write!(f, "SQL collation change from {} to {}", old, new) - } + Self::SqlCollation { old, new } => match (old, new) { + (Some(old), Some(new)) => write!(f, "SQL collation change from {} to {}", old, new), + (_, Some(new)) => write!(f, "SQL collation changed to {}", new), + (_, _) => write!(f, "SQL collation change"), + }, Self::BeginTransaction(_) => write!(f, "Begin transaction"), Self::CommitTransaction => write!(f, "Commit transaction"), Self::RollbackTransaction => write!(f, "Rollback transaction"), @@ -216,14 +177,39 @@ impl TokenEnvChange { let mut new_value = vec![0; len]; buf.read_exact(&mut new_value[0..len])?; + let new = if len == 5 { + let new_sortid = new_value[4]; + let new_info = u32::from_le_bytes([ + new_value[0], + new_value[1], + new_value[2], + new_value[3], + ]); + + Some(Collation::new(new_info, new_sortid)) + } else { + None + }; + let len = buf.read_u8()? as usize; let mut old_value = vec![0; len]; buf.read_exact(&mut old_value[0..len])?; - TokenEnvChange::SqlCollation { - new: CollationInfo::new(new_value.as_slice()), - old: CollationInfo::new(old_value.as_slice()), - } + let old = if len == 5 { + let old_sortid = old_value[4]; + let old_info = u32::from_le_bytes([ + old_value[0], + old_value[1], + old_value[2], + old_value[3], + ]); + + Some(Collation::new(old_info, old_sortid)) + } else { + None + }; + + TokenEnvChange::SqlCollation { new, old } } EnvChangeTy::BeginTransaction | EnvChangeTy::EnlistDTCTransaction => { let len = buf.read_u8()?; diff --git a/src/tds/codec/type_info.rs b/src/tds/codec/type_info.rs index 3dcca1d6..8e4e7b42 100644 --- a/src/tds/codec/type_info.rs +++ b/src/tds/codec/type_info.rs @@ -212,7 +212,7 @@ uint_enum! { impl Encode for TypeInfo { fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { - match dbg!(self.inner) { + match self.inner { TypeInfoInner::FixedLen(ty) => { dst.put_u8(ty as u8); } @@ -296,30 +296,30 @@ impl TypeInfo { /// A datetime2 value. #[cfg(feature = "tds73")] pub fn datetime2() -> Self { - Self::varlen(VarLenType::Datetime2, 8, None) + Self::varlen(VarLenType::Datetime2, 8) } /// A uniqueidentifier value. pub fn guid() -> Self { - Self::varlen(VarLenType::Guid, 16, None) + Self::varlen(VarLenType::Guid, 16) } /// A date value. #[cfg(feature = "tds73")] pub fn date() -> Self { - Self::varlen(VarLenType::Daten, 3, None) + Self::varlen(VarLenType::Daten, 3) } /// A time value. #[cfg(feature = "tds73")] pub fn time() -> Self { - Self::varlen(VarLenType::Timen, 5, None) + Self::varlen(VarLenType::Timen, 5) } /// A time value. #[cfg(feature = "tds73")] pub fn datetimeoffset() -> Self { - Self::varlen(VarLenType::DatetimeOffsetn, 10, None) + Self::varlen(VarLenType::DatetimeOffsetn, 10) } /// A variable binary value. If length is limited and larger than 8000 @@ -330,7 +330,7 @@ impl TypeInfo { _ => u16::MAX, }; - Self::varlen(VarLenType::BigVarBin, length as usize, None) + Self::varlen(VarLenType::BigVarBin, length as usize) } /// A binary value. @@ -340,7 +340,7 @@ impl TypeInfo { /// - If length is more than 8000 bytes. pub fn binary(length: u16) -> Self { assert!(length <= 8000); - Self::varlen(VarLenType::BigBinary, length as usize, None) + Self::varlen(VarLenType::BigBinary, length as usize) } /// A variable string value. If length is limited and larger than 8000 @@ -351,7 +351,7 @@ impl TypeInfo { _ => u16::MAX, }; - Self::varlen(VarLenType::BigVarChar, length as usize, None) + Self::varlen(VarLenType::BigVarChar, length as usize) } /// A variable UTF-16 string value. If length is limited and larger than @@ -362,7 +362,7 @@ impl TypeInfo { _ => u16::MAX, }; - Self::varlen(VarLenType::BigVarChar, length as usize, None) + Self::varlen(VarLenType::BigVarChar, length as usize) } /// A constant-size string value. @@ -372,7 +372,7 @@ impl TypeInfo { /// - If length is more than 8000 characters. pub fn char(length: u16) -> Self { assert!(length <= 8000); - Self::varlen(VarLenType::BigChar, length as usize, None) + Self::varlen(VarLenType::BigChar, length as usize) } /// A constant-size UTF-16 string value. @@ -382,22 +382,22 @@ impl TypeInfo { /// - If length is more than 4000 characters. pub fn nchar(length: u16) -> Self { assert!(length <= 4000); - Self::varlen(VarLenType::NChar, length as usize, None) + Self::varlen(VarLenType::NChar, length as usize) } /// A (deprecated) heap-allocated text storage. pub fn text() -> Self { - Self::varlen(VarLenType::Text, u32::MAX as usize, None) + Self::varlen(VarLenType::Text, u32::MAX as usize) } /// A (deprecated) heap-allocated UTF-16 text storage. pub fn ntext() -> Self { - Self::varlen(VarLenType::NText, u32::MAX as usize, None) + Self::varlen(VarLenType::NText, u32::MAX as usize) } /// A (deprecated) heap-allocated binary storage. pub fn image() -> Self { - Self::varlen(VarLenType::Image, u32::MAX as usize, None) + Self::varlen(VarLenType::Image, u32::MAX as usize) } /// Numeric data types that have fixed precision and scale. Decimal and @@ -433,8 +433,8 @@ impl TypeInfo { Self { inner } } - fn varlen(ty: VarLenType, len: usize, collation: Option) -> Self { - let cx = VarLenContext::new(ty, len, collation); + fn varlen(ty: VarLenType, len: usize) -> Self { + let cx = VarLenContext::new(ty, len, None); let inner = TypeInfoInner::VarLenSized(cx); Self { inner } diff --git a/src/tds/collation.rs b/src/tds/collation.rs index 65d4abdf..d7bfc88e 100644 --- a/src/tds/collation.rs +++ b/src/tds/collation.rs @@ -1,3 +1,5 @@ +use std::fmt; + ///! legacy implementation of collations (or codepages rather) for dealing with varchar's with legacy databases ///! references [1] which has some mappings from the katmai (SQL Server 2008) source code and is a TDS driver ///! directly from microsoft @@ -52,6 +54,15 @@ impl Collation { } } +impl fmt::Display for Collation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.encoding() { + Ok(encoding) => write!(f, "{}", encoding.name()), + _ => write!(f, "None"), + } + } +} + /// https://github.com/Microsoft/mssql-jdbc/blob/eb14f63077c47ef1fc1c690deb8cfab602baeb85/src/main/java/com/microsoft/sqlserver/jdbc/SQLCollation.java#L102-L310 /// maps an LCID (it's locale part which is only 2 bytes) to a codepage /// diff --git a/src/tds/context.rs b/src/tds/context.rs index 4c43b22b..fc24fa9f 100644 --- a/src/tds/context.rs +++ b/src/tds/context.rs @@ -1,4 +1,4 @@ -use super::codec::*; +use super::{codec::*, Collation}; use std::sync::Arc; /// Context, that might be required to make sure we understand and are understood by the server @@ -10,6 +10,7 @@ pub(crate) struct Context { transaction_desc: [u8; 8], last_meta: Option>>, spn: Option, + collation: Option, } impl Context { @@ -21,6 +22,7 @@ impl Context { transaction_desc: [0; 8], last_meta: None, spn: None, + collation: None, } } @@ -54,6 +56,14 @@ impl Context { self.transaction_desc = desc; } + pub fn set_collation(&mut self, collation: Collation) { + self.collation = Some(collation); + } + + pub fn collation(&self) -> Option { + self.collation + } + pub fn version(&self) -> FeatureLevel { self.version } diff --git a/src/tds/stream/token.rs b/src/tds/stream/token.rs index 4e325a05..3514e1c8 100644 --- a/src/tds/stream/token.rs +++ b/src/tds/stream/token.rs @@ -146,6 +146,11 @@ where TokenEnvChange::BeginTransaction(desc) => { self.conn.context_mut().set_transaction_descriptor(desc); } + TokenEnvChange::SqlCollation { new, .. } => { + if let Some(collation) = new { + self.conn.context_mut().set_collation(collation); + } + } TokenEnvChange::CommitTransaction | TokenEnvChange::RollbackTransaction | TokenEnvChange::DefectTransaction => { diff --git a/tests/bulk.rs b/tests/bulk.rs index 070c1dda..570fc589 100644 --- a/tests/bulk.rs +++ b/tests/bulk.rs @@ -4,7 +4,7 @@ use names::{Generator, Name}; use once_cell::sync::Lazy; use std::env; use std::sync::Once; -use tiberius::{BulkLoadMetadata, ColumnFlag, IntoSql, Result, TokenRow, TypeInfo}; +use tiberius::{BulkLoadMetadata, ColumnFlag, IntoSql, Result, TokenRow, TypeInfo, TypeLength}; use runtimes_macro::test_on_runtimes; @@ -122,7 +122,7 @@ test_bulk_type!(float( test_bulk_type!(varchar_limited( "VARCHAR(255)", - TypeInfo::float(), + TypeInfo::varchar(TypeLength::Limited(255)), 1000, vec!["aaaaaaaaaaaaaaaaaaaaaaa"; 1000].into_iter() ));