From 5f8bbd3b8e25940b44a70d145825f44795916a94 Mon Sep 17 00:00:00 2001 From: Andrew Billings Date: Tue, 16 May 2017 13:26:18 +0100 Subject: [PATCH] #45 #43 WIP Started work on nesting of batches and transaction support in batches. The SQL generation is mostly complete, but the code still needs to handle errors raised correctly. - A SqlBatch now stores a list of IBatchItem instead of just SqlBatchCommands. Both SqlBatch and SqlBatchCommand implement this interface. - Heavily refactored the pre-processing to support nesting of batches and transactions, using the new IBatchItem interface. - Added the ServerVersion to DatabaseSchema. - The Connection class caches the DatabaseSchema for its connection string. - Changed the common connection code to return Connection objects instead of just the connection string. This is used to get the ServerVersion from its DatabaseSchema, so better SQL can be generated if the version supports it. --- Database/BatchProcessArgs.cs | 84 ++ Database/Connection.cs | 12 +- Database/Constants.cs | 62 + Database/DbBatchDataReader.cs | 34 +- Database/IBatchItem.cs | 48 + Database/Resources.Designer.cs | 63 + Database/Resources.resx | 21 + Database/Schema/DatabaseSchema.Schema.cs | 18 +- Database/Schema/DatabaseSchema.cs | 36 +- Database/Schema/ISchema.cs | 9 + Database/SqlBatch.cs | 1211 +++++++++++++---- Database/SqlBatchCommand.cs | 231 +++- Database/SqlBatchParametersCollection.cs | 11 +- Database/SqlBatchResult.cs | 30 +- Database/SqlProgram.cs | 1 + Database/SqlStringBuilder.cs | 37 +- Database/Test/TestBatching.cs | 151 +- .../WebApplications.Utilities.Database.csproj | 3 + 18 files changed, 1643 insertions(+), 419 deletions(-) create mode 100644 Database/BatchProcessArgs.cs create mode 100644 Database/Constants.cs create mode 100644 Database/IBatchItem.cs diff --git a/Database/BatchProcessArgs.cs b/Database/BatchProcessArgs.cs new file mode 100644 index 00000000..31b70bdf --- /dev/null +++ b/Database/BatchProcessArgs.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using WebApplications.Utilities.Annotations; +using WebApplications.Utilities.Threading; + +namespace WebApplications.Utilities.Database +{ + /// + /// Arguments for processing a . + /// + internal class BatchProcessArgs + { + [NotNull] + public readonly Version ServerVersion; + + [NotNull] + public readonly SqlStringBuilder SqlBuilder = new SqlStringBuilder(); + + [NotNull] + public readonly List AllParameters = new List(); + + [NotNull] + public readonly Dictionary OutParameters = new Dictionary(); + + [NotNull] + public readonly Dictionary OutParameterCommands = new Dictionary(); + + [NotNull] + public readonly Dictionary> CommandOutParams = + new Dictionary>(); + + [NotNull] + public readonly HashSet ConnectionSemaphores = new HashSet(); + + [NotNull] + public readonly HashSet LoadBalConnectionSemaphores = new HashSet(); + + [NotNull] + public readonly HashSet DatabaseSemaphores = new HashSet(); + + [NotNull] + public readonly Stack TransactionStack = new Stack(); + + public bool InTransaction => TransactionStack.Count > 0; + + public CommandBehavior Behavior = CommandBehavior.SequentialAccess; + + public ushort CommandIndex; + + public BatchProcessArgs([NotNull] Version serverVersion) + { + ServerVersion = serverVersion; + } + + [NotNull] + public AsyncSemaphore[] GetSemaphores() + { + AsyncSemaphore[] semaphores; + // Concat the semaphores to a single array + int semaphoreCount = + ConnectionSemaphores.Count + + LoadBalConnectionSemaphores.Count + + DatabaseSemaphores.Count; + if (semaphoreCount < 1) + semaphores = Array.Empty; + else + { + semaphores = new AsyncSemaphore[semaphoreCount]; + int i = 0; + + // NOTE! Do NOT reorder these without also reordering the semaphores in SqlProgramCommand.WaitSemaphoresAsync + foreach (AsyncSemaphore semaphore in ConnectionSemaphores) + semaphores[i++] = semaphore; + foreach (AsyncSemaphore semaphore in LoadBalConnectionSemaphores) + semaphores[i++] = semaphore; + foreach (AsyncSemaphore semaphore in DatabaseSemaphores) + semaphores[i++] = semaphore; + } + return semaphores; + } + } +} \ No newline at end of file diff --git a/Database/Connection.cs b/Database/Connection.cs index 695689d9..cf7e89f1 100644 --- a/Database/Connection.cs +++ b/Database/Connection.cs @@ -177,6 +177,11 @@ public Connection([NotNull] string connectionString, double weight, AsyncSemapho [CanBeNull] internal readonly AsyncSemaphore Semaphore; + /// + /// The database schema for this connection. + /// + internal DatabaseSchema CachedSchema; + /// /// Returns a new connection with the increased by . /// @@ -248,7 +253,12 @@ public Task GetSchema( bool forceReload, CancellationToken cancellationToken = default(CancellationToken)) { - return DatabaseSchema.GetOrAdd(this, false, cancellationToken); + if (CachedSchema == null) + return DatabaseSchema.GetOrAdd(this, forceReload, cancellationToken); + + return forceReload + ? CachedSchema.ReLoad(cancellationToken) + : Task.FromResult(CachedSchema); } #region Equalities diff --git a/Database/Constants.cs b/Database/Constants.cs new file mode 100644 index 00000000..cf19da2d --- /dev/null +++ b/Database/Constants.cs @@ -0,0 +1,62 @@ +#region © Copyright Web Applications (UK) Ltd, 2017. All rights reserved. +// Copyright (c) 2017, Web Applications UK Ltd +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Web Applications UK Ltd nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL WEB APPLICATIONS UK LTD BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +namespace WebApplications.Utilities.Database +{ + internal static class Constants + { + internal static class BatchState + { + internal const int Building = 0; + internal const int Executing = 1; + internal const int Completed = 2; + } + + internal static class ExecuteState + { + /// + /// Indicates the next record set will be for the output parameters for the command. + /// + internal const string Output = "Output"; + + /// + /// Indicates an error has occurred. + /// + internal const string Error = "Error"; + + /// + /// Indicates the last Error should be re-thrown. + /// + internal const string ReThrow = "ReThrow"; + + /// + /// Indicates the command has ended. + /// + internal const string End = "End"; + } + } +} \ No newline at end of file diff --git a/Database/DbBatchDataReader.cs b/Database/DbBatchDataReader.cs index e23726b8..01c53063 100644 --- a/Database/DbBatchDataReader.cs +++ b/Database/DbBatchDataReader.cs @@ -297,11 +297,7 @@ public override Task IsDBNullAsync(int ordinal, CancellationToken cancella => BaseReaderOpen().IsDBNullAsync(ordinal, cancellationToken); /// Closes the object. - public override void Close() - { - State = BatchReaderState.Closed; - // TODO event handler? - } + public override void Close() => State = BatchReaderState.Closed; /// Releases the managed resources used by the and optionally releases the unmanaged resources. /// true to release managed and unmanaged resources; false to release only unmanaged resources. @@ -702,20 +698,22 @@ internal SqlBatchDataReader([NotNull] SqlDataReader baseReader, CommandBehavior /// An for reading XML from the current record set. protected internal override XmlReader GetXmlReader() { - if (BaseReader.FieldCount == 1) - switch (BaseReader.GetDataTypeName(0)) - { - case "ntext": - case "nvarchar": - return XmlReader.Create(new SqlBatchDataTextReader(this), _xmlReaderSettings); - case "xml": - // ReSharper disable once AssignNullToNotNullAttribute - return BaseReader.Read() - ? BaseReader.GetXmlReader(0) - : XmlReader.Create(new StringReader(string.Empty), _xmlReaderSettings); - } + if (BaseReader.FieldCount != 1) + throw new InvalidOperationException("The command must return a single column."); + + switch (BaseReader.GetDataTypeName(0)) + { + case "ntext": + case "nvarchar": + return XmlReader.Create(new SqlBatchDataTextReader(this), _xmlReaderSettings); + case "xml": + // ReSharper disable once AssignNullToNotNullAttribute + return BaseReader.Read() + ? BaseReader.GetXmlReader(0) + : XmlReader.Create(new StringReader(string.Empty), _xmlReaderSettings); + } - throw new InvalidOperationException("TODO"); + throw new InvalidOperationException("The command must return an Xml result."); } } } \ No newline at end of file diff --git a/Database/IBatchItem.cs b/Database/IBatchItem.cs new file mode 100644 index 00000000..ed32dd85 --- /dev/null +++ b/Database/IBatchItem.cs @@ -0,0 +1,48 @@ +#region © Copyright Web Applications (UK) Ltd, 2017. All rights reserved. +// Copyright (c) 2017, Web Applications UK Ltd +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of Web Applications UK Ltd nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL WEB APPLICATIONS UK LTD BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#endregion + +using WebApplications.Utilities.Annotations; + +namespace WebApplications.Utilities.Database +{ + /// + /// Interface to an item in a batch. + /// + internal interface IBatchItem + { + /// + /// Processes the item to be executed. + /// + /// The uid. + /// The connection string. + /// The arguments. + void Process( + [NotNull] string uid, + [NotNull] string connectionString, + BatchProcessArgs args); + } +} \ No newline at end of file diff --git a/Database/Resources.Designer.cs b/Database/Resources.Designer.cs index 1f7ff11c..33ffca07 100644 --- a/Database/Resources.Designer.cs +++ b/Database/Resources.Designer.cs @@ -420,6 +420,15 @@ internal static string SqlBatch_AddCommand_NoCommonConnections { } } + /// + /// Looks up a localized string similar to Only allowed 65536 commands per batch.. + /// + internal static string SqlBatch_AddCommand_OnlyAllowed65536 { + get { + return ResourceManager.GetString("SqlBatch_AddCommand_OnlyAllowed65536", resourceCulture); + } + } + /// /// Looks up a localized string similar to Cannot add to a batch is completed.. /// @@ -438,6 +447,15 @@ internal static string SqlBatch_CheckState_Executing { } } + /// + /// Looks up a localized string similar to Cannot have an isolation level of Unspecified.. + /// + internal static string SqlBatch_CreateTransaction_UnspecifiedIsoLvl { + get { + return ResourceManager.GetString("SqlBatch_CreateTransaction_UnspecifiedIsoLvl", resourceCulture); + } + } + /// /// Looks up a localized string similar to Cannot execute an empty batch.. /// @@ -447,6 +465,42 @@ internal static string SqlBatch_ExecuteAsync_Empty { } } + /// + /// Looks up a localized string similar to The IsolationLevel enumeration value, {0}, is not supported.. + /// + internal static string SqlBatch_Process_IsolationLevelNotSupported { + get { + return ResourceManager.GetString("SqlBatch_Process_IsolationLevelNotSupported", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The batch has already been added to another batch.. + /// + internal static string SqlBatch_RootBatch_AlreadyAdded { + get { + return ResourceManager.GetString("SqlBatch_RootBatch_AlreadyAdded", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot add a batch which is completed.. + /// + internal static string SqlBatch_RootBatch_Completed { + get { + return ResourceManager.GetString("SqlBatch_RootBatch_Completed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot add a batch which is executing.. + /// + internal static string SqlBatch_RootBatch_Executing { + get { + return ResourceManager.GetString("SqlBatch_RootBatch_Executing", resourceCulture); + } + } + /// /// Looks up a localized string similar to The reader is closed.. /// @@ -492,6 +546,15 @@ internal static string SqlBatchParametersCollection_GetOrAddParameter_MultipleRe } } + /// + /// Looks up a localized string similar to Only allowed 65536 parameters per command.. + /// + internal static string SqlBatchParametersCollection_GetOrAddParameter_OnlyAllowed65536 { + get { + return ResourceManager.GetString("SqlBatchParametersCollection_GetOrAddParameter_OnlyAllowed65536", resourceCulture); + } + } + /// /// Looks up a localized string similar to Unsupported SqlDbType '{0}'.. /// diff --git a/Database/Resources.resx b/Database/Resources.resx index 5bd7463e..202950f1 100644 --- a/Database/Resources.resx +++ b/Database/Resources.resx @@ -384,4 +384,25 @@ The reader is closed. + + Only allowed 65536 parameters per command. + + + Only allowed 65536 commands per batch. + + + Cannot add a batch which is executing. + + + Cannot add a batch which is completed. + + + The batch has already been added to another batch. + + + Cannot have an isolation level of Unspecified. + + + The IsolationLevel enumeration value, {0}, is not supported. + \ No newline at end of file diff --git a/Database/Schema/DatabaseSchema.Schema.cs b/Database/Schema/DatabaseSchema.Schema.cs index 7bba89be..19fc75cb 100644 --- a/Database/Schema/DatabaseSchema.Schema.cs +++ b/Database/Schema/DatabaseSchema.Schema.cs @@ -144,10 +144,19 @@ public class Schema : IEquatable, ISchema [PublicAPI] public Guid Guid { get; } + /// + /// Gets the server version. + /// + /// + /// The server version. + /// + public Version ServerVersion { get; } + /// /// Initializes a new instance of the class. /// /// The unique identifier. + /// The server version. /// The schemas by identifier. /// Name of the program definitions by. /// Name of the tables by. @@ -162,6 +171,7 @@ public class Schema : IEquatable, ISchema /// The database collation. private Schema( Guid guid, + [NotNull] Version serverVersion, [NotNull] IReadOnlyDictionary schemasByID, [NotNull] IReadOnlyDictionary programDefinitionsByName, [NotNull] IReadOnlyDictionary tablesByName, @@ -176,6 +186,7 @@ private Schema( [NotNull] SqlCollation databaseCollation) { Guid = guid; + ServerVersion = serverVersion; SchemasByID = schemasByID; ProgramsByName = programDefinitionsByName; TablesByName = tablesByName; @@ -193,6 +204,7 @@ private Schema( /// /// Gets or adds a schema. /// + /// The server version. /// The schemas by identifier. /// The programs by name. /// The tables by name. @@ -200,9 +212,12 @@ private Schema( /// The collations by name. /// The server collation. /// The database collation. - /// The schema. + /// + /// The schema. + /// [NotNull] protected internal static Schema GetOrAdd( + [NotNull] Version serverVersion, [NotNull] IReadOnlyDictionary schemasByID, [NotNull] IReadOnlyDictionary programsByName, [NotNull] IReadOnlyDictionary tablesByName, @@ -246,6 +261,7 @@ protected internal static Schema GetOrAdd( g => new Schema( g, + serverVersion, schemasByID, programsByName, tablesByName, diff --git a/Database/Schema/DatabaseSchema.cs b/Database/Schema/DatabaseSchema.cs index 462f8d60..0831d0b5 100644 --- a/Database/Schema/DatabaseSchema.cs +++ b/Database/Schema/DatabaseSchema.cs @@ -49,6 +49,12 @@ namespace WebApplications.Utilities.Database.Schema [PublicAPI] public partial class DatabaseSchema : ISchema { + /// + /// The minimum supported server version. + /// + [NotNull] + public static readonly Version MinimumSupportedServerVersion = new Version(9, 0); + /// /// Holds schemas against connections strings. /// @@ -377,6 +383,14 @@ public Instant Loaded /// Unique identity of the schema. /// public Guid Guid => Current.Guid; + + /// + /// Gets the server version. + /// + /// + /// The server version. + /// + public Version ServerVersion => Current.ServerVersion; #endregion /// @@ -394,12 +408,15 @@ public static Task GetOrAdd( { if (connection == null) throw new ArgumentNullException(nameof(connection)); - // ReSharper disable PossibleNullReferenceException - return _databaseSchemas.GetOrAdd( - connection.ConnectionString, - cs => new DatabaseSchema(connection.ConnectionString)) - .Load(forceReload, cancellationToken); - // ReSharper restore PossibleNullReferenceException + if (connection.CachedSchema == null) + { + connection.CachedSchema = _databaseSchemas.GetOrAdd( + connection.ConnectionString, + cs => new DatabaseSchema(connection.ConnectionString)); + Debug.Assert(connection.CachedSchema != null); + } + + return connection.CachedSchema.Load(forceReload, cancellationToken); } /// @@ -457,18 +474,20 @@ private async Task Load(bool forceReload, CancellationToken canc SqlCollation serverCollation; SqlCollation databaseCollation; + Version version; + // Open a connection using (SqlConnection sqlConnection = new SqlConnection(ConnectionString)) { // ReSharper disable once PossibleNullReferenceException await sqlConnection.OpenAsync(cancellationToken).ConfigureAwait(false); - if (!Version.TryParse(sqlConnection.ServerVersion, out Version version)) + if (!Version.TryParse(sqlConnection.ServerVersion, out version)) throw new DatabaseSchemaException( () => Resources.DatabaseSchema_Load_CouldNotParseVersionInformation); Debug.Assert(version != null); - if (version.Major < 9) + if (version < MinimumSupportedServerVersion) throw new DatabaseSchemaException( () => Resources.DatabaseSchema_Load_VersionNotSupported, version); @@ -862,6 +881,7 @@ private async Task Load(bool forceReload, CancellationToken canc // Update the current schema. _current = new CurrentSchema( Schema.GetOrAdd( + version, sqlSchemas, programDefinitions, tables, diff --git a/Database/Schema/ISchema.cs b/Database/Schema/ISchema.cs index 6d170246..18d14f3b 100644 --- a/Database/Schema/ISchema.cs +++ b/Database/Schema/ISchema.cs @@ -136,5 +136,14 @@ public interface ISchema : IEquatable /// Unique identity of the schema. /// Guid Guid { get; } + + /// + /// Gets the server version. + /// + /// + /// The server version. + /// + [NotNull] + Version ServerVersion { get; } } } \ No newline at end of file diff --git a/Database/SqlBatch.cs b/Database/SqlBatch.cs index 2b7778a6..e6eaaa21 100644 --- a/Database/SqlBatch.cs +++ b/Database/SqlBatch.cs @@ -40,46 +40,359 @@ using System.Threading.Tasks; using NodaTime; using WebApplications.Utilities.Annotations; +using WebApplications.Utilities.Database.Schema; using WebApplications.Utilities.Threading; namespace WebApplications.Utilities.Database { + /// + /// Delegate to a method for handling an exception. + /// + /// The type of the exception. + /// The exception to handle. + /// Set to to suppress the exception. + public delegate void ExceptionHandler(T exception, ref bool suppress) + where T : Exception; + /// /// Allows multiple SqlPrograms to be executed in a single database call. /// - public partial class SqlBatch : IReadOnlyList + public partial class SqlBatch : IEnumerable, IBatchItem { - private const int Building = 0; - private const int Executing = 1; - private const int Completed = 2; + /// + /// The type of a transaction. + /// + [Flags] + private enum TransactionType : byte + { + /// + /// No transaction + /// + None, + + /// + /// A transaction which commits if successfully executed + /// + Commit, + + /// + /// A transaction which always rolls back + /// + Rollback + } + + /// + /// Holds the state of a batch + /// + /// + private sealed class State : IDisposable + { + /// + /// The batch the state is for. + /// + [NotNull] + public readonly SqlBatch Batch; + /// + /// The execute lock + /// + [NotNull] + public readonly AsyncLock ExecuteLock = new AsyncLock(); + /// + /// The value of the state. + /// + public int Value = Constants.BatchState.Building; + /// + /// The command count. + /// + public int CommandCount; + + /// + /// Initializes a new instance of the class. + /// + /// The batch. + public State([NotNull] SqlBatch batch) + { + Batch = batch; + } + + /// + /// Checks the state is valid for adding to the batch. + /// + public void CheckBuildingState() + { + if (Value != Constants.BatchState.Building) + { + throw new InvalidOperationException( + Value == Constants.BatchState.Executing + ? Resources.SqlBatch_CheckState_Executing + : Resources.SqlBatch_CheckState_Completed); + } + } + + /// Releases the lock held on this state. + public void Dispose() => Monitor.Exit(this); + } + + /// + /// The states of two batches. + /// + /// + private struct States : IDisposable + { + public State StateA; + public State StateB; + + /// + /// Releases the locks held on the states + /// + public void Dispose() + { + try + { + try + { + } + finally + { + State stateB = Interlocked.Exchange(ref StateB, null); + if (stateB != null) Monitor.Exit(stateB); + } + } + finally + { + State stateA = Interlocked.Exchange(ref StateA, null); + if (stateA != null) Monitor.Exit(StateA); + } + } + } + /// + /// Gets the state. + /// [NotNull] - private readonly object _addLock = new object(); + private State GetState() + { + bool hasLock = false; + State state = _state; + Monitor.Enter(state, ref hasLock); + try + { + while (state.Batch._parent != null) + { + Monitor.Exit(Interlocked.Exchange(ref state, state.Batch._parent._state)); + hasLock = false; + Monitor.Enter(state, ref hasLock); + } + Debug.Assert(hasLock); + _state = state; + return state; + } + catch + { + if (hasLock) + Monitor.Exit(state); + throw; + } + } + /// + /// Gets the states of two batches. + /// + /// The batch a. + /// The batch b. + /// + private static States GetStates([NotNull] SqlBatch batchA, [NotNull] SqlBatch batchB) + { + States states = new States(); + try + { + if (batchA.ID.CompareTo(batchB.ID) < 0) + { + states.StateA = batchA.GetState(); + states.StateB = batchB.GetState(); + } + else + { + states.StateB = batchB.GetState(); + states.StateA = batchA.GetState(); + } + return states; + } + catch + { + states.Dispose(); + throw; + } + } + [NotNull] - private readonly AsyncLock _executeLock = new AsyncLock(); + private State _state; - private int _state = Building; + private SqlBatch _parent; [NotNull] [ItemNotNull] - private readonly List _commands = new List(); + private readonly List _items = new List(); private Duration _batchTimeout; + /// + /// The type of transaction to use. + /// + private readonly TransactionType _transaction; + + /// + /// The isolation level for the transaction. + /// + private readonly IsolationLevel _isolationLevel; + + /// + /// If then any errors that occur within this batch wont cause an exception to be thrown for the whole batch. + /// The command that failed will still throw an exception. + /// + private readonly bool _suppressErrors; + + /// + /// The exception handler for any errors that occur in the database. + /// Only used if there is a transaction or errors are suppressed. + /// + private readonly ExceptionHandler _exceptionHandler; + +#if DEBUG + /// + /// The SQL for the batch. For debugging purposes. + /// + private string _sql; +#endif + + /// + /// Creates a new batch. + /// + /// The batch timeout. Defaults to 30 seconds. + /// The new . [NotNull] - private readonly ResettableLazy> _commonConnectionStrings; + public static SqlBatch Create(Duration? batchTimeout = null) + { + return new SqlBatch(batchTimeout); + } /// - /// Initializes a new instance of the class. + /// Creates a new batch which is wrapped in a try ... catch block. + /// Any errors that occur within this batch wont cause an exception to be thrown for the whole batch, + /// unless an is specified and doesn't suppress the error. + /// The command that failed will still throw an exception. /// + /// The optional exception handler. /// The batch timeout. Defaults to 30 seconds. - public SqlBatch(Duration? batchTimeout = null) + /// The new . + [NotNull] + public static SqlBatch CreateTryCatch( + ExceptionHandler exceptionHandler = null, + Duration? batchTimeout = null) + { + return new SqlBatch(batchTimeout, suppressErrors: true, exceptionHandler: exceptionHandler); + } + + /// + /// Creates a new batch which is wrapped in a transaction. + /// If an error occurs within the batch, the transaction will rollback. + /// + /// The isolation level of the transaction. + /// if set to the transaction will always be rolled back. + /// if set to any errors that occur within this batch + /// wont cause an exception to be thrown for the whole batch. See . + /// The optional exception handler. + /// The batch timeout. Defaults to 30 seconds. + /// + /// The new . + /// + [NotNull] + public static SqlBatch CreateTransaction( + IsolationLevel isolationLevel, + bool rollback = false, + bool suppressErrors = false, + ExceptionHandler exceptionHandler = null, + Duration? batchTimeout = null) + { + return new SqlBatch( + batchTimeout, + rollback ? TransactionType.Rollback : TransactionType.Commit, + isolationLevel, + suppressErrors, + exceptionHandler); + } + + /// + /// Initializes a new instance of the class. + /// + /// The batch timeout. Defaults to 30 seconds. + /// The transaction. + /// The isolation level. + /// if set to errors are suppressed. + /// The exception handler. + private SqlBatch( + Duration? batchTimeout, + TransactionType transaction = TransactionType.None, + IsolationLevel isolationLevel = IsolationLevel.Unspecified, + bool suppressErrors = false, + ExceptionHandler exceptionHandler = null) { BatchTimeout = batchTimeout ?? Duration.FromSeconds(30); - _commonConnectionStrings = new ResettableLazy>(GetCommonConnections); + _state = new State(this); + _transaction = transaction; + _isolationLevel = isolationLevel; + _suppressErrors = suppressErrors; + _exceptionHandler = exceptionHandler; } + /// + /// Initializes a new instance of the class. + /// + /// The parent. + /// The transaction. + /// The isolation level. + /// if set to errors are suppressed. + /// The exception handler. + private SqlBatch( + [NotNull] SqlBatch parent, + TransactionType transaction = TransactionType.None, + IsolationLevel isolationLevel = IsolationLevel.Unspecified, + bool suppressErrors = false, + ExceptionHandler exceptionHandler = null) + { + BatchTimeout = parent._batchTimeout; + _state = parent._state; + _parent = parent; + _transaction = transaction; + _isolationLevel = isolationLevel; + _suppressErrors = suppressErrors; + _exceptionHandler = exceptionHandler; + } + + /// + /// Checks the state is valid for adding to the batch, without requiring a lock. + /// + private void CheckBuildingStateQuick() + { + State state = _state; + if (state.Batch._parent == null) + state.CheckBuildingState(); + } + + /// + /// Gets the identifier of the batch. + /// + /// The identifier. + public Guid ID { get; } = Guid.NewGuid(); + + /// + /// Gets a value indicating whether this is a root batch. + /// + /// + /// if this instance is root; otherwise, . + /// + public bool IsRoot => _parent == null; + /// /// Gets or sets the batch timeout. /// This is the time to wait for the batch to execute. @@ -106,76 +419,51 @@ public Duration BatchTimeout } } - /// - /// Gets the connection strings which are common to all the commands that have been added to this batch. - /// - /// - /// The common connection strings. - /// - [NotNull] - [ItemNotNull] - public IReadOnlyCollection CommonConnectionStrings - { - get - { - lock (_addLock) - { - return _commonConnectionStrings.Value -#if NET452 - .ToArray() -#endif - ; - } - } - } - - /// Gets the number of elements in the collection. - /// The number of elements in the collection. - public int Count => _commands.Count; - - /// Gets the element at the specified index in the read-only list. - /// The element at the specified index in the read-only list. - /// The zero-based index of the element to get. - public SqlBatchCommand this[int index] => _commands[index]; - /// Returns an enumerator that iterates through a collection. /// An object that can be used to iterate through the collection. IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); /// Returns an enumerator that iterates through the collection. /// A that can be used to iterate through the collection. - public IEnumerator GetEnumerator() => _commands.GetEnumerator(); - - /// - /// Checks the state is valid for adding to the batch. - /// - private void CheckState() + public IEnumerator GetEnumerator() { - if (_state != Building) + Stack.Enumerator> stack = new Stack.Enumerator>(); + stack.Push(_items.GetEnumerator()); + + while (stack.TryPop(out List.Enumerator enumerator)) { - throw new InvalidOperationException( - _state == Executing - ? Resources.SqlBatch_CheckState_Executing - : Resources.SqlBatch_CheckState_Completed); + while (enumerator.MoveNext()) + { + IBatchItem item = enumerator.Current; + + if (item is SqlBatchCommand command) yield return command; + else + { + Debug.Assert(item is SqlBatch); + + stack.Push(enumerator); + stack.Push(((SqlBatch)item)._items.GetEnumerator()); + break; + } + } } } - + /// /// Adds the command given to the batch. /// /// The command. private void AddCommand([NotNull] SqlBatchCommand command) { - lock (_addLock) + using (State state = GetState()) { - CheckState(); - - if (_commands.Count >= ushort.MaxValue) - throw new InvalidOperationException("Only allowed 65536 commands per batch."); + state.CheckBuildingState(); - command.Id = (ushort)_commands.Count; - _commands.Add(command); - _commonConnectionStrings.Reset(); + if (state.CommandCount >= ushort.MaxValue || + Interlocked.Increment(ref state.CommandCount) > ushort.MaxValue) + throw new InvalidOperationException(Resources.SqlBatch_AddCommand_OnlyAllowed65536); + + _items.Add(command); } } @@ -189,13 +477,14 @@ private void AddCommand([NotNull] SqlBatchCommand command) /// A which can be used to get the scalar value returned by the program. /// An optional method for setting the parameters to pass to the program. /// This instance. + [NotNull] public SqlBatch AddExecuteScalar( [NotNull] SqlProgram program, [NotNull] out SqlBatchResult result, [CanBeNull] SetBatchParametersDelegate setParameters = null) { if (program == null) throw new ArgumentNullException(nameof(program)); - CheckState(); + CheckBuildingStateQuick(); SqlBatchCommand.Scalar command = new SqlBatchCommand.Scalar(this, program, setParameters); AddCommand(command); @@ -212,13 +501,14 @@ public SqlBatch AddExecuteScalar( /// An optional method for setting the parameters to pass to the program. /// A which can be used to get the number of records affected by the program. /// This instance. + [NotNull] public SqlBatch AddExecuteNonQuery( [NotNull] SqlProgram program, [NotNull] out SqlBatchResult result, [CanBeNull] SetBatchParametersDelegate setParameters = null) { if (program == null) throw new ArgumentNullException(nameof(program)); - CheckState(); + CheckBuildingStateQuick(); SqlBatchCommand.NonQuery command = new SqlBatchCommand.NonQuery(this, program, setParameters); AddCommand(command); @@ -236,6 +526,7 @@ public SqlBatch AddExecuteNonQuery( /// An optional method for setting the parameters to pass to the program. /// A which can be used to wait for the program to finish executing. /// This instance. + [NotNull] public SqlBatch AddExecuteReader( [NotNull] SqlProgram program, [NotNull] ResultDelegateAsync resultAction, @@ -249,7 +540,7 @@ public SqlBatch AddExecuteReader( throw new ArgumentOutOfRangeException( nameof(behavior), "CommandBehavior.CloseConnection is not supported"); - CheckState(); + CheckBuildingStateQuick(); SqlBatchCommand.Reader command = new SqlBatchCommand.Reader( this, @@ -262,7 +553,7 @@ public SqlBatch AddExecuteReader( return this; } - + /// /// Adds the specified program to the batch. /// The value returned by the will be returned by the . @@ -275,6 +566,7 @@ public SqlBatch AddExecuteReader( /// A which can be used to wait for the program to finish executing /// and get the value returned by . /// This instance. + [NotNull] public SqlBatch AddExecuteReader( [NotNull] SqlProgram program, [NotNull] ResultDelegateAsync resultFunc, @@ -288,7 +580,7 @@ public SqlBatch AddExecuteReader( throw new ArgumentOutOfRangeException( nameof(behavior), "CommandBehavior.CloseConnection is not supported"); - CheckState(); + CheckBuildingStateQuick(); SqlBatchCommand.Reader command = new SqlBatchCommand.Reader( this, @@ -301,7 +593,7 @@ public SqlBatch AddExecuteReader( return this; } - + /// /// Adds the specified program to the batch. /// @@ -310,6 +602,7 @@ public SqlBatch AddExecuteReader( /// An optional method for setting the parameters to pass to the program. /// A which can be used to wait for the program to finish executing. /// This instance. + [NotNull] public SqlBatch AddExecuteXmlReader( [NotNull] SqlProgram program, [NotNull] XmlResultDelegateAsync resultAction, @@ -318,7 +611,7 @@ public SqlBatch AddExecuteXmlReader( { if (program == null) throw new ArgumentNullException(nameof(program)); if (resultAction == null) throw new ArgumentNullException(nameof(resultAction)); - CheckState(); + CheckBuildingStateQuick(); SqlBatchCommand.XmlReader command = new SqlBatchCommand.XmlReader( this, @@ -330,6 +623,7 @@ public SqlBatch AddExecuteXmlReader( return this; } + /// /// Adds the specified program to the batch. /// The value returned by the will be returned by the . @@ -341,6 +635,7 @@ public SqlBatch AddExecuteXmlReader( /// A which can be used to wait for the program to finish executing /// and get the value returned by . /// This instance. + [NotNull] public SqlBatch AddExecuteXmlReader( [NotNull] SqlProgram program, [NotNull] XmlResultDelegateAsync resultFunc, @@ -349,7 +644,7 @@ public SqlBatch AddExecuteXmlReader( { if (program == null) throw new ArgumentNullException(nameof(program)); if (resultFunc == null) throw new ArgumentNullException(nameof(resultFunc)); - CheckState(); + CheckBuildingStateQuick(); SqlBatchCommand.XmlReader command = new SqlBatchCommand.XmlReader( this, @@ -362,40 +657,215 @@ public SqlBatch AddExecuteXmlReader( return this; } - // TODO Add nested batch + /// + /// Adds the specified batch to this batch. + /// + /// The batch to add to this batch. + /// This instance. + [NotNull] + public SqlBatch AddBatch([NotNull] SqlBatch batch) + { + if (batch == null) throw new ArgumentNullException(nameof(batch)); + + // Can't add a batch which already has a parent + if (batch._parent != null) + throw new InvalidOperationException(Resources.SqlBatch_RootBatch_AlreadyAdded); + + // Check the states before taking the locks + State myState = _state, otherState = batch._state; + if (otherState.Value != Constants.BatchState.Building) + { + throw new InvalidOperationException( + otherState.Value == Constants.BatchState.Executing + ? Resources.SqlBatch_RootBatch_Executing + : Resources.SqlBatch_RootBatch_Completed); + } + + // If the state is for the root, check its in the building state and has enough command capacity + if (myState.Batch._parent == null) + { + myState.CheckBuildingState(); + if (myState.CommandCount + otherState.CommandCount > ushort.MaxValue) + throw new InvalidOperationException(Resources.SqlBatch_AddCommand_OnlyAllowed65536); + } + + // Get the states of both batches + using (States states = GetStates(this, batch)) + { + myState = states.StateA; + otherState = states.StateB; + + // Can't add a batch which already has a parent + if (batch._parent != null) + throw new InvalidOperationException(Resources.SqlBatch_RootBatch_AlreadyAdded); + Debug.Assert(otherState.Batch == batch); + + // Make sure the batches are in the Building state + myState.CheckBuildingState(); + if (otherState.Value != Constants.BatchState.Building) + { + throw new InvalidOperationException( + otherState.Value == Constants.BatchState.Executing + ? Resources.SqlBatch_RootBatch_Executing + : Resources.SqlBatch_RootBatch_Completed); + } + + // Increment the command counter + if (myState.CommandCount + otherState.CommandCount > ushort.MaxValue || + Interlocked.Add(ref myState.CommandCount, otherState.CommandCount) > ushort.MaxValue) + throw new InvalidOperationException(Resources.SqlBatch_AddCommand_OnlyAllowed65536); + + batch._parent = this; + batch._state = myState; + + _items.Add(batch); + } + + return this; + } + + /// + /// Adds the batch to this batch, only checking the state of this batch. + /// + /// The batch. + /// + private void AddBatchQuick(SqlBatch batch) + { + using (State state = GetState()) + { + state.CheckBuildingState(); + + _items.Add(batch); + } + } + + /// + /// Adds a new batch to this batch. + /// + /// A delegate to the method to use to add commands to the new batch. + /// This instance. + [NotNull] + public SqlBatch AddBatch([NotNull] Action addToBatch) + { + if (addToBatch == null) throw new ArgumentNullException(nameof(addToBatch)); + CheckBuildingStateQuick(); + + SqlBatch newBatch = new SqlBatch(this); + + AddBatchQuick(newBatch); + + addToBatch(newBatch); + + return this; + } + + /// + /// Adds a new batch to this batch. + /// Any errors that occur within the new batch wont cause an exception to be thrown for the whole batch, + /// unless an is specified and doesn't suppress the error. + /// The command that failed will still throw an exception. + /// + /// A delegate to the method to use to add commands to the new batch. + /// The optional exception handler. + /// This instance. + [NotNull] + public SqlBatch AddTryCatch(Action addToBatch, ExceptionHandler exceptionHandler = null) + { + if (addToBatch == null) throw new ArgumentNullException(nameof(addToBatch)); + CheckBuildingStateQuick(); + + SqlBatch newBatch = new SqlBatch(this, suppressErrors: true, exceptionHandler: exceptionHandler); + + AddBatchQuick(newBatch); + + addToBatch(newBatch); + + return this; + } + + /// + /// Adds a new batch to this batch. + /// If an error occurs within the batch, the transaction will rollback. + /// + /// A delegate to the method to use to add commands to the new batch. + /// The isolation level of the transaction. + /// if set to the transaction will always be rolled back. + /// if set to any errors that occur within this batch + /// wont cause an exception to be thrown for the whole batch. See . + /// The optional exception handler. + /// This instance. + [NotNull] + public SqlBatch AddTransaction( + Action addToBatch, + IsolationLevel isolationLevel, + bool rollback = false, + bool suppressErrors = false, + ExceptionHandler exceptionHandler = null) + { + if (addToBatch == null) throw new ArgumentNullException(nameof(addToBatch)); + CheckBuildingStateQuick(); + + SqlBatch newBatch = new SqlBatch( + this, + rollback ? TransactionType.Rollback : TransactionType.Commit, + isolationLevel, + suppressErrors, + exceptionHandler); + + AddBatchQuick(newBatch); + + addToBatch(newBatch); + + return this; + } /// /// Executes the batch on a single connection, asynchronously. /// /// A cancellation token which can be used to cancel the entire batch operation. /// An awaitable task which completes when the batch is complete. + [NotNull] public async Task ExecuteAsync(CancellationToken cancellationToken = default(CancellationToken)) { - lock (_addLock) + // If this isnt the root batch, need to begin executing the root then wait for this batch to complete + if (!IsRoot) + { + throw new NotImplementedException(); + } + + // If we're already completed, just return + if (_state.Value == Constants.BatchState.Completed) return; + + State state; + using (state = GetState()) { - if (_commands.Count < 1) + if (state.CommandCount < 1) throw new InvalidOperationException(Resources.SqlBatch_ExecuteAsync_Empty); - if (Interlocked.CompareExchange(ref _state, Executing, Building) == Completed) return; + // Change the state to Executing, or return if the state has already been set to completed. + if (Interlocked.CompareExchange( + ref state.Value, + Constants.BatchState.Executing, + Constants.BatchState.Building) == Constants.BatchState.Completed) return; } - using (await _executeLock.LockAsync(cancellationToken).ConfigureAwait(false)) + using (await state.ExecuteLock.LockAsync(cancellationToken).ConfigureAwait(false)) { - if (_state == Completed) return; + if (state.Value == Constants.BatchState.Completed) return; try { - string connectionString = DetermineConnection(); + Connection connection = DetermineConnection(); // Set the result count for each command to the number of connections - foreach (SqlBatchCommand command in _commands) + foreach (SqlBatchCommand command in this) command.Result.SetResultCount(1); - await ExecuteInternal(connectionString, 0, cancellationToken).ConfigureAwait(false); + await ExecuteInternal(connection, 0, cancellationToken).ConfigureAwait(false); } finally { - _state = Completed; + state.Value = Constants.BatchState.Completed; } } } @@ -405,39 +875,53 @@ public SqlBatch AddExecuteXmlReader( /// /// A cancellation token which can be used to cancel the entire batch operation. /// An awaitable task which completes when the batch is complete. + [NotNull] public async Task ExecuteAllAsync(CancellationToken cancellationToken = default(CancellationToken)) { - lock (_addLock) + if (!IsRoot) + { + throw new NotImplementedException(); + } + + // If we're already completed, just return + if (_state.Value == Constants.BatchState.Completed) return; + + State state; + using (state = GetState()) { - if (_commands.Count < 1) + if (state.CommandCount < 1) throw new InvalidOperationException(Resources.SqlBatch_ExecuteAsync_Empty); - if (Interlocked.CompareExchange(ref _state, Executing, Building) == Completed) return; + // Change the state to Executing, or return if the state has already been set to completed. + if (Interlocked.CompareExchange( + ref state.Value, + Constants.BatchState.Executing, + Constants.BatchState.Building) == Constants.BatchState.Completed) return; } - using (await _executeLock.LockAsync(cancellationToken).ConfigureAwait(false)) + using (await state.ExecuteLock.LockAsync(cancellationToken).ConfigureAwait(false)) { - if (_state == Completed) return; + if (state.Value == Constants.BatchState.Completed) return; try { // Get the connection strings which are common to each program - HashSet commonConnections = _commonConnectionStrings.Value; + HashSet commonConnections = GetCommonConnections(); Debug.Assert(commonConnections != null, "commonConnections != null"); // Set the result count for each command to the number of connections - foreach (SqlBatchCommand command in _commands) + foreach (SqlBatchCommand command in this) command.Result.SetResultCount(commonConnections.Count); Task[] tasks = commonConnections - .Select((cs, i) => Task.Run(() => ExecuteInternal(cs, i, cancellationToken))) + .Select((con, i) => Task.Run(() => ExecuteInternal(con, i, cancellationToken))) .ToArray(); await Task.WhenAll(tasks).ConfigureAwait(false); } finally { - _state = Completed; + state.Value = Constants.BatchState.Completed; } } } @@ -449,9 +933,8 @@ public SqlBatch AddExecuteXmlReader( internal void BeginExecute(bool all) { // If the state is Executing or Completed, don't need to do anything. - lock (_addLock) - if (_state != Building) - return; + if (_state.Value != Constants.BatchState.Building) + return; // Execute the batch if (all) @@ -460,82 +943,67 @@ internal void BeginExecute(bool all) Task.Run(() => ExecuteAsync()); } - /// - /// Indicates the next record set will be for the output parameters for the command. - /// - private const string OutputState = "Output"; - - /// - /// Indicates the command has ended. - /// - private const string EndState = "End"; - /// /// Executes the batch. /// - /// The connection string. + /// The connection. /// Index of the connection. /// A cancellation token. [NotNull] private async Task ExecuteInternal( - [NotNull] string connectionString, + [NotNull] Connection connection, int connectionIndex, CancellationToken cancellationToken) { - SqlStringBuilder sqlBuilder = new SqlStringBuilder(); - string uid = $"{Guid.NewGuid():B}@{DateTime.UtcNow:O}:"; + string uid = $"{ID:B} @ {DateTime.UtcNow:O}:"; - List allParameters = new List(); + DatabaseSchema schema = connection.CachedSchema; + // TODO Do we care enough? // ?? await connection.GetSchema(false, cancellationToken).ConfigureAwait(false); - Dictionary outParameters = new Dictionary(); - - Dictionary> commandOutParams = - new Dictionary>(); + BatchProcessArgs args = + new BatchProcessArgs(schema?.ServerVersion ?? DatabaseSchema.MinimumSupportedServerVersion); // Build the batch SQL and get the parameters to the commands PreProcess( uid, - connectionString, - allParameters, - outParameters, - commandOutParams, - sqlBuilder, - out AsyncSemaphore[] semaphores, + connection, + args, out CommandBehavior allBehavior); + AsyncSemaphore[] semaphores = args.GetSemaphores(); + string state = null; int index = -1; + string stateArgs = null; int actualIndex = 0; DbBatchDataReader commandReader = null; void MessageHandler(string message) { - if (!TryParseInfoMessage(message, out var info)) return; - - state = info.state; - index = info.index; + if (!TryParseInfoMessage(message, ref state, ref index, ref stateArgs)) return; if (commandReader != null) commandReader.State = BatchReaderState.Finished; } + // Wait the semaphores and setup the connection, command and reader using (await AsyncSemaphore.WaitAllAsync(cancellationToken, semaphores).ConfigureAwait(false)) - using (DbConnection dbConnection = await CreateOpenConnectionAsync(connectionString, uid, MessageHandler, cancellationToken) + using (DbConnection dbConnection = await CreateOpenConnectionAsync(connection, uid, MessageHandler, cancellationToken) .ConfigureAwait(false)) - using (DbCommand dbCommand = CreateCommand(sqlBuilder.ToString(), dbConnection, allParameters.ToArray())) + using (DbCommand dbCommand = CreateCommand(args.SqlBuilder.ToString(), dbConnection, args.AllParameters.ToArray())) using (DbDataReader reader = await dbCommand.ExecuteReaderAsync(allBehavior, cancellationToken) .ConfigureAwait(false)) { Debug.Assert(reader != null, "reader != null"); object[] values = null; - - List.Enumerator commandEnumerator = _commands.GetEnumerator(); + + IEnumerator commandEnumerator = GetEnumerator(); while (commandEnumerator.MoveNext()) { SqlBatchCommand command = commandEnumerator.Current; Debug.Assert(command != null, "command != null"); - + try { using (commandReader = CreateReader(reader, command.CommandBehavior)) @@ -559,51 +1027,57 @@ await command.HandleCommandAsync( commandReader = null; // Check the end states - while (state != EndState && index == actualIndex) + while (state != Constants.ExecuteState.End && index == actualIndex) { - if (state == OutputState) + switch (state) { - // Get the expected output parameters for the command - if (!commandOutParams.TryGetValue(command, out var outs)) - throw new NotImplementedException( - "Proper exception, unexpected output parameters"); + case Constants.ExecuteState.Output: + { + // Get the expected output parameters for the command + if (!args.CommandOutParams.TryGetValue(command, out var outs)) + throw new NotImplementedException( + "Proper exception, unexpected output parameters"); - // No longer expect the output parameters - commandOutParams.Remove(command); + // No longer expect the output parameters + args.CommandOutParams.Remove(command); - if (!await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) - throw new NotImplementedException("Proper exception, missing data"); + if (!await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + throw new NotImplementedException("Proper exception, missing data"); - if (outs.Count != reader.VisibleFieldCount) - throw new NotImplementedException("Proper exception, field count mismatch"); + if (outs.Count != reader.VisibleFieldCount) + throw new NotImplementedException("Proper exception, field count mismatch"); - // Expand the values buffer if needed - if (values == null) - values = new object[reader.VisibleFieldCount]; - else if (values.Length < reader.VisibleFieldCount) - Array.Resize(ref values, reader.VisibleFieldCount); + // Expand the values buffer if needed + if (values == null) + values = new object[reader.VisibleFieldCount]; + else if (values.Length < reader.VisibleFieldCount) + Array.Resize(ref values, reader.VisibleFieldCount); - // Get the output values record - reader.GetValues(values); + // Get the output values record + reader.GetValues(values); - // Set the output values - for (int i = 0; i < outs.Count; i++) - { - Debug.Assert(outs[i].output != null, "outs[i].output != null"); - Debug.Assert(outs[i].param != null, "outs[i].param != null"); + // Set the output values + for (int i = 0; i < outs.Count; i++) + { + Debug.Assert(outs[i].output != null, "outs[i].output != null"); + Debug.Assert(outs[i].param != null, "outs[i].param != null"); - outs[i].output.SetOutputValue(values[i], outs[i].param.BaseParameter); - } + outs[i].output.SetOutputValue(values[i], outs[i].param.BaseParameter); + } - if (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) - throw new NotImplementedException("Proper exception, unexpected data"); + if (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + throw new NotImplementedException("Proper exception, unexpected data"); - // TODO Do something with this? - bool hasNext = await reader.NextResultAsync(cancellationToken) - .ConfigureAwait(false); + // TODO Do something with this? + bool hasNext = await reader.NextResultAsync(cancellationToken) + .ConfigureAwait(false); + break; + } + case Constants.ExecuteState.Error: + throw new NotImplementedException(); + default: + throw new NotImplementedException("Proper exception, unexpected state"); } - else - throw new NotImplementedException("Proper exception, unexpected state"); } } catch (OperationCanceledException) @@ -642,158 +1116,306 @@ await command.HandleCommandAsync( /// /// The uid. /// The connection string. - /// All parameters. - /// The out parameters. - /// The command out parameters. - /// The SQL builder. - /// The semaphores. + /// The arguments. /// All behavior. private void PreProcess( [NotNull] string uid, [NotNull] string connectionString, - [NotNull] List allParameters, - [NotNull] Dictionary outParameters, - [NotNull] Dictionary> commandOutParams, - [NotNull] SqlStringBuilder sqlBuilder, - [NotNull] out AsyncSemaphore[] semaphores, + [NotNull] BatchProcessArgs args, out CommandBehavior allBehavior) { - HashSet connectionSemaphores = new HashSet(); - HashSet loadBalConnectionSemaphores = new HashSet(); - HashSet databaseSemaphores = new HashSet(); + ((IBatchItem)this).Process( + uid, + connectionString, + args); - // Start the behavior asking for sequential access. All commands must want it to be able to use it - allBehavior = CommandBehavior.SequentialAccess; + allBehavior = args.Behavior; + } - int commandIndex = 0; - foreach (SqlBatchCommand command in _commands) + /// + /// Processes the batch to be executed. + /// + /// The uid. + /// The connection string. + /// The arguments. + void IBatchItem.Process( + string uid, + string connectionString, + BatchProcessArgs args) + { + args.SqlBuilder + .AppendLine() + .AppendLine("/*") + .Append(" * Starting batch ") + .AppendLine(ID.ToString("D")) + .AppendLine(" */"); + + if (IsRoot) { - // Get the parameters for the command - SqlBatchParametersCollection parameters = command.GetParametersForConnection(connectionString); - - // Add the parameters to the collection to pass to the command - foreach (DbBatchParameter batchParameter in parameters.Parameters) - { - if (batchParameter.OutputValue != null) - { - if (!outParameters.TryGetValue(batchParameter.OutputValue, out DbParameter dbParameter)) - throw new NotImplementedException("proper error"); - batchParameter.BaseParameter = dbParameter; - } - else - { - allParameters.Add(batchParameter.BaseParameter); - } - } + // Declare the variable for storing the index of the currently executing command + args.SqlBuilder + .AppendLine("DECLARE @CmdIndex int;"); + } - // Add any output parameters to the dictionary for passing into following commands - if (parameters.OutputParameters != null) - { - foreach ((DbBatchParameter batchParameter, IOut outValue) in parameters.OutputParameters) - { - if (outParameters.ContainsKey(outValue)) - throw new NotImplementedException("proper error"); + args.SqlBuilder.AppendLine(); - outParameters.Add(outValue, (SqlParameter)batchParameter.BaseParameter); - } + bool hasTransaction = _transaction != TransactionType.None; + bool hasTryCatch = _suppressErrors || hasTransaction; - commandOutParams.Add(command, parameters.OutputParameters); + string tranName = null; + int startIndex = 0; + if (hasTryCatch) + { + if (hasTransaction) + { + string isoLevel = GetIsolationLevelStr(_isolationLevel); + if (isoLevel == null) + throw new ArgumentOutOfRangeException( + nameof(IsolationLevel), + _isolationLevel, + string.Format(Resources.SqlBatch_Process_IsolationLevelNotSupported, _isolationLevel)); + + // Set the isolation level and begin or save a transaction for the batch + tranName = "[" + ID.ToString("N") + "]"; + args.SqlBuilder + .AppendLine() + .Append("SET TRANSACTION ISOLATION LEVEL ") + .Append(isoLevel) + .AppendLine(";") + + .Append(args.InTransaction ? "SAVE" : "BEGIN") + .Append(" TRANSACTION ") + .Append(tranName) + .AppendLine(";"); + args.TransactionStack.Push(tranName, isoLevel); } - SqlProgramMapping mapping = parameters.Mapping; - SqlProgram program = command.Program; - LoadBalancedConnection loadBalancedConnection = program.Connection; - Connection connection = mapping.Connection; - - // Need to wait on the semaphores for all the connections and databases - if (connection.Semaphore != null) - connectionSemaphores.Add(connection.Semaphore); - if (loadBalancedConnection.ConnectionSemaphore != null) - loadBalConnectionSemaphores.Add(loadBalancedConnection.ConnectionSemaphore); - if (loadBalancedConnection.DatabaseSemaphore != null) - databaseSemaphores.Add(loadBalancedConnection.DatabaseSemaphore); - - // The mask the behavior with this commands behavior - allBehavior &= command.CommandBehavior; - - // Build batch SQL - sqlBuilder - .Append("-- ") - .AppendLine(command.Program.Name); + // Wrap the contents of the batch in a TRY ... CATCH block + args.SqlBuilder + .AppendLine() + .AppendLine("BEGIN TRY") + .AppendLine() + .GetLength(out startIndex); + } - command.AppendExecuteSql(sqlBuilder, parameters); + // Process the items in this batch + foreach (IBatchItem item in _items) + item.Process(uid, connectionString, args); - if (parameters.OutputParameters != null) + if (hasTryCatch) + { + if (hasTransaction) { - AppendInfo(sqlBuilder, uid, OutputState, commandIndex) - .Append("SELECT "); + args.TransactionStack.Pop(out string name, out _); + Debug.Assert(name == tranName); + } - bool firstParam = true; - foreach ((DbBatchParameter batchParameter, IOut outValue) in parameters.OutputParameters) + // If the transaction type is Commit and this is a root transaction, commit it + if (_transaction == TransactionType.Commit && !args.InTransaction) + { + args.SqlBuilder + .AppendLine() + .Append("COMMIT TRANSACTION ") + .Append(tranName) + .AppendLine(";"); + } + // If the transaction is Rollback, always roll it back + else if (_transaction == TransactionType.Rollback) + args.SqlBuilder + .Append("ROLLBACK TRANSACTION ") + .Append(tranName) + .AppendLine(";"); + + // End the TRY block and start the CATCH block + args.SqlBuilder + .IndentRegion(startIndex) + .AppendLine() + .AppendLine("END TRY") + .AppendLine("BEGIN CATCH") + .GetLength(out startIndex); + + // If there is a transaction, roll it back if possible + if (hasTransaction) + { + if (args.InTransaction) + args.SqlBuilder + .AppendLine("IF XACT_STATE() <> -1 ") + .Append("\t"); + + args.SqlBuilder + .Append("ROLLBACK TRANSACTION ") + .Append(tranName) + .AppendLine(";"); + } + + // Output an Error info message then select the error information + AppendInfo(args.SqlBuilder, uid, Constants.ExecuteState.Error, "%d", null, "@CmdIndex") + .AppendLine( + "SELECT\tERROR_NUMBER(),\r\n\tERROR_SEVERITY(),\r\n\tERROR_STATE(),\r\n\tERROR_LINE(),\r\n\tISNULL(QUOTENAME(ERROR_PROCEDURE()),'NULL'),\r\n\tERROR_MESSAGE();"); + + // If the error isnt being suppressed, rethrow it for any outer catches to handle it + if (!_suppressErrors) + { + if (args.ServerVersion.Major < 11) { - if (firstParam) firstParam = false; - else sqlBuilder.Append(", "); - - sqlBuilder.Append(batchParameter.BaseParameter.ParameterName); + // Cant rethrow the actual error, so raise a special error message + args.SqlBuilder + .Append("RAISERROR(") + .AppendVarChar($"{uid}{Constants.ExecuteState.ReThrow}:%d:") + .AppendLine(",16,0,@CmdIndex);"); + } + else + { + args.SqlBuilder.AppendLine("THROW;"); } - sqlBuilder.AppendLine(";"); } - AppendInfo(sqlBuilder, uid, EndState, commandIndex).AppendLine(); + // End the CATCH block + args.SqlBuilder + .IndentRegion(startIndex) + .AppendLine() + .AppendLine("END CATCH") + .AppendLine(); - commandIndex++; + // Reset the isolation level + if (!IsRoot && hasTransaction) + { + if (!args.TransactionStack.TryPeek(out _, out string isoLevel)) + isoLevel = GetIsolationLevelStr(IsolationLevel.Unspecified); + + if (isoLevel != null) + args.SqlBuilder + .AppendLine() + .Append("SET TRANSACTION ISOLATION LEVEL ") + .Append(isoLevel) + .AppendLine(";"); + } } + + args.SqlBuilder + .AppendLine() + .AppendLine("/*") + .Append(" * Ending batch ") + .AppendLine(ID.ToString("D")) + .AppendLine(" */"); + } - // Concat the semaphores to a single array - int semaphoreCount = - connectionSemaphores.Count + loadBalConnectionSemaphores.Count + databaseSemaphores.Count; - if (semaphoreCount < 1) - semaphores = Array.Empty; - else + private static string GetIsolationLevelStr(IsolationLevel isoLevel) + { + switch (isoLevel) { - semaphores = new AsyncSemaphore[semaphoreCount]; - int i = 0; - - // NOTE! Do NOT reorder these without also reordering the semaphores in SqlProgramCommand.WaitSemaphoresAsync - foreach (AsyncSemaphore semaphore in connectionSemaphores) - semaphores[i++] = semaphore; - foreach (AsyncSemaphore semaphore in loadBalConnectionSemaphores) - semaphores[i++] = semaphore; - foreach (AsyncSemaphore semaphore in databaseSemaphores) - semaphores[i++] = semaphore; + case IsolationLevel.ReadUncommitted: + return "READ UNCOMMITTED"; + case IsolationLevel.ReadCommitted: + case IsolationLevel.Unspecified: + return "READ COMMITTED"; + case IsolationLevel.RepeatableRead: + return "REPEATABLE READ"; + case IsolationLevel.Serializable: + return "SERIALIZABLE"; + case IsolationLevel.Snapshot: + return "SNAPSHOT"; + case IsolationLevel.Chaos: + default: + return null; } } - // this would be provider specific + /// + /// Appends an info message to the SQL builder. + /// + /// The SQL builder. + /// The uid. + /// The state. + /// The index. + /// The arguments. + /// The format arguments. + /// + /// The + /// + [NotNull] + internal static SqlStringBuilder AppendInfo( + [NotNull] SqlStringBuilder sqlBuilder, + [NotNull] string uid, + [NotNull] string state, + int index, + string args = null, + string formatArgs = null) + { + return AppendInfo( + sqlBuilder, + uid, + state, + index.ToString(), + args, + formatArgs); + } + + /// + /// Appends an info message to the SQL builder. + /// + /// The SQL builder. + /// The uid. + /// The state. + /// The index. + /// The arguments. + /// The format arguments. + /// + /// The + /// [NotNull] - private static SqlStringBuilder AppendInfo( + internal static SqlStringBuilder AppendInfo( [NotNull] SqlStringBuilder sqlBuilder, [NotNull] string uid, [NotNull] string state, - int index) + string index, + string args = null, + string formatArgs = null) { + // TODO this would be provider specific + Debug.Assert(!state.Contains(":")); - return sqlBuilder + sqlBuilder .Append("RAISERROR(") - .AppendVarChar($"{uid}{state}:{index}") - .Append(",4,") - .Append(unchecked((byte)index)) + .AppendVarChar($"{uid}{state}:{index}:{args ?? string.Empty}") + .Append(",4,0"); + if (formatArgs != null) + sqlBuilder + .Append(',') + .Append(formatArgs); + return sqlBuilder .AppendLine(");"); } - private static bool TryParseInfoMessage([NotNull] string message, out (string state, int index) info) + /// + /// Attempts to parse an information message. + /// + /// The message. + /// The state. + /// The index. + /// The arguments. + /// + private static bool TryParseInfoMessage([NotNull] string message, ref string state, ref int index, ref string args) { - int ind = message.IndexOf(':'); - if (ind < 0 || !ushort.TryParse(message.Substring(ind + 1), out ushort index)) - { - info = (null, -1); + int ind1 = message.IndexOf(':'); + int ind1p1 = ind1 + 1; + int ind2 = message.IndexOf(':', ind1p1); + if (ind1 < 0 || ind2 < 0 || !ushort.TryParse(message.Substring(ind1p1, ind2 - ind1p1), out ushort parsedIndex)) return false; - } - info = (message.Substring(0, ind), index); + state = message.Substring(0, ind1); + index = parsedIndex; + args = (ind2 + 1 >= message.Length) ? null : message.Substring(ind2 + 1); return true; } + /// + /// Registers an information message handler. + /// + /// The connection. + /// The uid. + /// The handler. + /// private void RegisterInfoMessageHandler( [NotNull] DbConnection connection, [NotNull] string uid, @@ -877,6 +1499,10 @@ private DbCommand CreateCommand( { Debug.Assert(command.Connection == connection); +#if DEBUG + _sql = text; +#endif + command.CommandText = text; command.CommandType = CommandType.Text; command.CommandTimeout = (int)BatchTimeout.TotalSeconds(); @@ -910,6 +1536,15 @@ private DbBatchDataReader CreateReader([NotNull] DbDataReader reader, CommandBeh } } + /// + /// An equality comparer that compares only the connection string for a connection. + /// + [NotNull] + private static readonly EqualityBuilder _connectionStringEquality = + new EqualityBuilder( + (a, b) => a.ConnectionString == b.ConnectionString, + c => c.ConnectionString.GetHashCode()); + /// /// Determines the connection string to use. /// @@ -917,12 +1552,13 @@ private DbBatchDataReader CreateReader([NotNull] DbDataReader reader, CommandBeh /// /// [NotNull] - private string DetermineConnection() + private Connection DetermineConnection() { - Debug.Assert(_commands.Count > 0); + Debug.Assert(IsRoot); + Debug.Assert(_state.CommandCount > 0); // Get the connection strings which are common to each program - HashSet commonConnections = _commonConnectionStrings.Value; + HashSet commonConnections = GetCommonConnections(); // If there is a single common connection string, just use that if (commonConnections.Count == 1) @@ -930,23 +1566,23 @@ private string DetermineConnection() Debug.Assert(commonConnections != null); - Dictionary connWeightCounts = - new Dictionary(); + Dictionary connWeightCounts = + new Dictionary(_connectionStringEquality); - foreach (SqlBatchCommand command in _commands) + foreach (SqlBatchCommand command in this) { SqlProgramMapping[] mappings = command.Program.Mappings - .Where(m => commonConnections.Contains(m.Connection.ConnectionString)) + .Where(m => commonConnections.Contains(m.Connection)) .ToArray(); double totWeight = mappings.Sum(m => m.Connection.Weight); foreach (SqlProgramMapping mapping in mappings) { - if (!connWeightCounts.TryGetValue(mapping.Connection.ConnectionString, out var counter)) + if (!connWeightCounts.TryGetValue(mapping.Connection, out var counter)) { counter = new WeightCounter(); - connWeightCounts.Add(mapping.Connection.ConnectionString, counter); + connWeightCounts.Add(mapping.Connection, counter); } counter.Increment(mapping.Connection.Weight / totWeight); @@ -964,29 +1600,30 @@ private string DetermineConnection() /// [NotNull] [ItemNotNull] - private HashSet GetCommonConnections() + private HashSet GetCommonConnections() { - string commonConnectionStr = null; - HashSet commonConnections = null; - foreach (SqlBatchCommand command in _commands) + Connection commonConnection = null; + HashSet commonConnections = null; + foreach (SqlBatchCommand command in this) { if (commonConnections == null) { - commonConnections = new HashSet( + commonConnections = new HashSet( command.Program.Mappings .Where(m => m.Connection.Weight > 0) - .Select(m => m.Connection.ConnectionString)); + .Select(m => m.Connection), + _connectionStringEquality); } else { // If there's a single common connection, just check if any mapping has that connection. - if (commonConnectionStr != null) + if (commonConnection != null) { bool contains = false; foreach (SqlProgramMapping mapping in command.Program.Mappings) { if (mapping.Connection.Weight > 0 && - mapping.Connection.ConnectionString == commonConnectionStr) + mapping.Connection.ConnectionString == commonConnection.ConnectionString) { contains = true; break; @@ -1001,14 +1638,14 @@ private HashSet GetCommonConnections() commonConnections.IntersectWith( command.Program.Mappings .Where(m => m.Connection.Weight > 0) - .Select(m => m.Connection.ConnectionString)); + .Select(m => m.Connection)); } if (commonConnections.Count < 1) throw new InvalidOperationException(Resources.SqlBatch_AddCommand_NoCommonConnections); if (commonConnections.Count == 1) - commonConnectionStr = commonConnections.Single(); + commonConnection = commonConnections.Single(); } Debug.Assert(commonConnections != null, "commonConnections != null"); return commonConnections; diff --git a/Database/SqlBatchCommand.cs b/Database/SqlBatchCommand.cs index 1693246d..8a6b20cb 100644 --- a/Database/SqlBatchCommand.cs +++ b/Database/SqlBatchCommand.cs @@ -26,6 +26,7 @@ #endregion using System; +using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; @@ -35,20 +36,15 @@ using System.Threading.Tasks; using WebApplications.Utilities.Annotations; using WebApplications.Utilities.Database.Schema; +using WebApplications.Utilities.Threading; namespace WebApplications.Utilities.Database { /// /// A command for controlling a batch execution of a . /// - public abstract class SqlBatchCommand + public abstract class SqlBatchCommand : IBatchItem { - /// - /// Gets the command identifier. - /// - /// The identifier. - public ushort Id { get; internal set; } - /// /// Gets the batch that this command belongs to. /// @@ -74,13 +70,18 @@ public abstract class SqlBatchCommand /// /// The result. [NotNull] - public SqlBatchResult Result { get; private set; } + public SqlBatchResult Result { get; } /// /// The behavior of the command. /// internal readonly CommandBehavior CommandBehavior; + /// + /// The index of the command in the root batch. + /// + private int _index; + /// /// Initializes a new instance of the class. /// @@ -88,11 +89,13 @@ public abstract class SqlBatchCommand /// The program to execute. /// An optional method for setting the parameters to pass to the program. /// The preferred command behavior. + /// The result object. internal SqlBatchCommand( [NotNull] SqlBatch batch, [NotNull] SqlProgram program, [CanBeNull] SetBatchParametersDelegate setParameters, - CommandBehavior commandBehavior) + CommandBehavior commandBehavior, + [NotNull] SqlBatchResult result) { Batch = batch; Program = program; @@ -100,20 +103,158 @@ internal SqlBatchCommand( if ((commandBehavior & CommandBehavior.SingleRow) == CommandBehavior.SingleRow) commandBehavior |= CommandBehavior.SingleResult; CommandBehavior = commandBehavior; + Result = result; + result.Command = this; } + /// + /// Processes the command to be executed. + /// + /// The uid. + /// The connection string. + /// The arguments. + void IBatchItem.Process( + string uid, + string connectionString, + BatchProcessArgs args) + { + _index = args.CommandIndex; + + // Get the parameters for the command + SqlBatchParametersCollection parameters = GetParametersForConnection(connectionString, args.CommandIndex); + + List dependentParams = null; + + // Add the parameters to the collection to pass to the command + foreach (DbBatchParameter batchParameter in parameters.Parameters) + { + if (batchParameter.OutputValue != null) + { + if (!args.OutParameters.TryGetValue(batchParameter.OutputValue, out DbParameter dbParameter)) + throw new NotImplementedException("proper error"); + batchParameter.BaseParameter = dbParameter; + + if (dependentParams == null) dependentParams = new List(); + dependentParams.Add(batchParameter); + } + else + { + args.AllParameters.Add(batchParameter.BaseParameter); + } + } + + // Add any output parameters to the dictionary for passing into following commands + if (parameters.OutputParameters != null) + { + foreach ((DbBatchParameter batchParameter, IOut outValue) in parameters.OutputParameters) + { + if (args.OutParameters.ContainsKey(outValue)) + throw new NotImplementedException("proper error"); + + args.OutParameters.Add(outValue, batchParameter.BaseParameter); + args.OutParameterCommands.Add(outValue, this); + } + + args.CommandOutParams.Add(this, parameters.OutputParameters); + } + + // The mask the behavior with this commands behavior + args.Behavior &= CommandBehavior; + + // Build command SQL + args.SqlBuilder + .Append("-- ") + .AppendLine(Program.Name) + + // Used in CATCH statements to know which command failed + .Append("SET @CmdIndex = ") + .Append(args.CommandIndex) + .AppendLine(";"); + + // If any of the parameters come from output parameters of previous commands, + // we need to make sure the commands executed successfully + if (dependentParams != null) + { + foreach (IGrouping cmd in dependentParams.GroupBy(p => args.OutParameterCommands[p.OutputValue])) + { + string[] paramNames = cmd.Select(p => p.ParameterName).ToArray(); + // TODO get messages from Resources? + string message = paramNames.Length == 1 + ? $"The value of the {paramNames[0]} parameter depends on the output of a previous command which has not been executed successfully." + : $"The value of the parameters {string.Join(",", paramNames)} depend on the output of a previous command which has not been executed successfully."; + + args.SqlBuilder + .Append("IF (ISNULL(@Cmd") + .Append(cmd.Key._index) + .AppendLine("Success,0) <> 1)") + .Append("\tRAISERROR(") + .AppendNVarChar(message) + .AppendLine(",16,0);"); + } + args.SqlBuilder.AppendLine(); + } + + AppendExecuteSql(args.SqlBuilder, parameters); + + if (parameters.OutputParameters != null) + { + // Sets a flag indicating that this command executed successfully, + // so any command using the output parameters can check + args.SqlBuilder + .Append("DECLARE @Cmd") + .Append(args.CommandIndex) + .Append("Success bit") + .AppendLine( + args.ServerVersion.Major > 9 + ? " = 1;" + : $"; SET @Cmd{args.CommandIndex}Success = 1;"); + + SqlBatch + .AppendInfo(args.SqlBuilder, uid, Constants.ExecuteState.Output, args.CommandIndex) + .Append("SELECT "); + + bool firstParam = true; + foreach ((DbBatchParameter batchParameter, _) in parameters.OutputParameters) + { + if (firstParam) firstParam = false; + else args.SqlBuilder.Append(", "); + + args.SqlBuilder.Append(batchParameter.BaseParameter.ParameterName); + } + args.SqlBuilder.AppendLine(";"); + } + + SqlBatch.AppendInfo(args.SqlBuilder, uid, Constants.ExecuteState.End, args.CommandIndex).AppendLine(); + + SqlProgramMapping mapping = parameters.Mapping; + SqlProgram program = Program; + LoadBalancedConnection loadBalancedConnection = program.Connection; + Connection connection = mapping.Connection; + + // Need to wait on the semaphores for all the connections and databases + if (connection.Semaphore != null) + args.ConnectionSemaphores.Add(connection.Semaphore); + if (loadBalancedConnection.ConnectionSemaphore != null) + args.LoadBalConnectionSemaphores.Add(loadBalancedConnection.ConnectionSemaphore); + if (loadBalancedConnection.DatabaseSemaphore != null) + args.DatabaseSemaphores.Add(loadBalancedConnection.DatabaseSemaphore); + + args.CommandIndex++; + } + /// /// Gets the parameters for the connection with the connection string given. /// /// The connection string. - /// + /// Index of the command in the batch. + /// The collection of parameters. [NotNull] - internal SqlBatchParametersCollection GetParametersForConnection([NotNull] string connectionString) + internal SqlBatchParametersCollection GetParametersForConnection([NotNull] string connectionString, ushort commandIndex) { SqlProgramMapping mapping = Program.Mappings.Single(m => m.Connection.ConnectionString == connectionString); Debug.Assert(mapping != null, "mapping != null"); - SqlBatchParametersCollection parameters = new SqlBatchParametersCollection(mapping, this); + SqlBatchParametersCollection parameters = new SqlBatchParametersCollection(mapping, this, commandIndex); _setParameters?.Invoke(parameters); @@ -203,11 +344,7 @@ internal class Scalar : SqlBatchCommand /// /// The result. [NotNull] - public new SqlBatchResult Result - { - get => (SqlBatchResult)base.Result; - private set => base.Result = value; - } + public new SqlBatchResult Result => (SqlBatchResult)base.Result; /// /// Initializes a new instance of the class. @@ -219,9 +356,8 @@ public Scalar( [NotNull] SqlBatch batch, [NotNull] SqlProgram program, SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters, CommandBehavior.SequentialAccess) + : base(batch, program, setParameters, CommandBehavior.SequentialAccess, new SqlBatchResult()) { - Result = new SqlBatchResult(this); } /// @@ -259,11 +395,7 @@ internal class NonQuery : SqlBatchCommand /// /// The result. [NotNull] - public new SqlBatchResult Result - { - get => (SqlBatchResult)base.Result; - private set => base.Result = value; - } + public new SqlBatchResult Result => (SqlBatchResult)base.Result; /// /// Initializes a new instance of the class. @@ -275,9 +407,8 @@ public NonQuery( [NotNull] SqlBatch batch, [NotNull] SqlProgram program, SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters, CommandBehavior.SequentialAccess) + : base(batch, program, setParameters, CommandBehavior.SequentialAccess, new SqlBatchResult()) { - Result = new SqlBatchResult(this); } /// @@ -323,6 +454,10 @@ internal override async Task HandleCommandAsync( } } + /// + /// Base class for calling ExecuteReader on a program in a batch. + /// + /// internal abstract class BaseReader : SqlBatchCommand { /// @@ -332,12 +467,14 @@ internal abstract class BaseReader : SqlBatchCommand /// The program to execute. /// An optional method for setting the parameters to pass to the program. /// The behavior. + /// The result. protected BaseReader( [NotNull] SqlBatch batch, [NotNull] SqlProgram program, [CanBeNull] SetBatchParametersDelegate setParameters, - CommandBehavior commandBehavior) - : base(batch, program, setParameters, commandBehavior) + CommandBehavior commandBehavior, + [NotNull] SqlBatchResult result) + : base(batch, program, setParameters, commandBehavior, result) { } @@ -348,6 +485,7 @@ protected BaseReader( /// The parameters to execute with. public override void AppendExecuteSql(SqlStringBuilder builder, SqlBatchParametersCollection parameters) { + // ReSharper disable StringLiteralTypo if ((CommandBehavior & CommandBehavior.KeyInfo) == CommandBehavior.KeyInfo) builder.AppendLine("SET NO_BROWSETABLE ON;"); if ((CommandBehavior & CommandBehavior.SchemaOnly) == CommandBehavior.SchemaOnly) @@ -359,6 +497,7 @@ public override void AppendExecuteSql(SqlStringBuilder builder, SqlBatchParamete builder.AppendLine("SET FMTONLY OFF;"); if ((CommandBehavior & CommandBehavior.KeyInfo) == CommandBehavior.KeyInfo) builder.AppendLine("SET NO_BROWSETABLE OFF;"); + // ReSharper restore StringLiteralTypo } /// @@ -422,10 +561,9 @@ public Reader( [NotNull] ResultDelegateAsync resultAction, CommandBehavior commandBehavior, [CanBeNull] SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters, commandBehavior) + : base(batch, program, setParameters, commandBehavior, new SqlBatchResult()) { _resultAction = resultAction; - Result = new SqlBatchResult(this); } /// @@ -465,11 +603,7 @@ internal class Reader : BaseReader /// /// The result. [NotNull] - public new SqlBatchResult Result - { - get => (SqlBatchResult)base.Result; - private set => base.Result = value; - } + public new SqlBatchResult Result => (SqlBatchResult)base.Result; /// /// Initializes a new instance of the class. @@ -485,10 +619,9 @@ public Reader( [NotNull] ResultDelegateAsync resultFunc, CommandBehavior commandBehavior, [CanBeNull] SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters, commandBehavior) + : base(batch, program, setParameters, commandBehavior, new SqlBatchResult()) { _resultFunc = resultFunc; - Result = new SqlBatchResult(this); } /// @@ -515,6 +648,10 @@ protected override void SetResult(Task task, int index) } } + /// + /// Base class for calling ExecuteXmlReader on a program in a batch. + /// + /// internal abstract class BaseXmlReader : SqlBatchCommand { /// @@ -523,14 +660,16 @@ internal abstract class BaseXmlReader : SqlBatchCommand /// The batch that this command belongs to. /// The program to execute. /// An optional method for setting the parameters to pass to the program. + /// The result. protected BaseXmlReader( [NotNull] SqlBatch batch, [NotNull] SqlProgram program, - [CanBeNull] SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters, CommandBehavior.SequentialAccess) + [CanBeNull] SetBatchParametersDelegate setParameters, + [NotNull] SqlBatchResult result) + : base(batch, program, setParameters, CommandBehavior.SequentialAccess, result) { } - + /// /// Handles the command asynchronously. /// @@ -590,10 +729,9 @@ public XmlReader( [NotNull] SqlProgram program, [NotNull] XmlResultDelegateAsync resultAction, [CanBeNull] SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters) + : base(batch, program, setParameters, new SqlBatchResult()) { _resultAction = resultAction; - Result = new SqlBatchResult(this); } /// @@ -633,11 +771,7 @@ internal class XmlReader : BaseXmlReader /// /// The result. [NotNull] - public new SqlBatchResult Result - { - get => (SqlBatchResult)base.Result; - private set => base.Result = value; - } + public new SqlBatchResult Result => (SqlBatchResult)base.Result; /// /// Initializes a new instance of the class. @@ -651,10 +785,9 @@ public XmlReader( [NotNull] SqlProgram program, [NotNull] XmlResultDelegateAsync resultFunc, [CanBeNull] SetBatchParametersDelegate setParameters) - : base(batch, program, setParameters) + : base(batch, program, setParameters, new SqlBatchResult()) { _resultFunc = resultFunc; - Result = new SqlBatchResult(this); } /// diff --git a/Database/SqlBatchParametersCollection.cs b/Database/SqlBatchParametersCollection.cs index 7d316072..3dc1f79e 100644 --- a/Database/SqlBatchParametersCollection.cs +++ b/Database/SqlBatchParametersCollection.cs @@ -51,6 +51,8 @@ public partial class SqlBatchParametersCollection [NotNull] private readonly SqlBatchCommand _command; + private readonly ushort _commandIndex; + [NotNull] private readonly SqlProgramMapping _mapping; @@ -100,10 +102,12 @@ public partial class SqlBatchParametersCollection /// /// The mapping. /// The command. - internal SqlBatchParametersCollection([NotNull] SqlProgramMapping mapping, [NotNull] SqlBatchCommand command) + /// Index of the command. + internal SqlBatchParametersCollection([NotNull] SqlProgramMapping mapping, [NotNull] SqlBatchCommand command, ushort commandIndex) { _mapping = mapping; _command = command; + _commandIndex = commandIndex; } /// @@ -119,10 +123,13 @@ private DbBatchParameter GetOrAddParameter([NotNull] SqlProgramParameter program if (index >= 0) return _parameters[index]; + if (_parameters.Count >= ushort.MaxValue) + throw new InvalidOperationException(Resources.SqlBatchParametersCollection_GetOrAddParameter_OnlyAllowed65536); + SqlBatchParameter batchParameter = new SqlBatchParameter( programParameter, programParameter.CreateSqlParameter(), - "_" + _command.Id.ToString("X4") + _parameters.Count.ToString("X4")); + "_" + _commandIndex.ToString("X4") + _parameters.Count.ToString("X4")); _parameters.Add(batchParameter); if (programParameter.Direction == ParameterDirection.ReturnValue) diff --git a/Database/SqlBatchResult.cs b/Database/SqlBatchResult.cs index 0ab92304..1fa3dd68 100644 --- a/Database/SqlBatchResult.cs +++ b/Database/SqlBatchResult.cs @@ -68,16 +68,31 @@ static SqlBatchResult() /// /// The command that this is the result of. /// - [NotNull] - protected readonly SqlBatchCommand Command; + [CanBeNull] + private SqlBatchCommand _command; /// /// Initializes a new instance of the class. /// - /// The command that this is the result of. - internal SqlBatchResult([NotNull] SqlBatchCommand command) + internal SqlBatchResult() { - Command = command ?? throw new ArgumentNullException(nameof(command)); + } + + /// + /// The command that this is the result of. + /// + protected internal SqlBatchCommand Command + { + get + { + Debug.Assert(_command != null); + return _command; + } + internal set + { + Debug.Assert(_command == null); + _command = value; + } } /// @@ -139,9 +154,7 @@ public sealed class SqlBatchResult : SqlBatchResult /// /// Initializes a new instance of the class. /// - /// The command that this is the result of. - internal SqlBatchResult([NotNull] SqlBatchCommand command) - : base(command) + internal SqlBatchResult() { } @@ -278,7 +291,6 @@ protected override Task GetResultInternalAsync(CancellationToken cancellationTok /// Gets the results for all connections asynchronously. /// /// A cancellation token which can be used to cancel the operation. The batch will continue running. - /// The order of the results is based on the order of the connections in the property. /// An awaitable task which returns the results for all connections. public async Task> GetResultsAsync(CancellationToken cancellationToken = default(CancellationToken)) { diff --git a/Database/SqlProgram.cs b/Database/SqlProgram.cs index 2dec5a4f..0551b7e7 100644 --- a/Database/SqlProgram.cs +++ b/Database/SqlProgram.cs @@ -693,6 +693,7 @@ await newProgram.Validate(checkOrder, false, !ignoreValidationErrors, cancellati /// Gets the valid mappings for the . /// [NotNull] + [ItemNotNull] public IEnumerable Mappings { get diff --git a/Database/SqlStringBuilder.cs b/Database/SqlStringBuilder.cs index 7f01a459..85fdd961 100644 --- a/Database/SqlStringBuilder.cs +++ b/Database/SqlStringBuilder.cs @@ -101,6 +101,37 @@ public SqlStringBuilder AppendNVarChar(string value) return this; } + /// + /// Indents a region of the builder from the start index given to the end. + /// + /// The start. + [NotNull] + public SqlStringBuilder IndentRegion(int start) + { + int newLineLength = Environment.NewLine.Length; + if (start >= newLineLength) + { + if (_builder.ToString(start - newLineLength, newLineLength) == Environment.NewLine) + start -= newLineLength; + } + + _builder.Replace(Environment.NewLine, Environment.NewLine + "\t", start, _builder.Length - start); + + return this; + } + + /// + /// Gets the length of the builder. + /// + /// The length. + /// + [NotNull] + public SqlStringBuilder GetLength(out int length) + { + length = _builder.Length; + return this; + } + #region StringBuilder methods /// Appends the string representation of a specified 8-bit unsigned integer to this instance. /// A reference to this instance after the append operation has completed. @@ -415,7 +446,7 @@ public SqlStringBuilder AppendFormat(string format, object arg0, object arg1) [NotNull] public SqlStringBuilder AppendLine(string value) { - _builder.AppendLine(value); + _builder.Append(value).Append(Environment.NewLine); return this; } @@ -425,7 +456,7 @@ public SqlStringBuilder AppendLine(string value) [NotNull] public SqlStringBuilder AppendLine() { - _builder.AppendLine(); + _builder.Append(Environment.NewLine); return this; } @@ -514,7 +545,7 @@ public int Length /// A string that represents the current object. public override string ToString() => _builder.ToString(); #endregion - + /// /// Performs an implicit conversion from to . /// diff --git a/Database/Test/TestBatching.cs b/Database/Test/TestBatching.cs index d254eea5..cd937df4 100644 --- a/Database/Test/TestBatching.cs +++ b/Database/Test/TestBatching.cs @@ -28,7 +28,9 @@ // ReSharper disable ConsiderUsingConfigureAwait,UseConfigureAwait using System; +using System.Data; using System.Diagnostics; +using System.Reflection; using System.Text; using System.Threading.Tasks; using System.Xml.Linq; @@ -41,6 +43,29 @@ namespace WebApplications.Utilities.Database.Test [TestClass] public class TestBatching : DatabaseTestBase { + private static int _count = 0; + private static double _nonQueryTime; + private static double _scalarTime; + private static double _outputTime; + private static double _tableTime; + private static double _xmlTime; + private static double _doneTime; + + private static double _batchedNonQueryTime; + private static double _batchedScalarTime; + private static double _batchedOutputTime; + private static double _batchedTableTime; + private static double _batchedXmlTime; + private static double _batchedDoneTime; + + private static void AddTime(Stopwatch sw, ref double counter) + { + var elapsed = sw.Elapsed.TotalMilliseconds; + + if (_count > 0) + counter += elapsed; + } + private static void Setup( out SqlProgram nonQueryProg, out SqlProgram returnsScalarProg, @@ -88,7 +113,7 @@ private static void Setup( returnsXmlProg = database.GetSqlProgram("spReturnsXml").Result; - randomString = Random.RandomString(Encoding.GetEncoding(1252), maxLength: 20); + randomString = Random.RandomString(Encoding.ASCII, maxLength: 20); randomInt = Random.RandomInt32(); randomDecimal = Math.Round(Random.RandomDecimal() % 1_000_000_000m, 2); randomBool = Random.RandomBoolean(); @@ -102,10 +127,28 @@ public async Task PerfTest() { for (int i = 0; i < 100; i++) { + Trace.WriteLine("Run " + i); await TestBatchEverything(); await TestNotBatchEverything(); - Trace.WriteLine(""); + _count++; } + _count = 0; + + Trace.WriteLine($"B spNonQuery in {_batchedNonQueryTime / _count}ms"); + Trace.WriteLine($"B spWithParametersReturnsScalar in {_batchedScalarTime / _count}ms"); + Trace.WriteLine($"B spOutputParameters in {_batchedOutputTime / _count}ms"); + Trace.WriteLine($"B spReturnsTable in {_batchedTableTime / _count}ms"); + Trace.WriteLine($"B spReturnsXml in {_batchedXmlTime / _count}ms"); + Trace.WriteLine($"B Done in {_batchedDoneTime / _count}ms"); + + Trace.WriteLine(""); + + Trace.WriteLine($"N spNonQuery in {_nonQueryTime / _count}ms"); + Trace.WriteLine($"N spWithParametersReturnsScalar in {_scalarTime / _count}ms"); + Trace.WriteLine($"N spOutputParameters in {_outputTime / _count}ms"); + Trace.WriteLine($"N spReturnsTable in {_tableTime / _count}ms"); + Trace.WriteLine($"N spReturnsXml in {_xmlTime / _count}ms"); + Trace.WriteLine($"N Done in {_doneTime / _count}ms"); } [TestMethod] @@ -124,7 +167,10 @@ public async Task TestBatchEverything() out Out output, out Out inputOutput); - SqlBatch batch = new SqlBatch() + SqlBatchResult outputResult = null; + SqlBatchResult tableResult = null; + + SqlBatch batch = SqlBatch.CreateTransaction(IsolationLevel.ReadUncommitted) .AddExecuteNonQuery( nonQueryProg, out SqlBatchResult nonQueryResult, @@ -137,29 +183,31 @@ public async Task TestBatchEverything() randomInt, randomDecimal, randomBool) - .AddExecuteScalar( - outputParametersProg, - out SqlBatchResult outputResult, - randomInt, - inputOutput, - output) - .AddExecuteReader( - returnsTableProg, - async (reader, token) => - { - Assert.IsTrue(await reader.ReadAsync(token)); + .AddTransaction( + b => b.AddExecuteScalar( + outputParametersProg, + out outputResult, + randomInt, + inputOutput, + output) + .AddExecuteReader( + returnsTableProg, + async (reader, token) => + { + Assert.IsTrue(await reader.ReadAsync(token)); - Assert.AreEqual(randomString, reader.GetValue(0)); - Assert.AreEqual(output.Value, reader.GetValue(1)); - Assert.AreEqual(randomDecimal, reader.GetValue(2)); - Assert.AreEqual(randomBool, reader.GetValue(3)); - }, - out SqlBatchResult tableResult, - randomString, - // Using output of previous program as input - output, - randomDecimal, - randomBool) + Assert.AreEqual(randomString, reader.GetValue(0)); + Assert.AreEqual(output.Value, reader.GetValue(1)); + Assert.AreEqual(randomDecimal, reader.GetValue(2)); + Assert.AreEqual(randomBool, reader.GetValue(3)); + }, + out tableResult, + randomString, + // Using output of previous program as input + output, + randomDecimal, + randomBool), + IsolationLevel.Serializable) .AddExecuteXmlReader( returnsXmlProg, (reader, token) => @@ -175,15 +223,26 @@ public async Task TestBatchEverything() // ReSharper disable ConsiderUsingConfigureAwait,UseConfigureAwait #pragma warning disable 4014 - nonQueryResult.GetResultAsync().ContinueWith(_ => Trace.WriteLine($"B spNonQuery @ {sw.Elapsed.TotalMilliseconds}ms")); - scalarResult.GetResultAsync().ContinueWith(_ => Trace.WriteLine($"B spWithParametersReturnsScalar @ {sw.Elapsed.TotalMilliseconds}ms")); - outputResult.GetResultAsync().ContinueWith(_ => Trace.WriteLine($"B spOutputParameters @ {sw.Elapsed.TotalMilliseconds}ms")); - tableResult.GetResultAsync().ContinueWith(_ => Trace.WriteLine($"B spReturnsTable @ {sw.Elapsed.TotalMilliseconds}ms")); - xmlResult.GetResultAsync().ContinueWith(_ => Trace.WriteLine($"B spReturnsXml @ {sw.Elapsed.TotalMilliseconds}ms")); + nonQueryResult.GetResultAsync().ContinueWith(_ => AddTime(sw, ref _batchedNonQueryTime)); + scalarResult.GetResultAsync().ContinueWith(_ => AddTime(sw, ref _batchedScalarTime)); + outputResult.GetResultAsync().ContinueWith(_ => AddTime(sw, ref _batchedOutputTime)); + tableResult.GetResultAsync().ContinueWith(_ => AddTime(sw, ref _batchedTableTime)); + xmlResult.GetResultAsync().ContinueWith(_ => AddTime(sw, ref _batchedXmlTime)); #pragma warning restore 4014 - await batch.ExecuteAsync(); - Trace.WriteLine($"B Done @ {sw.Elapsed.TotalMilliseconds}ms"); + try + { + await batch.ExecuteAsync(); + AddTime(sw, ref _batchedDoneTime); + } + finally + { + if (_count == 0) + { + Trace.WriteLine("SQL:"); + Trace.WriteLine(GetSql(batch) ?? ""); + } + } } [TestMethod] @@ -204,12 +263,12 @@ public async Task TestNotBatchEverything() Stopwatch sw = Stopwatch.StartNew(); - await nonQueryProg.ExecuteNonQueryAsync(randomString, randomInt) - .ContinueWith(_ => Trace.WriteLine($"N spNonQuery @ {sw.Elapsed.TotalMilliseconds}ms")); - await returnsScalarProg.ExecuteScalarAsync(randomString, randomInt, randomDecimal, randomBool) - .ContinueWith(_ => Trace.WriteLine($"N spWithParametersReturnsScalar @ {sw.Elapsed.TotalMilliseconds}ms")); - await outputParametersProg.ExecuteScalarAsync(randomInt, inputOutput, output) - .ContinueWith(_ => Trace.WriteLine($"N spOutputParameters @ {sw.Elapsed.TotalMilliseconds}ms")); + await nonQueryProg.ExecuteNonQueryAsync(randomString, randomInt); + AddTime(sw, ref _nonQueryTime); + await returnsScalarProg.ExecuteScalarAsync(randomString, randomInt, randomDecimal, randomBool); + AddTime(sw, ref _scalarTime); + await outputParametersProg.ExecuteScalarAsync(randomInt, inputOutput, output); + AddTime(sw, ref _outputTime); await returnsTableProg.ExecuteReaderAsync( async (reader, token) => { @@ -224,15 +283,25 @@ await returnsTableProg.ExecuteReaderAsync( // Using output of previous program as input output.Value, randomDecimal, - randomBool).ContinueWith(_ => Trace.WriteLine($"N spReturnsTable @ {sw.Elapsed.TotalMilliseconds}ms")); + randomBool); + AddTime(sw, ref _tableTime); await returnsXmlProg.ExecuteXmlReaderAsync( (reader, token) => { string xml = XElement.Load(reader).ToString(); Assert.AreEqual("bar", xml); return TaskResult.Completed; - }).ContinueWith(_ => Trace.WriteLine($"N spReturnsXml @ {sw.Elapsed.TotalMilliseconds}ms")); - Trace.WriteLine($"N Done @ {sw.Elapsed.TotalMilliseconds}ms"); + }); + AddTime(sw, ref _xmlTime); + AddTime(sw, ref _doneTime); + } + + private static string GetSql(SqlBatch batch) + { + FieldInfo field = typeof(SqlBatch).GetField("_sql", BindingFlags.NonPublic | BindingFlags.Instance); + if (field == null) return null; + + return (string)field.GetValue(batch); } } } \ No newline at end of file diff --git a/Database/WebApplications.Utilities.Database.csproj b/Database/WebApplications.Utilities.Database.csproj index 660800a4..44806176 100644 --- a/Database/WebApplications.Utilities.Database.csproj +++ b/Database/WebApplications.Utilities.Database.csproj @@ -37,6 +37,7 @@ + @@ -50,6 +51,7 @@ + @@ -147,6 +149,7 @@ SqlProgramGeneric.Generated.txt4 +