Skip to content

Commit

Permalink
Use a do end for netcdf files
Browse files Browse the repository at this point in the history
  • Loading branch information
Zinoex committed Nov 14, 2023
1 parent ee6d533 commit 0c627dc
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
65 changes: 37 additions & 28 deletions src/Data/imdp.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,46 @@

function read_imdp_jl_file(path)
dataset = Dataset(path)
mdp_or_mc = Dataset(path) do
n = dataset.attrib["num_states"]
initial_state = dataset.attrib["initial_state"]
model = dataset.attrib["model"]

n = dataset.attrib["num_states"]
initial_state = dataset.attrib["initial_state"]
model = dataset.attrib["model"]
@assert model ["imdp", "imc"]
@assert dataset.attrib["rows"] == "to"
@assert dataset.attrib["cols"] ["from", "from/action"]
@assert dataset.properties["format"] == "sparse_csc"

@assert model ["imdp", "imc"]
@assert dataset.attrib["rows"] == "to"
@assert dataset.attrib["cols"] ["from", "from/action"]
@assert dataset.properties["format"] == "sparse_csc"
lower_colptr = convert.(Int32, dataset["lower_colptr"][:])
lower_rowval = convert.(Int32, dataset["lower_rowval"][:])
lower_nzval = dataset["lower_nzval"][:]
= SparseMatrixCSC(
Int32(n + 1),
Int32(n),
lower_colptr,
lower_rowval,
lower_nzval,
)

lower_colptr = convert.(Int32, dataset["lower_colptr"][:])
lower_rowval = convert.(Int32, dataset["lower_rowval"][:])
lower_nzval = dataset["lower_nzval"][:]
= SparseMatrixCSC(Int32(n + 1), Int32(n), lower_colptr, lower_rowval, lower_nzval)
upper_colptr = convert.(Int32, dataset["upper_colptr"][:])
upper_rowval = convert.(Int32, dataset["upper_rowval"][:])
upper_nzval = dataset["upper_nzval"][:]
= SparseMatrixCSC(
Int32(n + 1),
Int32(n),
upper_colptr,
upper_rowval,
upper_nzval,
)

upper_colptr = convert.(Int32, dataset["upper_colptr"][:])
upper_rowval = convert.(Int32, dataset["upper_rowval"][:])
upper_nzval = dataset["upper_nzval"][:]
= SparseMatrixCSC(Int32(n + 1), Int32(n), upper_colptr, upper_rowval, upper_nzval)
prob = MatrixIntervalProbabilities(; lower = P̲, upper = P̅)

prob = MatrixIntervalProbabilities(; lower = P̲, upper = P̅)

if model == "imdp"
mdp_or_mc = read_imdp_jl_mdp(dataset, prob, initial_state)
elseif model == "imc"
mdp_or_mc = read_imdp_jl_mc(dataset, prob, initial_state)
if model == "imdp"
return read_imdp_jl_mdp(dataset, prob, initial_state)
elseif model == "imc"
return read_imdp_jl_mc(dataset, prob, initial_state)
end
end

# IMPORTANT! Otherwise the file cannot be reopened until the OS has released the file handle.
close(dataset)

return mdp_or_mc
end

Expand All @@ -44,14 +53,14 @@ function read_imdp_jl_mdp(dataset, prob, initial_state)
mdp = IntervalMarkovDecisionProcess(prob, stateptr, action_vals, Int32(initial_state))
return mdp
end

function read_imdp_jl_mc(dataset, prob, initial_state)
@assert dataset.attrib["cols"] == "from"

mc = IntervalMarkovChain(prob, Int32(initial_state))
return mc
end

function write_imdp_jl_file(path, mdp_or_mc)
# TODO: implement
end
end
5 changes: 4 additions & 1 deletion src/Data/prism.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ function write_prism_labels_file(path_without_file_ending, mdp_or_mc, terminal_s
return write(path_without_file_ending * ".lab", join(lines, "\n"))
end

function write_prism_transitions_file(path_without_file_ending, mdp::IntervalMarkovDecisionProcess)
function write_prism_transitions_file(
path_without_file_ending,
mdp::IntervalMarkovDecisionProcess,
)
number_states = num_states(mdp)

prob = transition_prob(mdp)
Expand Down

0 comments on commit 0c627dc

Please sign in to comment.