Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing a function to be called multiple times with different inputs #627

Draft
wants to merge 49 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
ed280a6
Modified 1D approx test to show get_argument bug
nicholaskl97 Sep 2, 2022
1c0f0d0
Updated get_argument for eval with multiple inputs
nicholaskl97 Oct 27, 2022
63e0ddc
Forced get_argument when strategy != Quadrature
nicholaskl97 Oct 27, 2022
4e7b1b8
Test file for fixing get_argument
nicholaskl97 Oct 27, 2022
8a612dc
Test file for debugging symbolic_discretize
nicholaskl97 Oct 27, 2022
83e2475
transform_expression uses indvars now
nicholaskl97 Dec 31, 2022
13df657
Some test files
nicholaskl97 Dec 31, 2022
b17f92b
Merge branch 'master' into get_argument-fix
nicholaskl97 Dec 31, 2022
e885f45
Reverted get_argument to original state
nicholaskl97 Dec 31, 2022
74a2749
Removed temporary debug files
nicholaskl97 Dec 31, 2022
d0df2a3
Updated _vcat to accept multiple arguments
nicholaskl97 Jan 1, 2023
41a75f6
get_argument returns all args no just first per eq
nicholaskl97 Jan 12, 2023
c5d9960
Added implicit 1D and another 2D test case
nicholaskl97 Jan 12, 2023
64b56de
generate gridtrain trainsets based of pde vars
nicholaskl97 Jan 12, 2023
55fa847
added OptimJL and OptimOptimisers
nicholaskl97 Jan 12, 2023
b7e3d7a
get_bounds works with new transform_expression
nicholaskl97 Jan 12, 2023
fb199e4
Added test of ODE with hard constraint ic
nicholaskl97 Jan 12, 2023
2572dbf
_vcat now fills out scalar inputs to match batches
nicholaskl97 Jan 24, 2023
3e36fbe
cord now only has variables that show up in the eq
nicholaskl97 Jan 26, 2023
d115eae
GridTraining train_sets now work on the GPU
nicholaskl97 Feb 7, 2023
abb85a8
_vcat maintains Array types when filling
nicholaskl97 Feb 7, 2023
c7d3dc5
Formatting change
nicholaskl97 Feb 7, 2023
d9da546
StochasticTraining now actually uses bcs_points
nicholaskl97 Feb 17, 2023
18338d3
get_bounds uses bcs_points
nicholaskl97 Feb 17, 2023
cee31db
get_bounds uses get_variables
nicholaskl97 Feb 17, 2023
ea1c3b0
Merge branch 'master' into master
nicholaskl97 Feb 17, 2023
be3abf1
Increased test number of points
nicholaskl97 Feb 20, 2023
308454c
get_bounds is now okay with eqs with no variables
nicholaskl97 Feb 20, 2023
09b6cf6
symbolic_utilities doesn't need LinearAlgebra
nicholaskl97 Feb 20, 2023
6e4206b
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Feb 21, 2023
55d142a
Can now handle Ix(u(x,1)) and not just Ix(u(x,y))
nicholaskl97 Feb 21, 2023
a9b6b47
import ComponentArrays used in training_strategies
nicholaskl97 Feb 21, 2023
f815469
Added import ComponentArrays statements
nicholaskl97 Feb 22, 2023
5889a1b
Revert "Added import ComponentArrays statements"
nicholaskl97 Feb 22, 2023
424a7ef
Revert "import ComponentArrays used in training_strategies"
nicholaskl97 Feb 22, 2023
d581889
Revert "added OptimJL and OptimOptimisers"
nicholaskl97 Feb 22, 2023
edcb1a7
Replaced Lux.ComponentArray with using Co...Arrays
nicholaskl97 Feb 22, 2023
b07ae13
Formatted with JuliaFormtter
nicholaskl97 Feb 23, 2023
7a1e0b5
Docstrings were counting against code coverage
nicholaskl97 Mar 7, 2023
7f527c7
Improperly used docstrings changed to comments
nicholaskl97 Mar 8, 2023
530d50e
Added comments for _vcat
nicholaskl97 Mar 8, 2023
48c8b04
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Mar 8, 2023
e4f1536
Updated docstring for build_symbolic_loss_function
nicholaskl97 Mar 9, 2023
238b315
Reductions needed inits for cases like u(0)=0
nicholaskl97 Mar 10, 2023
44f3a28
Formatted with JuliaFormatter
nicholaskl97 Mar 10, 2023
fc7d36c
Added a new integral test
nicholaskl97 Apr 3, 2023
550ab40
Merge remote-tracking branch 'origin/master' into get_argument-fix
nicholaskl97 Apr 3, 2023
4dcf2a8
Merge remote-tracking branch 'origin/master'
nicholaskl97 May 29, 2023
00f07fc
Merge remote-tracking branch 'origin/master'
nicholaskl97 Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 108 additions & 53 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,81 @@ Take expressions in the form:

to

:((cord, θ, phi, derivative, u)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ[1:33], θ"[34:66])
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[1], cord[2])
[(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, θ1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, θ2))) - 0,
(+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, θ2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, θ1))) - 0]
end
end
end)

for Flux.Chain, and

:((cord, θ, phi, derivative, u)->begin
#= ... =#
#= ... =#
begin
(u1, u2) = (θ.depvar.u1, θ.depvar.u2)
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[1], cord[2])
[(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, u1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, u1))) - 0,
(+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, u2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, u2))) - 0]
end
end
end)

for Lux.AbstractExplicitLayer
:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ[1:205], θ[206:410])
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(4, derivative(phi2, u, _vcat(x, y), [[0.0, ε]], 1, θ2)), derivative(phi1, u, _vcat(x, y), [[ε, 0.0]], 1, θ1)) .- 0
Comment on lines +20 to +24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are those made and then not used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not used in this case any more. I think they may be used in the integral case, but that might not be true either. I can look through the different cases to see if they are ever used and remove them if not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it looks like deprecated code now so it would be good to just remove it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, currently, cord1 = vcat(...) is being used for integral equations only and in my efforts to see if it's possible to remove it, I've found something I broke that wasn't being tested for the integral equations, so I'm can work more on fixing that next week. In particular, I'll look for a fix that removes any need for those lines.

Copy link
Author

@nicholaskl97 nicholaskl97 Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing I realized as I was working on this was that Ix(u(sin(x)) will now be interpreted as $\int u(\sin x) dx$. However, Dx(u(sin(x)) is (under my current changes) being interpreted as $u'(\sin x)$, not $\frac{d}{dx}\left[ u(\sin x) \right] = u'(\sin x) \cos x$. They're interpreted this way because that's how the numeric integral and numeric derivative functions were already written. However, it feels a little inconsistent with the way the integral was interpreted; it's instead consistent with an interpretation of Ix(u(sin(x)) as $U(\sin x)$, where $U$ is an antiderivative of $u$.

It feels to me like Ix(u(sin(x)) is $\int u(\sin x) dx$ and Dx(u(sin(x)) is $\frac{d}{dx}\left[ u(\sin x) \right]$, but then I don't know how you would actually specify $U(\sin x)$ or $u'(\sin x)$, or if you should even be allowed to. (I'm fine not letting people use $U(\sin x)$ since it's not uniquely defined, but it feels like they should be able to use $u'(\sin x)$.)

Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, Dx(u(sin(x)) is (under my current changes) being interpreted as

That's not correct and would not play nicely. It should give the same result as what happens when basic symbolic interactions are done:

julia> using Symbolics

julia> @variables u x
2-element Vector{Num}:
 u
 x

julia> @variables u(..) x
2-element Vector{Any}:
  u
 x

julia> u(sin(x))
u(sin(x))

julia> D = Differential(x)
(::Differential) (generic function with 2 methods)

julia> D(u(sin(x)))
Differential(x)(u(sin(x)))

julia> expand_derivatives(D(u(sin(x))))
cos(x)*Differential(sin(x))(u(sin(x)))

end
end
end)

for Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0, and

:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ[1:205], θ[206:410])
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(9, derivative(phi1, u, _vcat(x, y), [[0.0, ε]], 1, θ1)), derivative(phi2, u, _vcat(x, y), [[ε, 0.0]], 1, θ2)) .- 0
end
end
end)

for Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0 (i.e., separate loss functions are created for each equation)

with Flux.Chain; and

:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ.depvar.u1, θ.depvar.u2)
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(4, derivative(phi2, u, _vcat(x, y), [[0.0, ε]], 1, θ2)), derivative(phi1, u, _vcat(x, y), [[ε, 0.0]], 1, θ1)) .- 0
end
end
end)

for Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0 and

:((cord, θ, phi, derivative, integral, u, p)->begin
#= ... =#
#= ... =#
begin
(θ1, θ2) = (θ.depvar.u1, θ.depvar.u2)
(phi1, phi2) = (phi[1], phi[2])
let (x, y) = (cord[[1], :], cord[[2], :])
begin
cord2 = vcat(x, y)
cord1 = vcat(x, y)
end
(+).((*).(9, derivative(phi1, u, _vcat(x, y), [[0.0, ε]], 1, θ1)), derivative(phi2, u, _vcat(x, y), [[ε, 0.0]], 1, θ2)) .- 0
end
end
end)

for Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0

with Lux.Chain
"""
function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
eq_params = SciMLBase.NullParameters(),
Expand All @@ -61,7 +107,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input)
this_eq_indvars = unique(vcat(values(this_eq_pair)...))
else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring above. What does the code look like now?

this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars],
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => filter(arg -> !isempty(find_thing_in_expr(integrand,
arg)),
dict_depvar_input[intvars]),
integrating_depvars))
this_eq_indvars = transformation_vars isa Nothing ?
unique(vcat(values(this_eq_pair)...)) : transformation_vars
Expand Down Expand Up @@ -142,17 +190,10 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
vcat_expr = Expr(:block, :($(eq_pair_expr...)))
vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename

if strategy isa QuadratureTraining
indvars_ex = get_indvars_ex(bc_indvars)
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
build_expr(:tuple, right_arg_pairs))
else
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)]
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
build_expr(:tuple, right_arg_pairs))
end
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)]
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
build_expr(:tuple, right_arg_pairs))

if !(dict_transformation_vars isa Nothing)
transformation_expr_ = Expr[]
Expand Down Expand Up @@ -256,7 +297,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D
hcat(vec(map(points -> collect(points),
Iterators.product(bc_data...)))...))

pde_train_sets = map(pde_args) do bt
pde_train_sets = map(pde_vars) do bt
span = map(b -> get(dict_var_span_, b, b), bt)
_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
Expand Down Expand Up @@ -292,7 +333,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars,
dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains])
dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains])

pde_args = get_argument(eqs, dict_indvars, dict_depvars)
pde_args = get_variables(eqs, dict_indvars, dict_depvars)

pde_lower_bounds = map(pde_args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
Expand Down Expand Up @@ -325,19 +366,33 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str
] for d in domains])

# pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains]
pde_args = get_argument(eqs, dict_indvars, dict_depvars)
pde_bounds = map(pde_args) do pde_arg
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
pde_vars = get_variables(eqs, dict_indvars, dict_depvars)
pde_bounds = map(pde_vars) do pde_var
if !isempty(pde_var)
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_var)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
else
[eltypeθ(0.0)], [eltypeθ(0.0)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what case is this handling?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, it was the case of something like $u(0)=0$. The parser will now just make that into something like u(_vcat(0)) .- 0 instead of u(cord1) .- 0. Since that doesn't have a variable in it, we don't bother making training data for that expression (we can evaluate it without a training set). However, if you passed empty arrays along, then it would error, so instead we're just giving it 0 as both the upper and lower bounds, which don't really have any meaning since there aren't any variables that range between the bounds.

end
end

bound_args = get_argument(bcs, dict_indvars, dict_depvars)
bcs_bounds = map(bound_args) do bound_arg
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
dx_bcs = 1 / strategy.bcs_points
dict_span_bcs = Dict([Symbol(d.variables) => [
infimum(d.domain) + dx_bcs,
supremum(d.domain) - dx_bcs,
] for d in domains])
bound_vars = get_variables(bcs, dict_indvars, dict_depvars)
bcs_bounds = map(bound_vars) do bound_var
if !isempty(bound_var)
bds = mapreduce(s -> get(dict_span_bcs, s, fill(s, 2)), hcat, bound_var)
bds = eltypeθ.(bds)
bds[1, :], bds[2, :]
else
[eltypeθ(0.0)], [eltypeθ(0.0)]
end
end

return pde_bounds, bcs_bounds
end

Expand Down
107 changes: 87 additions & 20 deletions src/symbolic_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,42 @@ julia> _dot_(e)
dottable_(x) = Broadcast.dottable(x)
dottable_(x::Function) = true

"""
_vcat(x...)

Wraps vcat, but isn't dottable. Also, if x contains a mixture of arrays and
scalars, it fills the scalars to match the dimensions of the arrays.

# Examples
```julia-repl
julia> _vcat([1 2], [3 4])
2×2 Matrix{Int64}:
1 2
3 4

julia> _vcat(0, [1 2])
2×2 Matrix{Int64}:
0 0
1 2
```
"""
_vcat(x::Number...) = vcat(x...)
_vcat(x::AbstractArray{<:Number}...) = vcat(x...)
function _vcat(x::Union{Number, AbstractArray{<:Number}}...)
example = first(Iterators.filter(e -> !(e isa Number), x))
dims = (1, size(example)[2:end]...)
x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x)
_vcat(x...)
end
_vcat(x...) = vcat(x...)
dottable_(x::typeof(_vcat)) = false

_dot_(x) = x
function _dot_(x::Expr)
dotargs = Base.mapany(_dot_, x.args)
if x.head === :call && dottable_(x.args[1])
if x.head === :call && x.args[1] === :_vcat
Expr(x.head, dotargs...)
elseif x.head === :call && dottable_(x.args[1])
Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...))
elseif x.head === :comparison
Expr(:comparison,
Expand Down Expand Up @@ -125,17 +157,20 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa
_args = ex.args
for (i, e) in enumerate(_args)
if !(e isa Expr)
if e in keys(dict_depvars)
if e in keys(dict_depvars) # _args represents a call to a dependent variable
depvar = _args[1]
num_depvar = dict_depvars[depvar]
indvars = _args[2:end]
indvars = map((indvar_) -> transform_expression(pinnrep, indvar_),
_args[2:end])
var_ = is_integral ? :(u) : :($(Expr(:$, :u)))
ex.args = if !multioutput
[var_, Symbol(:cord, num_depvar), :($θ), :phi]
else
# Make something like u(x,y) into u([x,y], θ, phi), since the neural net needs to be called with parameters
# Note that [x,y] is achieved with _vcat, which can also fill scalars, as in the u(0,x) case, where vcat(0,x) would fail if x were a row vector
[var_, :((_vcat)($(indvars...))), :($θ), :phi]
nicholaskl97 marked this conversation as resolved.
Show resolved Hide resolved
else # If multioutput, there are different θ and phir for each dependent variable
[
var_,
Symbol(:cord, num_depvar),
:((_vcat)($(indvars...))),
Symbol(:($θ), num_depvar),
Symbol(:phi, num_depvar),
]
Expand All @@ -151,7 +186,8 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa
end
depvar = _args[1]
num_depvar = dict_depvars[depvar]
indvars = _args[2:end]
indvars = map((indvar_) -> transform_expression(pinnrep, indvar_),
_args[2:end])
dict_interior_indvars = Dict([indvar .=> j
for (j, indvar) in enumerate(dict_depvar_input[depvar])])
dim_l = length(dict_interior_indvars)
Expand All @@ -162,13 +198,13 @@ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = fa
εs_dnv = [εs[d] for d in undv]

ex.args = if !multioutput
[var_, :phi, :u, Symbol(:cord, num_depvar), εs_dnv, order, :($θ)]
[var_, :phi, :u, :((_vcat)($(indvars...))), εs_dnv, order, :($θ)]
else
[
var_,
Symbol(:phi, num_depvar),
:u,
Symbol(:cord, num_depvar),
:((_vcat)($(indvars...))),
εs_dnv,
order,
Symbol(:($θ), num_depvar),
Expand Down Expand Up @@ -336,7 +372,8 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input)
expr = toexpr(eq)
pair_ = map(depvars) do depvar
if !isempty(find_thing_in_expr(expr, depvar))
dict_depvars[depvar] => dict_depvar_input[depvar]
dict_depvars[depvar] => filter(arg -> !isempty(find_thing_in_expr(expr, arg)),
dict_depvar_input[depvar])
end
end
Dict(filter(p -> p !== nothing, pair_))
Expand Down Expand Up @@ -419,6 +456,13 @@ function find_thing_in_expr(ex::Expr, thing; ans = [])
return collect(Set(ans))
end

function find_thing_in_expr(ex::Symbol, thing::Symbol; ans = [])
if thing == ex
push!(ans, ex)
end
return ans
end

"""
```julia
get_argument(eqs,_indvars::Array,_depvars::Array)
Expand All @@ -435,27 +479,50 @@ function get_argument(eqs, _indvars::Array, _depvars::Array)
get_argument(eqs, dict_indvars, dict_depvars)
end
function get_argument(eqs, dict_indvars, dict_depvars)
exprs = toexpr.(eqs)
vars = map(exprs) do expr
exprs = toexpr.(eqs) # Equations, as expressions
# vars is an array of arrays of arrays, representing instances of each dependent variable that appears in the expression, by dependent variable, by equation
vars = map(exprs) do expr # For each equation,...
# For each dependent variable, make an array of instances of the dependent variable
_vars = map(depvar -> find_thing_in_expr(expr, depvar), collect(keys(dict_depvars)))
# Remove any empty arrays, representing dependent variables that don't appear in the equation
f_vars = filter(x -> !isempty(x), _vars)
map(x -> first(x), f_vars)
end
args_ = map(vars) do _vars
ind_args_ = map(var -> var.args[2:end], _vars)

args_ = map(vars) do _vars # For each equation, ...
# _vars is an array of arrays of instances of each dependent variables that appears in the equation, by dependent variable

# For each dependent variable, for each instance of the dependent variable, get all arguments of that instance
ind_args_ = map.(var -> var.args[2:end], _vars)

# Get all arguments used in any instance of any dependent variable
all_ind_args = reduce(vcat, reduce(vcat, ind_args_, init = Any[]), init = Any[])

# Add any independent variables from expression-typed dependent variable calls
for ind_arg in all_ind_args
if ind_arg isa Expr
for ind_var in collect(keys(dict_indvars))
if !isempty(NeuralPDE.find_thing_in_expr(ind_arg, ind_var))
push!(all_ind_args, ind_var)
end
end
end
end

syms = Set{Symbol}()
filter(vcat(ind_args_...)) do ind_arg
filter(all_ind_args) do ind_arg # For each argument
if ind_arg isa Symbol
if ind_arg ∈ syms
false
false # remove symbols that have already occurred
else
push!(syms, ind_arg)
true
true # keep symbols that haven't occurred yet, but note their occurance
end
elseif ind_arg isa Expr # we've already taken what we wanted from the expressions
false
else
true
true # keep all non-symbols
end
end
end
return args_ # TODO for all arguments
return args_
end
Loading