Skip to content

Commit 1e71aa6

Browse files
authored
Support NetCDF input tables as Arrow alternative (#2542)
Fixes most of #2434, not all because only a subset of the tables are supported. This essentially gives the same treatment we currently have for Arrow input files, but for NetCDF files. We choose between the file format based on whether the path has a `.arrow` or `.nc` extension. So in Python if you do: ```py model.basin.state.set_filepath(Path("basin-state.nc")) ``` (Note you need [this](#2039 (comment)) if you read the model from disk.) Ribasim Python will convert the `df.to_xarray().to_netcdf()`, and if you read it back in with Python `ds.to_dataframe()`. In the TOML you get: ```toml [basin] state = "basin-state.nc" ``` The core then also reads based on the file extension. To keep things simple, when reading data into the core we convert the N dimensional arrays in the NetCDF to tables, so the rest of the initialization doesn't need to handle both data structures. We should also think about what to do for irregular tables that don't fit in a structured array. Perhaps just throw an error?
1 parent 71ea905 commit 1e71aa6

File tree

12 files changed

+279
-82
lines changed

12 files changed

+279
-82
lines changed

core/src/Ribasim.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ using StructArrays: StructVector
153153
using DataStructures: OrderedSet, OrderedDict, counter, inc!
154154

155155
# NCDatasets is used to read and write NetCDF files.
156-
using NCDatasets: NCDataset, defDim, defVar
156+
using NCDatasets: NCDataset, defDim, defVar, dimnames
157157

158158
using Dates: Second
159159

core/src/config.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ function Base.getproperty(config::Config, sym::Symbol)
259259
end
260260

261261
"Construct a path relative to both the TOML directory and the optional `input_dir`"
262-
function input_path(config::Config, path::String="")
262+
function input_path(config::Config, path::String = "")
263263
return normpath(config.dir, config.input_dir, path)
264264
end
265265

@@ -269,7 +269,7 @@ function database_path(config::Config)
269269
end
270270

271271
"Construct a path relative to both the TOML directory and the optional `results_dir`"
272-
function results_path(config::Config, path::String="")
272+
function results_path(config::Config, path::String = "")
273273
# If the path is empty, we return the results directory.
274274
if !isempty(path)
275275
name, ext = splitext(path)

core/src/read.jl

Lines changed: 115 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,40 +1714,135 @@ DateTime. This is used to convert between the solver's inner float time, and the
17141714
datetime_since(t::Real, t0::DateTime)::DateTime = t0 + Millisecond(round(1000 * t))
17151715

17161716
"""
1717-
load_data(db::DB, config::Config, nodetype::Symbol, kind::Symbol)::Union{Arrow.Table, Query, Nothing}
1717+
load_netcdf(table_path::String, table_type::Type{<:Table})::NamedTuple
17181718
1719-
Load data from Arrow files if available, otherwise the database.
1720-
Returns either an `Arrow.Table`, `SQLite.Query` or `nothing` if the data is not present.
1719+
Load a table from a NetCDF file. The data is stored as multi-dimensional arrays, and
1720+
converted to a table for compatibility with the rest of the internals.
1721+
"""
1722+
function load_netcdf(table_path::String, table_type::Type{<:Table})::NamedTuple
1723+
table = NCDataset(table_path) do ds
1724+
names = fieldnames(table_type)
1725+
table = OrderedDict{Symbol, AbstractVector}()
1726+
data_varnames = filter(x -> !(String(x) in nc_dim_names), names)
1727+
for data_varname in data_varnames
1728+
var = ds[data_varname]
1729+
dim_names = dimnames(var)
1730+
if dim_names == ("node_id",)
1731+
table[:node_id] = ds["node_id"][:]
1732+
elseif dim_names == ("node_id", "time")
1733+
node_id_data = ds["node_id"][:]
1734+
time_data = ds["time"][:]
1735+
ntime = length(time_data)
1736+
nnode = length(node_id_data)
1737+
table[:node_id] = repeat(node_id_data; outer = ntime)
1738+
table[:time] = repeat(time_data; inner = nnode)
1739+
else
1740+
error("Unsupported dimensions: $dim_names, must be (node_id, [time])")
1741+
end
1742+
table[data_varname] = vec(var[:])
1743+
end
1744+
table
1745+
end
1746+
return columntable(table)
1747+
end
1748+
1749+
"""
1750+
load_data(db::DB, config::Config, nodetype::Symbol, kind::Symbol)::Union{NamedTuple, Nothing}
1751+
1752+
Load data from Arrow or NetCDF files if available, otherwise the database.
1753+
Returns either a `NamedTuple` of Vectors or `nothing` if the data is not present.
17211754
"""
17221755
function load_data(
17231756
db::DB,
17241757
config::Config,
17251758
table_type::Type{<:Table},
1726-
)::Union{Arrow.Table, Query, Nothing}
1727-
# TODO load_data doesn't need both config and db, use config to check which one is needed
1728-
1759+
)::Union{NamedTuple, Nothing}
17291760
toml = getfield(config, :toml)
17301761
section_name = snake_case(node_type(table_type))
17311762
section = getproperty(toml, section_name)
17321763
kind = table_name(table_type)
17331764
sql_name = sql_table_name(table_type)
17341765

1735-
path = if hasproperty(section, kind)
1736-
getproperty(section, kind)
1737-
else
1738-
nothing
1739-
end
1766+
path = hasproperty(section, kind) ? getproperty(section, kind) : nothing
17401767

1741-
table = if !isnothing(path)
1768+
if !isnothing(path)
1769+
# the TOML specifies a file outside the database
1770+
path = getproperty(section, kind)
17421771
table_path = input_path(config, path)
1743-
Arrow.Table(read(table_path); convert = false)
1744-
elseif exists(db, sql_name)
1745-
execute(db, "select * from $(esc_id(sql_name))")
1772+
# check suffix and read with Arrow or NCDatasets
1773+
ext = lowercase(splitext(table_path)[2])
1774+
if ext == ".nc"
1775+
return load_netcdf(table_path, table_type)
1776+
elseif ext == ".arrow"
1777+
bytes = read(table_path)
1778+
arrow_table = Arrow.Table(bytes; convert = false)
1779+
return arrow_columntable(arrow_table, table_type)
1780+
else
1781+
error("Unsupported file format: $table_path")
1782+
end
17461783
else
1747-
nothing
1784+
if exists(db, sql_name)
1785+
table = execute(db, "select * from $(esc_id(sql_name))")
1786+
return sqlite_columntable(table, db, config, table_type)
1787+
else
1788+
return nothing
1789+
end
1790+
end
1791+
end
1792+
1793+
"Faster alternative to Tables.columntable that preallocates based on the schema."
1794+
function sqlite_columntable(
1795+
table::Query,
1796+
db::DB,
1797+
config::Config,
1798+
T::Type{<:Table},
1799+
)::NamedTuple
1800+
sql_name = sql_table_name(T)
1801+
nrows = execute(db, "SELECT COUNT(*) FROM $(esc_id(sql_name))") |> first |> first
1802+
1803+
names = fieldnames(T)
1804+
types = fieldtypes(T)
1805+
vals = ntuple(i -> Vector{types[i]}(undef, nrows), length(names))
1806+
nt = NamedTuple{names}(vals)
1807+
1808+
for (i, row) in enumerate(table)
1809+
for name in names
1810+
val = row[name]
1811+
if name == :time
1812+
# time has type timestamp and is stored as a String in the database
1813+
# currently SQLite.jl does not automatically convert it to DateTime
1814+
val = if ismissing(val)
1815+
DateTime(config.starttime)
1816+
else
1817+
DateTime(
1818+
replace(val, r"(\.\d{3})\d+$" => s"\1"), # remove sub ms precision
1819+
dateformat"yyyy-mm-dd HH:MM:SS.s",
1820+
)
1821+
end
1822+
end
1823+
nt[name][i] = val
1824+
end
17481825
end
1826+
nt
1827+
end
17491828

1750-
return table
1829+
"Alternative to Tables.columntable that converts time to our own to_datetime."
1830+
function arrow_columntable(table::Query, T::Type{<:Table})::NamedTuple
1831+
nrows = length(first(table))
1832+
names = fieldnames(T)
1833+
types = fieldtypes(T)
1834+
vals = ntuple(i -> Vector{types[i]}(undef, nrows), length(names))
1835+
nt = NamedTuple{names}(vals)
1836+
1837+
for name in names
1838+
if name == :time
1839+
time_col = getproperty(table, name)
1840+
nt[name] .= [to_datetime(t) for t in time_col]
1841+
else
1842+
nt[name] .= getproperty(table, name)
1843+
end
1844+
end
1845+
nt
17511846
end
17521847

17531848
# alternative to convert that doesn't have warntimestamp
@@ -1762,7 +1857,7 @@ end
17621857
"""
17631858
load_structvector(db::DB, config::Config, ::Type{T})::StructVector{T}
17641859
1765-
Load data from Arrow files if available, otherwise the database.
1860+
Load data from Arrow or NetCDF files if available, otherwise the database.
17661861
Always returns a StructVector of the given struct type T, which is empty if the table is
17671862
not found. This function validates the schema, and enforces the required sort order.
17681863
"""
@@ -1771,62 +1866,12 @@ function load_structvector(
17711866
config::Config,
17721867
::Type{T},
17731868
)::StructVector{T} where {T <: Table}
1774-
table = load_data(db, config, T)
1869+
nt = load_data(db, config, T)
17751870

1776-
if table === nothing
1871+
if nt === nothing
17771872
return StructVector{T}(undef, 0)
17781873
end
17791874

1780-
table_in_db = table isa Query
1781-
1782-
nt = if table_in_db
1783-
# faster alternative to Tables.columntable that preallocates based on the schema
1784-
sql_name = sql_table_name(T)
1785-
nrows =
1786-
execute(db, "SELECT COUNT(*) FROM $(esc_id(sql_name))") |> first |> first
1787-
1788-
names = fieldnames(T)
1789-
types = fieldtypes(T)
1790-
vals = ntuple(i -> Vector{types[i]}(undef, nrows), length(names))
1791-
nt = NamedTuple{names}(vals)
1792-
1793-
for (i, row) in enumerate(table)
1794-
for name in names
1795-
val = row[name]
1796-
if name == :time
1797-
# time has type timestamp and is stored as a String in the database
1798-
# currently SQLite.jl does not automatically convert it to DateTime
1799-
val = if ismissing(val)
1800-
DateTime(config.starttime)
1801-
else
1802-
DateTime(
1803-
replace(val, r"(\.\d{3})\d+$" => s"\1"), # remove sub ms precision
1804-
dateformat"yyyy-mm-dd HH:MM:SS.s",
1805-
)
1806-
end
1807-
end
1808-
nt[name][i] = val
1809-
end
1810-
end
1811-
nt
1812-
else
1813-
nrows = length(first(table))
1814-
names = fieldnames(T)
1815-
types = fieldtypes(T)
1816-
vals = ntuple(i -> Vector{types[i]}(undef, nrows), length(names))
1817-
nt = NamedTuple{names}(vals)
1818-
1819-
for name in names
1820-
if name == :time
1821-
time_col = getproperty(table, name)
1822-
nt[name] .= [to_datetime(t) for t in time_col]
1823-
else
1824-
nt[name] .= getproperty(table, name)
1825-
end
1826-
end
1827-
nt
1828-
end
1829-
18301875
table = StructVector{T}(nt)
18311876
return sorted_table!(table)
18321877
end

core/test/io_test.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,54 @@ end
314314
storage2_begin = current_storage
315315
@test storage1_end storage2_begin
316316
end
317+
318+
@testitem "warm state netcdf" begin
319+
# This tests that we can write Basin / state results to NetCDF, and read this in again
320+
# as a warm state, such that the storages at the end of one run are equal to those
321+
# at the beginning of the second run.
322+
323+
using IOCapture: capture
324+
using Ribasim: solve!, write_results
325+
import TOML
326+
327+
model_path_src = normpath(@__DIR__, "../../generated_testmodels/basic/")
328+
329+
# avoid changing the original model for other tests
330+
model_path = normpath(@__DIR__, "../../generated_testmodels/basic_warm_netcdf/")
331+
cp(model_path_src, model_path; force = true)
332+
toml_path = normpath(model_path, "ribasim.toml")
333+
334+
# Configure model to use NetCDF format
335+
toml_dict = TOML.parsefile(toml_path)
336+
toml_dict["results"] = Dict("format" => "netcdf")
337+
open(toml_path, "w") do io
338+
TOML.print(io, toml_dict)
339+
end
340+
341+
config = Ribasim.Config(toml_path)
342+
model = Ribasim.Model(config)
343+
(; p_independent, state_time_dependent_cache) = model.integrator.p
344+
(; current_storage) = state_time_dependent_cache
345+
storage1_begin = copy(current_storage)
346+
solve!(model)
347+
storage1_end = current_storage
348+
@test storage1_begin != storage1_end
349+
350+
# copy state results to input
351+
write_results(model)
352+
state_path = Ribasim.results_path(config, Ribasim.RESULTS_FILENAME.basin_state)
353+
cp(state_path, Ribasim.input_path(config, "warm_state.nc"))
354+
355+
# point TOML to the warm state NetCDF file
356+
toml_dict = TOML.parsefile(toml_path)
357+
toml_dict["basin"] = Dict("state" => "warm_state.nc")
358+
open(toml_path, "w") do io
359+
TOML.print(io, toml_dict)
360+
end
361+
362+
model = Ribasim.Model(toml_path)
363+
(; p_independent, state_time_dependent_cache) = model.integrator.p
364+
(; current_storage) = state_time_dependent_cache
365+
storage2_begin = current_storage
366+
@test storage1_end storage2_begin
367+
end

pixi.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ jupyter = "*"
235235
matplotlib = ">=3.7"
236236
minio = "*"
237237
mypy = "*"
238-
netcdf4 = "*"
238+
netcdf4 = ">=1.7.1"
239239
networkx = ">=3.3"
240240
numpy = ">=1.25, <2.2"
241241
packaging = ">=23.0"
@@ -264,7 +264,7 @@ teamcity-messages = "*"
264264
tomli = ">=2.0"
265265
tomli-w = ">=1.0"
266266
twine = "*"
267-
xarray = "*"
267+
xarray = ">=2025.8.0"
268268
xmipy = ">=1.3"
269269
xugrid = "*"
270270

python/ribasim/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"datacompy >=0.16",
2020
"geopandas >=1.0",
2121
"matplotlib >=3.7",
22+
"netCDF4 >=1.7.1",
2223
"numpy >=1.25",
2324
"packaging >=23.0",
2425
"pandas >=2.0",
@@ -29,6 +30,7 @@ dependencies = [
2930
"shapely >=2.0",
3031
"tomli >=2.0",
3132
"tomli-w >=1.0",
33+
"xarray >=2025.8.0",
3234
]
3335
dynamic = ["version"]
3436

0 commit comments

Comments
 (0)