From 6450d2c318c6fca001bd9821747aed88f580d185 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Wed, 19 Feb 2025 09:33:51 -0600 Subject: [PATCH 1/5] Separate as() and into() --- src/FunSQL.jl | 1 + src/link.jl | 139 +++++++++++++++++++++--------------------- src/nodes.jl | 1 + src/nodes/as.jl | 17 +++--- src/nodes/hide.jl | 42 +++++++++++++ src/nodes/internal.jl | 20 ++---- src/nodes/into.jl | 28 +++++++++ src/resolve.jl | 34 +++++------ src/translate.jl | 66 ++++++++------------ 9 files changed, 196 insertions(+), 152 deletions(-) create mode 100644 src/nodes/hide.jl create mode 100644 src/nodes/into.jl diff --git a/src/FunSQL.jl b/src/FunSQL.jl index fb6a5ca8..618c6610 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -56,6 +56,7 @@ export funsql_group, funsql_highlight, funsql_in, + funsql_into, funsql_iterate, funsql_is_not_null, funsql_is_null, diff --git a/src/link.jl b/src/link.jl index 00015f41..3736efea 100644 --- a/src/link.jl +++ b/src/link.jl @@ -123,19 +123,15 @@ function dismantle(n::GroupNode, ctx) Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) end -function dismantle(n::IterateNode, ctx) +function dismantle(n::IntoNode, ctx) tail′ = dismantle(ctx) - iterator′ = dismantle(n.iterator, ctx) - Iterate(iterator = iterator′, tail = tail′) + Into(name = n.name, tail = tail′) end -function dismantle(n::JoinNode, ctx) - rt = row_type(n.joinee) - router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType)) +function dismantle(n::IterateNode, ctx) tail′ = dismantle(ctx) - joinee′ = dismantle(n.joinee, ctx) - on′ = dismantle_scalar(n.on, ctx) - RoutedJoin(joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional, tail = tail′) + iterator′ = dismantle(n.iterator, ctx) + Iterate(iterator = iterator′, tail = tail′) end function dismantle(n::LimitNode, ctx) @@ -181,6 +177,13 @@ function dismantle_scalar(n::ResolvedNode, ctx) end end +function dismantle(n::RoutedJoinNode, ctx) + tail′ = dismantle(ctx) + joinee′ = dismantle(n.joinee, ctx) + on′ = dismantle_scalar(n.on, ctx) + RoutedJoin(joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional, tail = tail′) +end + function dismantle(n::SelectNode, ctx) tail′ = dismantle(ctx) args′ = dismantle_scalar(n.args, ctx) @@ -232,16 +235,7 @@ function link(n::AppendNode, ctx) end function link(n::AsNode, ctx) - refs = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested(name = (local name))) - @assert name == n.name - push!(refs, tail) - else - error() - end - end - tail′ = link(ctx.tail, ctx, refs) + tail′ = link(ctx) As(name = n.name, tail = tail′) end @@ -289,10 +283,8 @@ function link(n::FromIterateNode, ctx) end function link(n::FromTableExpressionNode, ctx) - refs = ctx.cte_refs[(n.name, n.depth)] - for ref in ctx.refs - push!(refs, Nested(name = n.name, tail = ref)) - end + cte_refs = ctx.cte_refs[(n.name, n.depth)] + append!(cte_refs, ctx.refs) n end @@ -333,6 +325,20 @@ function link(n::GroupNode, ctx) Group(by = n.by, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) end +function link(n::IntoNode, ctx) + refs = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested(name = (local name))) + @assert name == n.name + push!(refs, tail) + else + error() + end + end + tail′ = link(ctx.tail, ctx, refs) + Into(name = n.name, tail = tail′) +end + function link(n::IterateNode, ctx) iterator′ = n.iterator defs = copy(ctx.defs) @@ -364,53 +370,6 @@ function link(n::IterateNode, ctx) Padding(tail = q′) end -function route(r::JoinRouter, ref::SQLQuery) - if @dissect(ref, Nested(name = (local name))) && name in r.label_set - return 1 - end - if @dissect(ref, Get(name = (local name))) && name in r.label_set - return 1 - end - if @dissect(ref, Agg()) && r.group - return 1 - end - return -1 -end - -function link(n::RoutedJoinNode, ctx) - lrefs = SQLQuery[] - rrefs = SQLQuery[] - for ref in ctx.refs - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - if n.optional && isempty(rrefs) - return link(ctx) - end - ln_ext_refs = length(lrefs) - rn_ext_refs = length(rrefs) - refs′ = SQLQuery[] - lateral_refs = SQLQuery[] - gather!(n.joinee, ctx, lateral_refs) - append!(lrefs, lateral_refs) - lateral = !isempty(lateral_refs) - gather!(n.on, ctx, refs′) - for ref in refs′ - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs)) - joinee′ = Linked(rrefs, rn_ext_refs, tail = link(n.joinee, ctx, rrefs)) - RoutedJoin( - joinee = joinee′, - on = n.on, - router = n.router, - left = n.left, - right = n.right, - lateral = lateral, - tail = tail′) -end - function link(n::LimitNode, ctx) tail′ = Linked(ctx.refs, tail = link(ctx)) Limit(offset = n.offset, limit = n.limit, tail = tail′) @@ -459,6 +418,46 @@ function link(n::PartitionNode, ctx) Partition(by = n.by, order_by = n.order_by, frame = n.frame, name = n.name, tail = tail′) end +function link(n::RoutedJoinNode, ctx) + lrefs = SQLQuery[] + rrefs = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, Nested(name = (local name))) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + if n.optional && isempty(rrefs) + return link(ctx) + end + ln_ext_refs = length(lrefs) + rn_ext_refs = length(rrefs) + refs′ = SQLQuery[] + lateral_refs = SQLQuery[] + gather!(n.joinee, ctx, lateral_refs) + append!(lrefs, lateral_refs) + lateral = !isempty(lateral_refs) + gather!(n.on, ctx, refs′) + for ref in refs′ + if @dissect(ref, Nested(name = (local name))) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs)) + joinee′ = Linked(rrefs, rn_ext_refs, tail = link(Into(name = n.name, tail = n.joinee), ctx, rrefs)) + RoutedJoin( + joinee = joinee′, + on = n.on, + name = n.name, + left = n.left, + right = n.right, + lateral = lateral, + tail = tail′) +end + function link(n::SelectNode, ctx) refs = SQLQuery[] gather!(n.args, ctx, refs) diff --git a/src/nodes.jl b/src/nodes.jl index ab470726..5fc547c0 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -913,6 +913,7 @@ include("nodes/get.jl") include("nodes/group.jl") include("nodes/highlight.jl") include("nodes/internal.jl") +include("nodes/into.jl") include("nodes/iterate.jl") include("nodes/join.jl") include("nodes/limit.jl") diff --git a/src/nodes/as.jl b/src/nodes/as.jl index 3a4a5db7..d1a61a2d 100644 --- a/src/nodes/as.jl +++ b/src/nodes/as.jl @@ -16,8 +16,7 @@ AsNode(name) = As(name; tail = nothing) name => tail -In a scalar context, `As` specifies the name of the output column. When -applied to tabular data, `As` wraps the data in a nested record. +`As` specifies the name of the output column. The arrow operator (`=>`) is a shorthand notation for `As`. @@ -35,19 +34,19 @@ SELECT "person_1"."person_id" AS "id" FROM "person" AS "person_1" ``` -*Show all patients together with their state of residence.* +*Show all patients together with their primary care provider.* ```jldoctest -julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]); +julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]); -julia> location = SQLTable(:location, columns = [:location_id, :state]); +julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]); julia> q = From(:person) |> - Join(From(:location) |> As(:location), - on = Get.location_id .== Get.location.location_id) |> - Select(Get.person_id, Get.location.state); + Join(From(:provider) |> As(:pcp), + on = Get.provider_id .== Get.pcp.provider_id) |> + Select(Get.person_id, Get.pcp.provider_name); -julia> print(render(q, tables = [person, location])) +julia> print(render(q, tables = [person, provider])) SELECT "person_1"."person_id", "location_1"."state" diff --git a/src/nodes/hide.jl b/src/nodes/hide.jl new file mode 100644 index 00000000..36ced02c --- /dev/null +++ b/src/nodes/hide.jl @@ -0,0 +1,42 @@ +# Hide node + +mutable struct HideNode <: TabularNode + over::Union{SQLNode, Nothing} + names::Vector{Symbol} + label_map::FunSQL.OrderedDict{Symbol, Int} + + function HideNode(; over = nothing, names = [], label_map = nothing) + if label_map !== nothing + new(over, names, label_map) + else + n = new(over, names, FunSQL.OrderedDict{Symbol, Int}()) + for (i, name) in enumerate(n.names) + if name in keys(n.label_map) + err = FunSQL.DuplicateLabelError(name, path = [n]) + throw(err) + end + n.label_map[name] = i + end + n + end + end +end + +HideNode(names...; over = nothing) = + HideNode(over = over, names = Symbol[names...]) + +Hide(args...; kws...) = + HideNode(args...; kws...) |> SQLNode + +const funsql_hide = Hide + +dissect(scr::Symbol, ::typeof(Hide), pats::Vector{Any}) = + dissect(scr, HideNode, pats) + +function FunSQL.PrettyPrinting.quoteof(n::HideNode, ctx::FunSQL.QuoteContext) + ex = Expr(:call, nameof(Hide), quoteof(n.names, ctx)...) + if n.over !== nothing + ex = Expr(:call, :|>, FunSQL.quoteof(n.over, ctx), ex) + end + ex +end diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 91f866fd..931afb97 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -203,29 +203,21 @@ PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) = # Annotated Join node. -struct JoinRouter - label_set::Set{Symbol} - group::Bool -end - -PrettyPrinting.quoteof(r::JoinRouter) = - Expr(:call, :JoinRouter, quoteof(r.label_set), quoteof(r.group)) - struct RoutedJoinNode <: TabularNode joinee::SQLQuery on::SQLQuery - router::JoinRouter + name::Symbol left::Bool right::Bool lateral::Bool optional::Bool - RoutedJoinNode(; joinee, on, router, left, right, lateral = false, optional = false) = - new(joinee, on, router, left, right, lateral, optional) + RoutedJoinNode(; joinee, on, name = label(joinee), left, right, lateral = false, optional = false) = + new(joinee, on, name, left, right, lateral, optional) end -RoutedJoinNode(joinee, on; router, left = false, right = false, lateral = false, optional = false) = - RoutedJoinNode(joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional) +RoutedJoinNode(joinee, on; name = label(joinee), left = false, right = false, lateral = false, optional = false) = + RoutedJoinNode(name = name, on = on, router, left = left, right = right, lateral = lateral, optional = optional) const RoutedJoin = SQLQueryCtor{RoutedJoinNode}(:RoutedJoin) @@ -234,7 +226,7 @@ function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext) if !ctx.limit push!(ex.args, quoteof(n.joinee, ctx)) push!(ex.args, quoteof(n.on, ctx)) - push!(ex.args, Expr(:kw, :router, quoteof(n.router))) + push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) if n.left push!(ex.args, Expr(:kw, :left, n.left)) end diff --git a/src/nodes/into.jl b/src/nodes/into.jl new file mode 100644 index 00000000..f421ea1e --- /dev/null +++ b/src/nodes/into.jl @@ -0,0 +1,28 @@ +# Wrap the output into a nested record. + +mutable struct IntoNode <: TabularNode + name::Symbol + + IntoNode(; name::Union{Symbol, AbstractString}) = + new(Symbol(name)) +end + +IntoNode(name) = + IntoNode(; name) + +""" + Into(; name, tail = nothing) + Into(name; tail = nothing) + +`Into` wraps output columns in a nested record. +""" +const Into = SQLQueryCtor{IntoNode}(:Into) + +const funsql_into = Into + +function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext) + Expr(:call, :Into, quoteof(n.name)) +end + +label(n::IntoNode) = + n.name diff --git a/src/resolve.jl b/src/resolve.jl index 73bf67ac..56136edb 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -178,9 +178,8 @@ end function resolve(n::AsNode, ctx) tail′ = resolve(ctx) - t = row_type(tail′) q′ = As(name = n.name, tail = tail′) - Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) + Resolved(type(tail′), tail = q′) end function resolve_scalar(n::AsNode, ctx) @@ -401,6 +400,13 @@ resolve(::HighlightNode, ctx) = resolve_scalar(::HighlightNode, ctx) = resolve_scalar(ctx) +function resolve(n::IntoNode, ctx) + tail′ = resolve(ctx) + t = row_type(tail′) + q′ = Into(name = n.name, tail = tail′) + Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) +end + function resolve(n::IterateNode, ctx) tail′ = resolve(ResolveContext(ctx, knot_type = nothing, implicit_knot = false)) t = row_type(tail′) @@ -418,21 +424,18 @@ end function resolve(n::JoinNode, ctx) tail′ = resolve(ctx) lt = row_type(tail′) + name = label(n.joinee) joinee′ = resolve(n.joinee, ResolveContext(ctx, row_type = lt, implicit_knot = false)) rt = row_type(joinee′) fields = FieldTypeMap() for (f, ft) in lt.fields - fields[f] = get(rt.fields, f, ft) + fields[f] = ft end - for (f, ft) in rt.fields - if !haskey(fields, f) - fields[f] = ft - end - end - group = rt.group isa EmptyType ? lt.group : rt.group + fields[name] = rt + group = lt.group t = RowType(fields, group) on′ = resolve_scalar(n.on, ctx, t) - q′ = Join(joinee = joinee′, on = on′, left = n.left, right = n.right, optional = n.optional, tail = tail′) + q′ = RoutedJoin(joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional, tail = tail′) Resolved(t, tail = q′) end @@ -532,16 +535,7 @@ function resolve(n::Union{WithNode, WithExternalNode}, ctx) v = get(ctx.cte_types, name, nothing) depth = 1 + (v !== nothing ? v[1] : 0) t = row_type(args′[i]) - cte_t = get(t.fields, name, EmptyType()) - if !(cte_t isa RowType) - throw( - ReferenceError( - REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, - name = name, - path = get_path(ctx))) - - end - cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t)) + cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, t)) end ctx′ = ResolveContext(ctx, cte_types = cte_types′) tail′ = resolve(ctx′) diff --git a/src/translate.jl b/src/translate.jl index 84b3bd79..0e4a16fb 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -427,26 +427,8 @@ function assemble(n::AppendNode, ctx) Assemblage(a_name, s, repl = repl, cols = dummy_cols) end -function assemble(n::AsNode, ctx) - refs′ = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - push!(refs′, tail) - else - push!(refs′, ref) - end - end - base = assemble(TranslateContext(ctx, refs = refs′)) - repl′ = Dict{SQLQuery, Symbol}() - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - repl′[ref] = base.repl[tail] - else - repl′[ref] = base.repl[ref] - end - end - Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) -end +assemble(n::AsNode, ctx) = + assemble(ctx) function assemble(n::BindNode, ctx) vars′ = ctx.vars @@ -530,21 +512,12 @@ end assemble(::FromNothingNode, ctx) = assemble(nothing, ctx) -function unwrap_repl(a::Assemblage) - repl′ = Dict{SQLQuery, Symbol}() - for (ref, name) in a.repl - @dissect(ref, (local tail) |> Nested()) || error() - repl′[tail] = name - end - Assemblage(a.name, a.syntax, cols = a.cols, repl = repl′) -end - function assemble(n::FromTableExpressionNode, ctx) cte_a = ctx.ctes[ctx.cte_map[(n.name, n.depth)]] alias = allocate_alias(ctx, n.name) tbl = convert(SQLSyntax, (cte_a.qualifiers, cte_a.name)) s = FROM(AS(name = alias, tail = tbl)) - subs = make_subs(unwrap_repl(cte_a.a), alias) + subs = make_subs(cte_a.a, alias) trns = Pair{SQLQuery, SQLSyntax}[] for ref in ctx.refs push!(trns, ref => subs[ref]) @@ -675,6 +648,27 @@ function assemble(n::GroupNode, ctx) return Assemblage(base.name, s, cols = cols, repl = repl) end +function assemble(n::IntoNode, ctx) + refs′ = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested()) + push!(refs′, tail) + else + push!(refs′, ref) + end + end + base = assemble(TranslateContext(ctx, refs = refs′)) + repl′ = Dict{SQLQuery, Symbol}() + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested()) + repl′[ref] = base.repl[tail] + else + repl′[ref] = base.repl[ref] + end + end + Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) +end + function assemble(n::IterateNode, ctx) ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax}()) left = assemble(ctx) @@ -883,22 +877,16 @@ function assemble(n::RoutedJoinNode, ctx) right = assemble(n.joinee, ctx) end if @dissect(right.syntax, (local joinee = (ID() || AS())) |> FROM()) && (!n.left || _outer_safe(right)) - for (ref, name) in right.repl - subs[ref] = right.cols[name] - end + right_alias = nothing if ctx.catalog.dialect.has_implicit_lateral lateral = false end else right_alias = allocate_alias(ctx, right) joinee = AS(name = right_alias, tail = complete(right)) - right_cache = Dict{Symbol, SQLSyntax}() - for (ref, name) in right.repl - subs[ref] = get(right_cache, name) do - ID(name = name, tail = right_alias) - end - end end + right_subs = make_subs(right, right_alias) + merge!(subs, right_subs) on = translate(n.on, ctx, subs) s = JOIN(joinee = joinee, on = on, left = n.left, right = n.right, lateral = lateral, tail = tail) trns = Pair{SQLQuery, SQLSyntax}[] From a70eaff4d98242fafd7019d924bbf15362b96a4e Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Wed, 19 Feb 2025 21:17:08 -0600 Subject: [PATCH 2/5] Make nested records visible --- src/link.jl | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/link.jl b/src/link.jl index 3736efea..d79cdcbb 100644 --- a/src/link.jl +++ b/src/link.jl @@ -25,16 +25,26 @@ struct LinkContext knot_refs) end -function link(q::SQLQuery) - @dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError()) - ctx = LinkContext(catalog) - t = row_type(tail) +function _select(t::RowType) refs = SQLQuery[] for (f, ft) in t.fields if ft isa ScalarType push!(refs, Get(f)) + else + nested_refs = _select(ft) + for nested_ref in nested_refs + push!(refs, Nested(over = nested_ref, name = f)) + end end end + refs +end + +function link(q::SQLQuery) + @dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError()) + ctx = LinkContext(catalog) + t = row_type(tail) + refs = _select(t) tail′ = Linked(refs, tail = link(dismantle(tail, ctx), ctx, refs)) WithContext(tail = tail′, catalog = catalog, defs = ctx.defs) end @@ -555,12 +565,9 @@ end function gather!(n::IsolatedNode, ctx) def = ctx.defs[n.idx] !@dissect(def, Linked()) || return - refs = SQLQuery[] - for (f, ft) in n.type.fields - if ft isa ScalarType - push!(refs, Get(f)) - break - end + refs = _select(n.type) + if !isempty(refs) + refs = refs[1:1] end def′ = Linked(refs, tail = link(def, ctx, refs)) ctx.defs[n.idx] = def′ From 8149fcfc8f5d41fc88e4baf5f7d6bd258f50e2ad Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Sat, 22 Feb 2025 13:58:58 -0600 Subject: [PATCH 3/5] Join: add swap option --- src/nodes/join.jl | 25 ++++++++++++++++--------- src/resolve.jl | 4 ++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/nodes/join.jl b/src/nodes/join.jl index c1bc2eb4..372d2272 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -6,21 +6,22 @@ mutable struct JoinNode <: TabularNode left::Bool right::Bool optional::Bool + swap::Bool - JoinNode(; joinee, on, left = false, right = false, optional = false) = - new(joinee, on, left, right, optional) + JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false) = + new(joinee, on, left, right, optional, swap) end -JoinNode(joinee; on, left = false, right = false, optional = false) = - JoinNode(; joinee, on, left, right, optional) +JoinNode(joinee; on, left = false, right = false, optional = false, swap = false) = + JoinNode(; joinee, on, left, right, optional, swap) -JoinNode(joinee, on; left = false, right = false, optional = false) = - JoinNode(; joinee, on, left, right, optional) +JoinNode(joinee, on; left = false, right = false, optional = false, swap = false) = + JoinNode(; joinee, on, left, right, optional, swap) """ - Join(; joinee, on, left = false, right = false, optional = false) - Join(joinee; on, left = false, right = false, optional = false) - Join(joinee, on; left = false, right = false, optional = false) + Join(; joinee, on, left = false, right = false, optional = false, swap = false) + Join(joinee; on, left = false, right = false, optional = false, swap = false) + Join(joinee, on; left = false, right = false, optional = false, swap = false) `Join` correlates two input datasets. @@ -102,8 +103,14 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) if n.optional push!(ex.args, Expr(:kw, :optional, n.optional)) end + if n.swap + push!(ex.args, Expr(:kw, :swap, n.swap)) + end else push!(ex.args, :…) end ex end + +label(n::JoinNode) = + n.swap ? label(n.joinee) : label(n.over) diff --git a/src/resolve.jl b/src/resolve.jl index 56136edb..bcdb6e21 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -422,6 +422,10 @@ function resolve(n::IterateNode, ctx) end function resolve(n::JoinNode, ctx) + if n.swap + ctx′ = ResolveContext(Ctx, tail = n.joinee) + return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx′) + end tail′ = resolve(ctx) lt = row_type(tail′) name = label(n.joinee) From 9edf1591e044a7b2e3ca76074b2207d209ec9a11 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Sat, 22 Feb 2025 17:18:45 -0600 Subject: [PATCH 4/5] Add Show/Hide combinators --- src/FunSQL.jl | 2 ++ src/link.jl | 4 +++- src/nodes.jl | 15 +++++++++------ src/nodes/internal.jl | 2 +- src/nodes/show.jl | 38 ++++++++++++++++++++++++++++++++++++++ src/resolve.jl | 29 ++++++++++++++++++++++++++++- src/types.jl | 30 +++++++++++++++++++++++------- 7 files changed, 104 insertions(+), 16 deletions(-) create mode 100644 src/nodes/show.jl diff --git a/src/FunSQL.jl b/src/FunSQL.jl index 618c6610..56e4378a 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -54,6 +54,7 @@ export funsql_from, funsql_fun, funsql_group, + funsql_hide, funsql_highlight, funsql_in, funsql_into, @@ -83,6 +84,7 @@ export funsql_rank, funsql_row_number, funsql_select, + funsql_show, funsql_sort, funsql_sum, funsql_with diff --git a/src/link.jl b/src/link.jl index d79cdcbb..c9eada38 100644 --- a/src/link.jl +++ b/src/link.jl @@ -27,13 +27,15 @@ end function _select(t::RowType) refs = SQLQuery[] + t.visible || return refs for (f, ft) in t.fields if ft isa ScalarType + ft.visible || continue push!(refs, Get(f)) else nested_refs = _select(ft) for nested_ref in nested_refs - push!(refs, Nested(over = nested_ref, name = f)) + push!(refs, Nested(name = f, tail = nested_ref)) end end end diff --git a/src/nodes.jl b/src/nodes.jl index 5fc547c0..e16e70ef 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -54,17 +54,19 @@ terminal(q::SQLQuery) = Chain(q′, q) = convert(SQLQuery, q)(q′) -label(q::SQLQuery) = - @something label(q.head) label(q.tail) +function label(q::SQLQuery; default = :_) + l = label(q.head) + l !== nothing ? l : label(q.tail; default) +end label(n::AbstractSQLNode) = nothing -label(::Nothing) = - :_ +label(::Nothing; default = :_) = + default -label(q) = - label(convert(SQLQuery, q)) +label(q; default = :_) = + label(convert(SQLQuery, q); default) # A variant of SQLQuery for assembling a chain of identifiers. @@ -922,6 +924,7 @@ include("nodes/order.jl") include("nodes/over.jl") include("nodes/partition.jl") include("nodes/select.jl") +include("nodes/show.jl") include("nodes/sort.jl") include("nodes/variable.jl") include("nodes/where.jl") diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 931afb97..36221e3e 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -302,7 +302,7 @@ PrettyPrinting.quoteof(n::FunSQLMacroNode, ctx::QuoteContext) = Expr(:macrocall, Symbol("@funsql"), n.line, !ctx.limit ? n.ex : :…) label(n::FunSQLMacroNode) = - label(n.query) + label(n.query, default = nothing) # Unwrap @funsql macro when displaying the query. diff --git a/src/nodes/show.jl b/src/nodes/show.jl new file mode 100644 index 00000000..f0176ec7 --- /dev/null +++ b/src/nodes/show.jl @@ -0,0 +1,38 @@ +# Show/Hide nodes + +mutable struct ShowNode <: TabularNode + names::Vector{Symbol} + visible::Bool + label_map::FunSQL.OrderedDict{Symbol, Int} + + function ShowNode(; names = [], visible = true, label_map = nothing) + if label_map !== nothing + new(names, visible, label_map) + else + n = new(names, visible, FunSQL.OrderedDict{Symbol, Int}()) + for (i, name) in enumerate(n.names) + if name in keys(n.label_map) + err = FunSQL.DuplicateLabelError(name, path = SQLQuery[n]) + throw(err) + end + n.label_map[name] = i + end + n + end + end +end + +ShowNode(names...; visible = true) = + ShowNode(names = Symbol[names...], visible = visible) + +const Show = SQLQueryCtor{ShowNode}(:Show) + +Hide(args...; kws...) = + Show(args...; kws..., visible = false) + +const funsql_show = Show +const funsql_hide = Hide + +function FunSQL.PrettyPrinting.quoteof(n::ShowNode, ctx::QuoteContext) + Expr(:call, n.visible ? :Show : :Hide, quoteof(n.names, ctx)...) +end diff --git a/src/resolve.jl b/src/resolve.jl index bcdb6e21..1e901673 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -423,7 +423,7 @@ end function resolve(n::JoinNode, ctx) if n.swap - ctx′ = ResolveContext(Ctx, tail = n.joinee) + ctx′ = ResolveContext(ctx, tail = n.joinee) return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx′) end tail′ = resolve(ctx) @@ -506,6 +506,33 @@ function resolve(n::SelectNode, ctx) Resolved(RowType(fields), tail = q′) end +function resolve(n::ShowNode, ctx) + tail′ = resolve(ctx) + t = row_type(tail′) + for name in n.names + ft = get(t.fields, name, EmptyType()) + if ft isa EmptyType + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.UNDEFINED_NAME, + name = name, + path = get_path(ctx))) + end + end + fields = FieldTypeMap() + for (f, ft) in t.fields + if f in keys(n.label_map) + if ft isa ScalarType + ft = ScalarType(visible = n.visible) + else + ft = RowType(ft.fields, ft.group, visible = n.visible) + end + end + fields[f] = ft + end + Resolved(RowType(fields, t.group, visible = t.visible), tail = tail′) +end + function resolve_scalar(n::SortNode, ctx) tail′ = resolve_scalar(ctx) q′ = Sort(value = n.value, nulls = n.nulls, tail = tail′) diff --git a/src/types.jl b/src/types.jl index 856821ed..e04cfec6 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,17 +13,27 @@ PrettyPrinting.quoteof(::EmptyType) = Expr(:call, nameof(EmptyType)) struct ScalarType <: AbstractSQLType + visible::Bool + + ScalarType(; visible = true) = + new(visible) end -PrettyPrinting.quoteof(::ScalarType) = - Expr(:call, nameof(ScalarType)) +function PrettyPrinting.quoteof(t::ScalarType) + ex = Expr(:call, nameof(ScalarType)) + if !t.visible + push!(ex.args, Expr(:kw, :visible, t.visible)) + end + ex +end struct RowType <: AbstractSQLType fields::OrderedDict{Symbol, Union{ScalarType, RowType}} group::Union{EmptyType, RowType} + visible::Bool - RowType(fields, group = EmptyType()) = - new(fields, group) + RowType(fields, group = EmptyType(); visible = true) = + new(fields, group, visible) end const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} @@ -43,6 +53,9 @@ function PrettyPrinting.quoteof(t::RowType) if !(t.group isa EmptyType) push!(ex.args, Expr(:kw, :group, quoteof(t.group))) end + if !t.visible + push!(ex.args, Expr(:kw, :visible, t.visible)) + end ex end @@ -54,8 +67,8 @@ const EMPTY_ROW = RowType() Base.intersect(::AbstractSQLType, ::AbstractSQLType) = EmptyType() -Base.intersect(::ScalarType, ::ScalarType) = - ScalarType() +Base.intersect(t1::ScalarType, t2::ScalarType) = + ScalarType(visible = t1.visible || t2.visible) function Base.intersect(t1::RowType, t2::RowType) if t1 === t2 @@ -71,7 +84,7 @@ function Base.intersect(t1::RowType, t2::RowType) end end group = intersect(t1.group, t2.group) - RowType(fields, group) + RowType(fields, group, visible = t1.visible || t2.visible) end @@ -98,5 +111,8 @@ function Base.issubset(t1::RowType, t2::RowType) if !issubset(t1.group, t2.group) return false end + if !t1.visible && t2.visible + return false + end return true end From ff18dc057054eb2733b186577950ed1f757e571a Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Wed, 5 Nov 2025 22:58:24 -0600 Subject: [PATCH 5/5] Replace show()/hide() with "private" parameter for define()/join() --- src/FunSQL.jl | 2 -- src/catalogs.jl | 23 +++++++++------ src/link.jl | 3 +- src/nodes.jl | 1 - src/nodes/define.jl | 20 ++++++++----- src/nodes/hide.jl | 42 --------------------------- src/nodes/into.jl | 19 +++++++----- src/nodes/join.jl | 24 +++++++++------- src/nodes/show.jl | 38 ------------------------ src/resolve.jl | 70 ++++++++++++++++++++++----------------------- src/types.jl | 41 ++++++++++++-------------- 11 files changed, 107 insertions(+), 176 deletions(-) delete mode 100644 src/nodes/hide.jl delete mode 100644 src/nodes/show.jl diff --git a/src/FunSQL.jl b/src/FunSQL.jl index 56e4378a..618c6610 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -54,7 +54,6 @@ export funsql_from, funsql_fun, funsql_group, - funsql_hide, funsql_highlight, funsql_in, funsql_into, @@ -84,7 +83,6 @@ export funsql_rank, funsql_row_number, funsql_select, - funsql_show, funsql_sort, funsql_sum, funsql_with diff --git a/src/catalogs.jl b/src/catalogs.jl index 0fbd446a..1fe7f2e6 100644 --- a/src/catalogs.jl +++ b/src/catalogs.jl @@ -31,22 +31,24 @@ _metadata_get(dict::SQLMetadata, key::Union{Symbol, AbstractString}, default; st end """ - SQLColumn(; name, metadata = nothing) - SQLColumn(name; metadata = nothing) + SQLColumn(; name, private = false, metadata = nothing) + SQLColumn(name; private = false, metadata = nothing) `SQLColumn` represents a column with the given `name` and optional `metadata`. +If `private` is `true`, the column is excluded from the default query output. """ struct SQLColumn name::Symbol + private::Bool metadata::SQLMetadata - function SQLColumn(; name::Union{Symbol, AbstractString}, metadata = nothing) - new(Symbol(name), _metadata(metadata)) + function SQLColumn(; name::Union{Symbol, AbstractString}, private = false, metadata = nothing) + new(Symbol(name), private, _metadata(metadata)) end end -SQLColumn(name; metadata = nothing) = - SQLColumn(name = name, metadata = metadata) +SQLColumn(name; private = false, metadata = nothing) = + SQLColumn(; name, private, metadata) Base.show(io::IO, col::SQLColumn) = print(io, quoteof(col, limit = true)) @@ -56,6 +58,9 @@ Base.show(io::IO, ::MIME"text/plain", col::SQLColumn) = function PrettyPrinting.quoteof(col::SQLColumn; limit::Bool = false) ex = Expr(:call, nameof(SQLColumn), QuoteNode(col.name)) + if col.private + push(ex.args, Expr(:kw, :private, col.private)) + end if !isempty(col.metadata) push!(ex.args, Expr(:kw, :metadata, limit ? :… : quoteof(reverse!(collect(col.metadata))))) end @@ -122,10 +127,10 @@ struct SQLTable <: AbstractDict{Symbol, SQLColumn} end SQLTable(name; qualifiers = Symbol[], columns, metadata = nothing) = - SQLTable(qualifiers = qualifiers, name = name, columns = columns, metadata = metadata) + SQLTable(; qualifiers, name, columns, metadata) SQLTable(name, columns...; qualifiers = Symbol[], metadata = nothing) = - SQLTable(qualifiers = qualifiers, name = name, columns = [columns...], metadata = metadata) + SQLTable(; qualifiers, name, columns = [columns...], metadata) _column_map(columns::OrderedDict{Symbol, SQLColumn}) = columns @@ -280,7 +285,7 @@ struct SQLCatalog <: AbstractDict{Symbol, SQLTable} end SQLCatalog(tables...; dialect = :default, cache = default_cache_maxsize, metadata = nothing) = - SQLCatalog(tables = tables, dialect = dialect, cache = cache, metadata = metadata) + SQLCatalog(; tables, dialect, cache, metadata) _table_map(tables::Dict{Symbol, SQLTable}) = tables diff --git a/src/link.jl b/src/link.jl index c9eada38..8bb14e13 100644 --- a/src/link.jl +++ b/src/link.jl @@ -27,10 +27,9 @@ end function _select(t::RowType) refs = SQLQuery[] - t.visible || return refs for (f, ft) in t.fields + !(f in t.private_fields) || continue if ft isa ScalarType - ft.visible || continue push!(refs, Get(f)) else nested_refs = _select(ft) diff --git a/src/nodes.jl b/src/nodes.jl index e16e70ef..b8e68a13 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -924,7 +924,6 @@ include("nodes/order.jl") include("nodes/over.jl") include("nodes/partition.jl") include("nodes/select.jl") -include("nodes/show.jl") include("nodes/sort.jl") include("nodes/variable.jl") include("nodes/where.jl") diff --git a/src/nodes/define.jl b/src/nodes/define.jl index fab058de..3fd44dfc 100644 --- a/src/nodes/define.jl +++ b/src/nodes/define.jl @@ -4,13 +4,14 @@ struct DefineNode <: TabularNode args::Vector{SQLQuery} before::Union{Symbol, Bool} after::Union{Symbol, Bool} + private::Bool label_map::OrderedDict{Symbol, Int} - function DefineNode(; args = [], before = nothing, after = nothing, label_map = nothing) + function DefineNode(; args = [], before = nothing, after = nothing, private = false, label_map = nothing) if label_map !== nothing - n = new(args, something(before, false), something(after, false), label_map) + n = new(args, something(before, false), something(after, false), private, label_map) else - n = new(args, something(before, false), something(after, false), OrderedDict{Symbol, Int}()) + n = new(args, something(before, false), something(after, false), private, OrderedDict{Symbol, Int}()) populate_label_map!(n) end if (n.before isa Symbol || n.before) && (n.after isa Symbol || n.after) @@ -20,12 +21,12 @@ struct DefineNode <: TabularNode end end -DefineNode(args...; before = nothing, after = nothing) = - DefineNode(args = SQLQuery[args...], before = before, after = after) +DefineNode(args...; before = nothing, after = nothing, private = false) = + DefineNode(args = SQLQuery[args...], before = before, after = after, private = private) """ - Define(; args = [], before = nothing, after = nothing, tail = nothing) - Define(args...; before = nothing, after = nothing, tail = nothing) + Define(; args = [], before = nothing, after = nothing, private = false, tail = nothing) + Define(args...; before = nothing, after = nothing, private = false, tail = nothing) The `Define` node adds or replaces output columns. @@ -35,6 +36,8 @@ both new and replaced columns at the end (after a specified column). Alternatively, set `before = true` (`before = `) to add both new and replaced columns at the front (before the specified column). +If `private` is set, the columns will be excluded from the query output. + # Examples *Show patients who are at least 16 years old.* @@ -90,5 +93,8 @@ function PrettyPrinting.quoteof(n::DefineNode, ctx::QuoteContext) if n.after !== false push!(ex.args, Expr(:kw, :after, n.after isa Symbol ? QuoteNode(n.after) : n.after)) end + if n.private !== false + push!(ex.args, Expr(:kw, :private, n.private)) + end ex end diff --git a/src/nodes/hide.jl b/src/nodes/hide.jl deleted file mode 100644 index 36ced02c..00000000 --- a/src/nodes/hide.jl +++ /dev/null @@ -1,42 +0,0 @@ -# Hide node - -mutable struct HideNode <: TabularNode - over::Union{SQLNode, Nothing} - names::Vector{Symbol} - label_map::FunSQL.OrderedDict{Symbol, Int} - - function HideNode(; over = nothing, names = [], label_map = nothing) - if label_map !== nothing - new(over, names, label_map) - else - n = new(over, names, FunSQL.OrderedDict{Symbol, Int}()) - for (i, name) in enumerate(n.names) - if name in keys(n.label_map) - err = FunSQL.DuplicateLabelError(name, path = [n]) - throw(err) - end - n.label_map[name] = i - end - n - end - end -end - -HideNode(names...; over = nothing) = - HideNode(over = over, names = Symbol[names...]) - -Hide(args...; kws...) = - HideNode(args...; kws...) |> SQLNode - -const funsql_hide = Hide - -dissect(scr::Symbol, ::typeof(Hide), pats::Vector{Any}) = - dissect(scr, HideNode, pats) - -function FunSQL.PrettyPrinting.quoteof(n::HideNode, ctx::FunSQL.QuoteContext) - ex = Expr(:call, nameof(Hide), quoteof(n.names, ctx)...) - if n.over !== nothing - ex = Expr(:call, :|>, FunSQL.quoteof(n.over, ctx), ex) - end - ex -end diff --git a/src/nodes/into.jl b/src/nodes/into.jl index f421ea1e..655652c0 100644 --- a/src/nodes/into.jl +++ b/src/nodes/into.jl @@ -2,17 +2,18 @@ mutable struct IntoNode <: TabularNode name::Symbol + private::Bool - IntoNode(; name::Union{Symbol, AbstractString}) = - new(Symbol(name)) + IntoNode(; name::Union{Symbol, AbstractString}, private::Bool = false) = + new(Symbol(name), private) end -IntoNode(name) = - IntoNode(; name) +IntoNode(name; private = false) = + IntoNode(; name, private) """ - Into(; name, tail = nothing) - Into(name; tail = nothing) + Into(; name, private = false, tail = nothing) + Into(name; private = false, tail = nothing) `Into` wraps output columns in a nested record. """ @@ -21,7 +22,11 @@ const Into = SQLQueryCtor{IntoNode}(:Into) const funsql_into = Into function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext) - Expr(:call, :Into, quoteof(n.name)) + ex = Expr(:call, :Into, quoteof(n.name)) + if n.private + push!(ex.args, Expr(:kw, :private, n.private)) + end + ex end label(n::IntoNode) = diff --git a/src/nodes/join.jl b/src/nodes/join.jl index 372d2272..027914f3 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -7,21 +7,22 @@ mutable struct JoinNode <: TabularNode right::Bool optional::Bool swap::Bool + private::Bool - JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false) = - new(joinee, on, left, right, optional, swap) + JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false, private = false) = + new(joinee, on, left, right, optional, swap, private) end -JoinNode(joinee; on, left = false, right = false, optional = false, swap = false) = - JoinNode(; joinee, on, left, right, optional, swap) +JoinNode(joinee; on, left = false, right = false, optional = false, swap = false, private = false) = + JoinNode(; joinee, on, left, right, optional, swap, private) -JoinNode(joinee, on; left = false, right = false, optional = false, swap = false) = - JoinNode(; joinee, on, left, right, optional, swap) +JoinNode(joinee, on; left = false, right = false, optional = false, swap = false, private = false) = + JoinNode(; joinee, on, left, right, optional, swap, private) """ - Join(; joinee, on, left = false, right = false, optional = false, swap = false) - Join(joinee; on, left = false, right = false, optional = false, swap = false) - Join(joinee, on; left = false, right = false, optional = false, swap = false) + Join(; joinee, on, left = false, right = false, optional = false, swap = false, private = false) + Join(joinee; on, left = false, right = false, optional = false, swap = false, private = false) + Join(joinee, on; left = false, right = false, optional = false, swap = false, private = false) `Join` correlates two input datasets. @@ -106,6 +107,9 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) if n.swap push!(ex.args, Expr(:kw, :swap, n.swap)) end + if n.private + push!(ex.args, Expr(:kw, :private, n.private)) + end else push!(ex.args, :…) end @@ -113,4 +117,4 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) end label(n::JoinNode) = - n.swap ? label(n.joinee) : label(n.over) + n.swap ? label(n.joinee) : nothing diff --git a/src/nodes/show.jl b/src/nodes/show.jl deleted file mode 100644 index f0176ec7..00000000 --- a/src/nodes/show.jl +++ /dev/null @@ -1,38 +0,0 @@ -# Show/Hide nodes - -mutable struct ShowNode <: TabularNode - names::Vector{Symbol} - visible::Bool - label_map::FunSQL.OrderedDict{Symbol, Int} - - function ShowNode(; names = [], visible = true, label_map = nothing) - if label_map !== nothing - new(names, visible, label_map) - else - n = new(names, visible, FunSQL.OrderedDict{Symbol, Int}()) - for (i, name) in enumerate(n.names) - if name in keys(n.label_map) - err = FunSQL.DuplicateLabelError(name, path = SQLQuery[n]) - throw(err) - end - n.label_map[name] = i - end - n - end - end -end - -ShowNode(names...; visible = true) = - ShowNode(names = Symbol[names...], visible = visible) - -const Show = SQLQueryCtor{ShowNode}(:Show) - -Hide(args...; kws...) = - Show(args...; kws..., visible = false) - -const funsql_show = Show -const funsql_hide = Hide - -function FunSQL.PrettyPrinting.quoteof(n::ShowNode, ctx::QuoteContext) - Expr(:call, n.visible ? :Show : :Hide, quoteof(n.names, ctx)...) -end diff --git a/src/resolve.jl b/src/resolve.jl index 1e901673..2628e7f9 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -263,16 +263,28 @@ function resolve(n::DefineNode, ctx) end end end + private_fields = copy(t.private_fields) + for l in keys(n.label_map) + if n.private + push!(private_fields, l) + else + delete!(private_fields, l) + end + end q′ = Define(args = args′, label_map = n.label_map, tail = tail′) - Resolved(RowType(fields, t.group), tail = q′) + Resolved(RowType(fields, t.group, private_fields), tail = q′) end function RowType(table::SQLTable) fields = FieldTypeMap() - for f in keys(table.columns) + private_fields = Set{Symbol}() + for (f, c) in table.columns fields[f] = ScalarType() + if c.private + push!(private_fields, f) + end end - RowType(fields) + RowType(fields, EmptyType(), private_fields) end function resolve(n::FromNode, ctx) @@ -390,8 +402,12 @@ function resolve(n::GroupNode, ctx) fields[n.name] = RowType(FieldTypeMap(), group) group = EmptyType() end + private_fields = Set{Symbol}() + if n.name !== nothing + push!(private_fields, n.name) + end q′ = Group(by = by′, sets = n.sets, label_map = n.label_map, tail = tail′) - Resolved(RowType(fields, group), tail = q′) + Resolved(RowType(fields, group, private_fields), tail = q′) end resolve(::HighlightNode, ctx) = @@ -404,7 +420,8 @@ function resolve(n::IntoNode, ctx) tail′ = resolve(ctx) t = row_type(tail′) q′ = Into(name = n.name, tail = tail′) - Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) + t′ = RowType(FieldTypeMap(n.name => t), EmptyType(), n.private ? Set([n.name]) : Set{Symbol}()) + Resolved(t′, tail = q′) end function resolve(n::IterateNode, ctx) @@ -424,7 +441,7 @@ end function resolve(n::JoinNode, ctx) if n.swap ctx′ = ResolveContext(ctx, tail = n.joinee) - return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx′) + return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional, private = n.private), ctx′) end tail′ = resolve(ctx) lt = row_type(tail′) @@ -437,7 +454,13 @@ function resolve(n::JoinNode, ctx) end fields[name] = rt group = lt.group - t = RowType(fields, group) + private_fields = copy(lt.private_fields) + if n.private + push!(private_fields, name) + else + delete!(private_fields, name) + end + t = RowType(fields, group, private_fields) on′ = resolve_scalar(n.on, ctx, t) q′ = RoutedJoin(joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional, tail = tail′) Resolved(t, tail = q′) @@ -490,8 +513,12 @@ function resolve(n::PartitionNode, ctx) end fields[n.name] = RowType(FieldTypeMap(), t) end + private_fields = copy(t.private_fields) + if n.name !== nothing + push!(private_fields, n.name) + end q′ = Partition(by = by′, order_by = order_by′, frame = n.frame, name = n.name, tail = tail′) - Resolved(RowType(fields, group), tail = q′) + Resolved(RowType(fields, group, private_fields), tail = q′) end function resolve(n::SelectNode, ctx) @@ -506,33 +533,6 @@ function resolve(n::SelectNode, ctx) Resolved(RowType(fields), tail = q′) end -function resolve(n::ShowNode, ctx) - tail′ = resolve(ctx) - t = row_type(tail′) - for name in n.names - ft = get(t.fields, name, EmptyType()) - if ft isa EmptyType - throw( - ReferenceError( - REFERENCE_ERROR_TYPE.UNDEFINED_NAME, - name = name, - path = get_path(ctx))) - end - end - fields = FieldTypeMap() - for (f, ft) in t.fields - if f in keys(n.label_map) - if ft isa ScalarType - ft = ScalarType(visible = n.visible) - else - ft = RowType(ft.fields, ft.group, visible = n.visible) - end - end - fields[f] = ft - end - Resolved(RowType(fields, t.group, visible = t.visible), tail = tail′) -end - function resolve_scalar(n::SortNode, ctx) tail′ = resolve_scalar(ctx) q′ = Sort(value = n.value, nulls = n.nulls, tail = tail′) diff --git a/src/types.jl b/src/types.jl index e04cfec6..24d0c369 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,27 +13,21 @@ PrettyPrinting.quoteof(::EmptyType) = Expr(:call, nameof(EmptyType)) struct ScalarType <: AbstractSQLType - visible::Bool - - ScalarType(; visible = true) = - new(visible) + ScalarType() = + new() end function PrettyPrinting.quoteof(t::ScalarType) - ex = Expr(:call, nameof(ScalarType)) - if !t.visible - push!(ex.args, Expr(:kw, :visible, t.visible)) - end - ex + Expr(:call, nameof(ScalarType)) end struct RowType <: AbstractSQLType fields::OrderedDict{Symbol, Union{ScalarType, RowType}} group::Union{EmptyType, RowType} - visible::Bool + private_fields::Set{Symbol} - RowType(fields, group = EmptyType(); visible = true) = - new(fields, group, visible) + RowType(fields, group = EmptyType(), private_fields = Set{Symbol}()) = + new(fields, group, private_fields) end const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} @@ -42,8 +36,8 @@ const GroupType = Union{EmptyType, RowType} RowType() = RowType(FieldTypeMap()) -RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType()) = - RowType(FieldTypeMap(fields), group) +RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType(), private_fields = Set{Symbol}()) = + RowType(FieldTypeMap(fields), group, private_fields) function PrettyPrinting.quoteof(t::RowType) ex = Expr(:call, nameof(RowType)) @@ -53,8 +47,8 @@ function PrettyPrinting.quoteof(t::RowType) if !(t.group isa EmptyType) push!(ex.args, Expr(:kw, :group, quoteof(t.group))) end - if !t.visible - push!(ex.args, Expr(:kw, :visible, t.visible)) + if !isempty(t.private_fields) + push!(ex.args, Expr(:kw, :private_fields, t.private_fields)) end ex end @@ -67,24 +61,28 @@ const EMPTY_ROW = RowType() Base.intersect(::AbstractSQLType, ::AbstractSQLType) = EmptyType() -Base.intersect(t1::ScalarType, t2::ScalarType) = - ScalarType(visible = t1.visible || t2.visible) +Base.intersect(::ScalarType, ::ScalarType) = + ScalarType() function Base.intersect(t1::RowType, t2::RowType) if t1 === t2 return t1 end fields = FieldTypeMap() + private_fields = Set{Symbol}() for f in keys(t1.fields) if f in keys(t2.fields) t = intersect(t1.fields[f], t2.fields[f]) if !isa(t, EmptyType) fields[f] = t + if f in t1.private_fields && f in t2.private_fields + push!(private_fields, f) + end end end end group = intersect(t1.group, t2.group) - RowType(fields, group, visible = t1.visible || t2.visible) + RowType(fields, group, private_fields) end @@ -104,15 +102,12 @@ function Base.issubset(t1::RowType, t2::RowType) return true end for f in keys(t1.fields) - if !(f in keys(t2.fields) && issubset(t1.fields[f], t2.fields[f])) + if !(f in keys(t2.fields) && issubset(t1.fields[f], t2.fields[f]) && (!(f in t1.private_fields) || f in t2.private_fields)) return false end end if !issubset(t1.group, t2.group) return false end - if !t1.visible && t2.visible - return false - end return true end