diff --git a/LinqToQuerystring/Extensions.cs b/LinqToQuerystring/Extensions.cs index c755d33..3bb6a7e 100644 --- a/LinqToQuerystring/Extensions.cs +++ b/LinqToQuerystring/Extensions.cs @@ -5,6 +5,8 @@ using System.Collections.Generic; using System.IO; using System.Linq; + using System.Linq.Expressions; + using System.Reflection; using Antlr.Runtime; using Antlr.Runtime.Tree; @@ -95,7 +97,7 @@ public static object LinqToQuerystring(this IQueryable query, Type inputType, st { var children = tree.Children.Cast<TreeNode>().ToList(); children.Sort(); - + // These should always come first foreach (var node in children.Where(o => !(o is SelectNode) && !(o is InlineCountNode))) { @@ -140,8 +142,9 @@ private static void BuildQuery(TreeNode node, ref IQueryable queryResult, ref IQ } else { - queryResult = queryResult.Provider.CreateQuery( - node.BuildLinqExpression(queryResult, queryResult.Expression)); + var expression = node.BuildLinqExpression(queryResult, queryResult.Expression); + var queryType = expression.Type.GetGenericArguments()[0]; + queryResult = CreateQuery(queryResult.Provider, expression, queryType); } } @@ -152,12 +155,18 @@ private static void BuildQuery(TreeNode node, ref IQueryable queryResult, ref IQ } else { - constrainedQuery = - constrainedQuery.Provider.CreateQuery( - node.BuildLinqExpression(constrainedQuery, constrainedQuery.Expression)); + var expression = node.BuildLinqExpression(constrainedQuery, constrainedQuery.Expression); + var queryType = expression.Type.GetGenericArguments()[0]; + constrainedQuery = CreateQuery(constrainedQuery.Provider, expression, queryType); } } + private static IQueryable CreateQuery(IQueryProvider provider, Expression expression, Type queryType) + { + var genericMethod = _createQueryMethodInfo.MakeGenericMethod(queryType); + return (IQueryable)genericMethod.Invoke(provider,new object[] { expression }); + } + private static IQueryable ProjectQuery(IQueryable constrainedQuery, TreeNode node) { // TODO: Find a solution to the following: @@ -189,5 +198,7 @@ static IEnumerable Iterate(this IEnumerator iterator) while (iterator.MoveNext()) yield return iterator.Current; } + + private static MethodInfo _createQueryMethodInfo = typeof(IQueryProvider).GetMethods().First(x => x.Name == "CreateQuery" && x.IsGenericMethod); } } \ No newline at end of file