Skip to content

Commit

Permalink
refactor: Quadrature Training with Integrals.jl@v4
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jan 17, 2024
1 parent afdce79 commit c747ac0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,10 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature
# mean(abs2,loss_(x,θ), dims=2)
# size_x = fill(size(x)[2],(1,1))
x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x)
sum(abs2, loss_(x, θ), dims = 2) #./ size_x
sum(abs2, vec(loss_(x, θ)), dims = 2) #./ size_x
end
prob = IntegralProblem(integrand, lb, ub, θ, batch = strategy.batch, nout = 1)
integral_function = BatchIntegralFunction(integrand, max_batch = strategy.batch)
prob = IntegralProblem(integral_function, lb, ub, θ)
solve(prob,
strategy.quadrature_alg,
reltol = strategy.reltol,
Expand Down

0 comments on commit c747ac0

Please sign in to comment.