Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FourierNeuralOperator and ComponentArrays error #29

Open
KirillZubov opened this issue Aug 27, 2024 · 1 comment
Open

FourierNeuralOperator and ComponentArrays error #29

KirillZubov opened this issue Aug 27, 2024 · 1 comment

Comments

@KirillZubov
Copy link
Member

KirillZubov commented Aug 27, 2024

I came across that FNO does not work with ComponentArrays(which is need for OptimizationProblem). Any ideas what the problem is in?

 fno = FourierNeuralOperator(gelu; chs = (2, 64, 64, 128, 1), modes = (16,))
    θ, st = Lux.setup(Random.default_rng(), fno)
    v = rand(rng, Float32, 2, 40, 50)
    c = fno(v, θ, st)[1] .- 1.0f0
    ff = (θ) -> fno(v, θ, st)[1] .- 1.0f0
    init_params = ComponentArrays.ComponentArray(θ)
    function total_loss(θ)
        sum(abs2, ff(θ))
    end
    total_loss(θ)
    total_loss(init_params)
julia> total_loss(init_params)
ERROR: MethodError: no method matching realfloat(::Array{ComplexF32, 3})

Closest candidates are:
  realfloat(::StridedArray{<:Union{Float32, Float64}})
   @ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:42
  realfloat(::AbstractArray{T}) where T<:Real
   @ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:49

Stacktrace:
  [1] plan_rfft(x::Array{ComplexF32, 3}, region::UnitRange{Int64}; kws::@Kwargs{})
    @ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:221
  [2] rfft(x::Array{ComplexF32, 3}, region::UnitRange{Int64})
    @ AbstractFFTs ~/.julia/packages/AbstractFFTs/4iQz5/src/definitions.jl:67
  [3] transform
    @ ~/.julia/packages/NeuralOperators/rTBsc/src/transform.jl:24 [inlined]
  [4] operator_conv
    @ ~/.julia/packages/NeuralOperators/rTBsc/src/functional.jl:3 [inlined]
  [5] (::OperatorConv{…})(x::Array{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{})
    @ NeuralOperators ~/.julia/packages/NeuralOperators/rTBsc/src/layers.jl:66
  [6] (::OperatorKernel{…})(x::Array{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{})
    @ NeuralOperators ~/.julia/packages/NeuralOperators/rTBsc/src/layers.jl:138
  [7] (::FourierNeuralOperator{…})(x::Array{…}, ps::ComponentArrays.ComponentVector{…}, st::@NamedTuple{})
    @ NeuralOperators ~/.julia/packages/NeuralOperators/rTBsc/src/fno.jl:70
  [8] (::var"#245#246")(θ::ComponentArrays.ComponentVector{ComplexF32, Vector{ComplexF32}, Tuple{ComponentArrays.Axis{…}}})
    @ Main ./REPL[706]:1
  [9] total_loss::ComponentArrays.ComponentVector{ComplexF32, Vector{ComplexF32}, Tuple{ComponentArrays.Axis{…}}})
    @ Main ./REPL[708]:2
 [10] top-level scope
    @ REPL[710]:1
Some type information was truncated. Use `show(err)` to see complete types.
@avik-pal
Copy link
Member

The problem probably stems from some parameters being Float32 and some being ComplexF32, so constructing a ComponentVector creates a ComplexF32. This is pretty much a fundamental shortcoming of any package that forces an array input type (hence the default return type Lux uses is a NamedTuple and not a ComponentArray).

As to how to solve the problem, I don't have any good solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants