Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/FunSQL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export
funsql_group,
funsql_highlight,
funsql_in,
funsql_into,
funsql_iterate,
funsql_is_not_null,
funsql_is_null,
Expand Down
23 changes: 14 additions & 9 deletions src/catalogs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ _metadata_get(dict::SQLMetadata, key::Union{Symbol, AbstractString}, default; st
end

"""
SQLColumn(; name, metadata = nothing)
SQLColumn(name; metadata = nothing)
SQLColumn(; name, private = false, metadata = nothing)
SQLColumn(name; private = false, metadata = nothing)

`SQLColumn` represents a column with the given `name` and optional `metadata`.
If `private` is `true`, the column is excluded from the default query output.
"""
struct SQLColumn
name::Symbol
private::Bool
metadata::SQLMetadata

function SQLColumn(; name::Union{Symbol, AbstractString}, metadata = nothing)
new(Symbol(name), _metadata(metadata))
function SQLColumn(; name::Union{Symbol, AbstractString}, private = false, metadata = nothing)
new(Symbol(name), private, _metadata(metadata))
end
end

SQLColumn(name; metadata = nothing) =
SQLColumn(name = name, metadata = metadata)
SQLColumn(name; private = false, metadata = nothing) =
SQLColumn(; name, private, metadata)

Base.show(io::IO, col::SQLColumn) =
print(io, quoteof(col, limit = true))
Expand All @@ -56,6 +58,9 @@ Base.show(io::IO, ::MIME"text/plain", col::SQLColumn) =

function PrettyPrinting.quoteof(col::SQLColumn; limit::Bool = false)
ex = Expr(:call, nameof(SQLColumn), QuoteNode(col.name))
if col.private
push(ex.args, Expr(:kw, :private, col.private))
end
if !isempty(col.metadata)
push!(ex.args, Expr(:kw, :metadata, limit ? :… : quoteof(reverse!(collect(col.metadata)))))
end
Expand Down Expand Up @@ -122,10 +127,10 @@ struct SQLTable <: AbstractDict{Symbol, SQLColumn}
end

SQLTable(name; qualifiers = Symbol[], columns, metadata = nothing) =
SQLTable(qualifiers = qualifiers, name = name, columns = columns, metadata = metadata)
SQLTable(; qualifiers, name, columns, metadata)

SQLTable(name, columns...; qualifiers = Symbol[], metadata = nothing) =
SQLTable(qualifiers = qualifiers, name = name, columns = [columns...], metadata = metadata)
SQLTable(; qualifiers, name, columns = [columns...], metadata)

_column_map(columns::OrderedDict{Symbol, SQLColumn}) =
columns
Expand Down Expand Up @@ -280,7 +285,7 @@ struct SQLCatalog <: AbstractDict{Symbol, SQLTable}
end

SQLCatalog(tables...; dialect = :default, cache = default_cache_maxsize, metadata = nothing) =
SQLCatalog(tables = tables, dialect = dialect, cache = cache, metadata = metadata)
SQLCatalog(; tables, dialect, cache, metadata)

_table_map(tables::Dict{Symbol, SQLTable}) =
tables
Expand Down
167 changes: 87 additions & 80 deletions src/link.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,27 @@ struct LinkContext
knot_refs)
end

function link(q::SQLQuery)
@dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError())
ctx = LinkContext(catalog)
t = row_type(tail)
function _select(t::RowType)
refs = SQLQuery[]
for (f, ft) in t.fields
!(f in t.private_fields) || continue
if ft isa ScalarType
push!(refs, Get(f))
else
nested_refs = _select(ft)
for nested_ref in nested_refs
push!(refs, Nested(name = f, tail = nested_ref))
end
end
end
refs
end

function link(q::SQLQuery)
@dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError())
ctx = LinkContext(catalog)
t = row_type(tail)
refs = _select(t)
tail′ = Linked(refs, tail = link(dismantle(tail, ctx), ctx, refs))
WithContext(tail = tail′, catalog = catalog, defs = ctx.defs)
end
Expand Down Expand Up @@ -123,19 +134,15 @@ function dismantle(n::GroupNode, ctx)
Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′)
end

function dismantle(n::IterateNode, ctx)
function dismantle(n::IntoNode, ctx)
tail′ = dismantle(ctx)
iterator′ = dismantle(n.iterator, ctx)
Iterate(iterator = iterator′, tail = tail′)
Into(name = n.name, tail = tail′)
end

function dismantle(n::JoinNode, ctx)
rt = row_type(n.joinee)
router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType))
function dismantle(n::IterateNode, ctx)
tail′ = dismantle(ctx)
joinee′ = dismantle(n.joinee, ctx)
on′ = dismantle_scalar(n.on, ctx)
RoutedJoin(joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional, tail = tail′)
iterator′ = dismantle(n.iterator, ctx)
Iterate(iterator = iterator′, tail = tail′)
end

function dismantle(n::LimitNode, ctx)
Expand Down Expand Up @@ -181,6 +188,13 @@ function dismantle_scalar(n::ResolvedNode, ctx)
end
end

function dismantle(n::RoutedJoinNode, ctx)
tail′ = dismantle(ctx)
joinee′ = dismantle(n.joinee, ctx)
on′ = dismantle_scalar(n.on, ctx)
RoutedJoin(joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional, tail = tail′)
end

function dismantle(n::SelectNode, ctx)
tail′ = dismantle(ctx)
args′ = dismantle_scalar(n.args, ctx)
Expand Down Expand Up @@ -232,16 +246,7 @@ function link(n::AppendNode, ctx)
end

function link(n::AsNode, ctx)
refs = SQLQuery[]
for ref in ctx.refs
if @dissect(ref, (local tail) |> Nested(name = (local name)))
@assert name == n.name
push!(refs, tail)
else
error()
end
end
tail′ = link(ctx.tail, ctx, refs)
tail′ = link(ctx)
As(name = n.name, tail = tail′)
end

Expand Down Expand Up @@ -289,10 +294,8 @@ function link(n::FromIterateNode, ctx)
end

function link(n::FromTableExpressionNode, ctx)
refs = ctx.cte_refs[(n.name, n.depth)]
for ref in ctx.refs
push!(refs, Nested(name = n.name, tail = ref))
end
cte_refs = ctx.cte_refs[(n.name, n.depth)]
append!(cte_refs, ctx.refs)
n
end

Expand Down Expand Up @@ -333,6 +336,20 @@ function link(n::GroupNode, ctx)
Group(by = n.by, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′)
end

function link(n::IntoNode, ctx)
refs = SQLQuery[]
for ref in ctx.refs
if @dissect(ref, (local tail) |> Nested(name = (local name)))
@assert name == n.name
push!(refs, tail)
else
error()
end
end
tail′ = link(ctx.tail, ctx, refs)
Into(name = n.name, tail = tail′)
end

function link(n::IterateNode, ctx)
iterator′ = n.iterator
defs = copy(ctx.defs)
Expand Down Expand Up @@ -364,53 +381,6 @@ function link(n::IterateNode, ctx)
Padding(tail = q′)
end

function route(r::JoinRouter, ref::SQLQuery)
if @dissect(ref, Nested(name = (local name))) && name in r.label_set
return 1
end
if @dissect(ref, Get(name = (local name))) && name in r.label_set
return 1
end
if @dissect(ref, Agg()) && r.group
return 1
end
return -1
end

function link(n::RoutedJoinNode, ctx)
lrefs = SQLQuery[]
rrefs = SQLQuery[]
for ref in ctx.refs
turn = route(n.router, ref)
push!(turn < 0 ? lrefs : rrefs, ref)
end
if n.optional && isempty(rrefs)
return link(ctx)
end
ln_ext_refs = length(lrefs)
rn_ext_refs = length(rrefs)
refs′ = SQLQuery[]
lateral_refs = SQLQuery[]
gather!(n.joinee, ctx, lateral_refs)
append!(lrefs, lateral_refs)
lateral = !isempty(lateral_refs)
gather!(n.on, ctx, refs′)
for ref in refs′
turn = route(n.router, ref)
push!(turn < 0 ? lrefs : rrefs, ref)
end
tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs))
joinee′ = Linked(rrefs, rn_ext_refs, tail = link(n.joinee, ctx, rrefs))
RoutedJoin(
joinee = joinee′,
on = n.on,
router = n.router,
left = n.left,
right = n.right,
lateral = lateral,
tail = tail′)
end

function link(n::LimitNode, ctx)
tail′ = Linked(ctx.refs, tail = link(ctx))
Limit(offset = n.offset, limit = n.limit, tail = tail′)
Expand Down Expand Up @@ -459,6 +429,46 @@ function link(n::PartitionNode, ctx)
Partition(by = n.by, order_by = n.order_by, frame = n.frame, name = n.name, tail = tail′)
end

function link(n::RoutedJoinNode, ctx)
lrefs = SQLQuery[]
rrefs = SQLQuery[]
for ref in ctx.refs
if @dissect(ref, Nested(name = (local name))) && name === n.name
push!(rrefs, ref)
else
push!(lrefs, ref)
end
end
if n.optional && isempty(rrefs)
return link(ctx)
end
ln_ext_refs = length(lrefs)
rn_ext_refs = length(rrefs)
refs′ = SQLQuery[]
lateral_refs = SQLQuery[]
gather!(n.joinee, ctx, lateral_refs)
append!(lrefs, lateral_refs)
lateral = !isempty(lateral_refs)
gather!(n.on, ctx, refs′)
for ref in refs′
if @dissect(ref, Nested(name = (local name))) && name === n.name
push!(rrefs, ref)
else
push!(lrefs, ref)
end
end
tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs))
joinee′ = Linked(rrefs, rn_ext_refs, tail = link(Into(name = n.name, tail = n.joinee), ctx, rrefs))
RoutedJoin(
joinee = joinee′,
on = n.on,
name = n.name,
left = n.left,
right = n.right,
lateral = lateral,
tail = tail′)
end

function link(n::SelectNode, ctx)
refs = SQLQuery[]
gather!(n.args, ctx, refs)
Expand Down Expand Up @@ -556,12 +566,9 @@ end
function gather!(n::IsolatedNode, ctx)
def = ctx.defs[n.idx]
!@dissect(def, Linked()) || return
refs = SQLQuery[]
for (f, ft) in n.type.fields
if ft isa ScalarType
push!(refs, Get(f))
break
end
refs = _select(n.type)
if !isempty(refs)
refs = refs[1:1]
end
def′ = Linked(refs, tail = link(def, ctx, refs))
ctx.defs[n.idx] = def′
Expand Down
15 changes: 9 additions & 6 deletions src/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ terminal(q::SQLQuery) =
Chain(q′, q) =
convert(SQLQuery, q)(q′)

label(q::SQLQuery) =
@something label(q.head) label(q.tail)
function label(q::SQLQuery; default = :_)
l = label(q.head)
l !== nothing ? l : label(q.tail; default)
end

label(n::AbstractSQLNode) =
nothing

label(::Nothing) =
:_
label(::Nothing; default = :_) =
default

label(q) =
label(convert(SQLQuery, q))
label(q; default = :_) =
label(convert(SQLQuery, q); default)


# A variant of SQLQuery for assembling a chain of identifiers.
Expand Down Expand Up @@ -913,6 +915,7 @@ include("nodes/get.jl")
include("nodes/group.jl")
include("nodes/highlight.jl")
include("nodes/internal.jl")
include("nodes/into.jl")
include("nodes/iterate.jl")
include("nodes/join.jl")
include("nodes/limit.jl")
Expand Down
17 changes: 8 additions & 9 deletions src/nodes/as.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ AsNode(name) =
As(name; tail = nothing)
name => tail

In a scalar context, `As` specifies the name of the output column. When
applied to tabular data, `As` wraps the data in a nested record.
`As` specifies the name of the output column.

The arrow operator (`=>`) is a shorthand notation for `As`.

Expand All @@ -35,19 +34,19 @@ SELECT "person_1"."person_id" AS "id"
FROM "person" AS "person_1"
```

*Show all patients together with their state of residence.*
*Show all patients together with their primary care provider.*

```jldoctest
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]);
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]);

julia> location = SQLTable(:location, columns = [:location_id, :state]);
julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]);

julia> q = From(:person) |>
Join(From(:location) |> As(:location),
on = Get.location_id .== Get.location.location_id) |>
Select(Get.person_id, Get.location.state);
Join(From(:provider) |> As(:pcp),
on = Get.provider_id .== Get.pcp.provider_id) |>
Select(Get.person_id, Get.pcp.provider_name);

julia> print(render(q, tables = [person, location]))
julia> print(render(q, tables = [person, provider]))
SELECT
"person_1"."person_id",
"location_1"."state"
Expand Down
Loading