diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl
index e463874590..74562175ae 100644
--- a/src/NeuralPDE.jl
+++ b/src/NeuralPDE.jl
@@ -25,6 +25,7 @@ using Symbolics: wrap, unwrap, arguments, operation
 using SymbolicUtils
 using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains
 using MonteCarloMeasurements
+
 import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives
 import DomainSets: Domain, ClosedInterval
 import ModelingToolkit: Interval, infimum, supremum #,Ball
diff --git a/src/PDE_BPINN.jl b/src/PDE_BPINN.jl
index 344d007963..d7642a3c2d 100644
--- a/src/PDE_BPINN.jl
+++ b/src/PDE_BPINN.jl
@@ -62,6 +62,12 @@ mutable struct PDELogTargetDensity{
     end
 end
 
+LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim
+
+function LogDensityProblems.capabilities(::PDELogTargetDensity)
+    LogDensityProblems.LogDensityOrder{1}()
+end
+
 function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
     # for parameter estimation neccesarry to use multioutput case
     return Tar.full_loglikelihood(setparameters(Tar, θ),
@@ -87,11 +93,12 @@ function setparameters(Tar::PDELogTargetDensity, θ)
 
     a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))
 
-    if Tar.extraparams > 0
+   if Tar.extraparams > 0
         b = θ[(end - Tar.extraparams + 1):end]
+
         return ComponentArrays.ComponentArray(;
-            depvar = a,
-            p = b)
+               depvar = a,
+               p = b)
     else
         return ComponentArrays.ComponentArray(;
             depvar = a)
@@ -298,6 +305,12 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
         Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
         numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
     pinnrep = symbolic_discretize(pde_system, discretization)
+    
+    pinnrep.iteration = [0]
+
+    
+    pinnrep.iteration = [0]
+
     dataset_pde, dataset_bc = discretization.dataset
 
     if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing))
@@ -428,12 +441,20 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
         end
         return bpinnsols
     else
+        println("now 1")
+
+        println("now 1")
+
         initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
         integrator = integratorchoice(Integratorkwargs, initial_ϵ)
         adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
             StepSizeAdaptor(targetacceptancerate, integrator))
 
         Kernel = AdvancedHMC.make_kernel(Kernel, integrator)
+        println("now 2")
+
+        println("now 2")
+
         samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
             adaptor; progress = progress, verbose = verbose)
 
diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl
index 3a1c4a79db..496478c2ac 100644
--- a/src/adaptive_losses.jl
+++ b/src/adaptive_losses.jl
@@ -11,10 +11,13 @@ function vectorify(x, t::Type{T}) where {T <: Real}
 end
 
 # Dispatches
+
 """
-    NonAdaptiveLoss(; pde_loss_weights = 1.0,
-                      bc_loss_weights = 1.0,
-                      additional_loss_weights = 1.0)
+```julia
+NonAdaptiveLoss{T}(; pde_loss_weights = 1,
+                     bc_loss_weights = 1,
+                     additional_loss_weights = 1)
+```
 
 A way of loss weighting the components of the loss function in the total sum that does not
 change during optimization
@@ -23,42 +26,45 @@ mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
     pde_loss_weights::Vector{T}
     bc_loss_weights::Vector{T}
     additional_loss_weights::Vector{T}
-    SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1.0,
-                                                      bc_loss_weights = 1.0,
-                                                      additional_loss_weights = 1.0) where {
-                                                                                          T <:
-                                                                                          Real
-                                                                                          }
+    SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1,
+            bc_loss_weights = 1,
+            additional_loss_weights = 1) where {
+            T <:
+            Real,
+        }
         new(vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
             vectorify(additional_loss_weights, T))
     end
 end
 
 # default to Float64
-SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1.0, bc_loss_weights = 1.0,
-                                               additional_loss_weights = 1.0)
+SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, bc_loss_weights = 1,
+        additional_loss_weights = 1)
     NonAdaptiveLoss{Float64}(; pde_loss_weights = pde_loss_weights,
-                             bc_loss_weights = bc_loss_weights,
-                             additional_loss_weights = additional_loss_weights)
+        bc_loss_weights = bc_loss_weights,
+        additional_loss_weights = additional_loss_weights)
 end
 
 function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
-                                         adaloss::NonAdaptiveLoss,
-                                         pde_loss_functions, bc_loss_functions)
+        adaloss::NonAdaptiveLoss,
+        pde_loss_functions, bc_loss_functions)
     function null_nonadaptive_loss(θ, pde_losses, bc_losses)
         nothing
     end
 end
 
 """
-    GradientScaleAdaptiveLoss(reweight_every;
-                            weight_change_inertia = 0.9,
-                            pde_loss_weights = 1.0,
-                            bc_loss_weights = 1.0,
-                            additional_loss_weights = 1.0)
+```julia
+GradientScaleAdaptiveLoss(reweight_every;
+                          weight_change_inertia = 0.9,
+                          pde_loss_weights = 1,
+                          bc_loss_weights = 1,
+                          additional_loss_weights = 1)
+```
 
 A way of adaptively reweighting the components of the loss function in the total sum such
-that BC_i loss weights are scaled by the exponential moving average of max(|∇pde_loss|) / mean(|∇bc_i_loss|)).
+that BC_i loss weights are scaled by the exponential moving average of
+max(|∇pde_loss|)/mean(|∇bc_i_loss|) )
 
 ## Positional Arguments
 
@@ -80,41 +86,54 @@ https://arxiv.org/abs/2001.04536v1
 With code reference:
 https://github.com/PredictiveIntelligenceLab/GradientPathologiesPINNs
 """
-mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
+mutable struct GradientScaleAdaptiveLoss{T <: Real, WT <: Real} <:
+               AbstractAdaptiveLoss
     reweight_every::Int64
     weight_change_inertia::T
-    pde_loss_weights::Vector{T}
-    bc_loss_weights::Vector{T}
-    additional_loss_weights::Vector{T}
-    SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T}(reweight_every;
-                                                                weight_change_inertia = 0.9,
-                                                                pde_loss_weights = 1.0,
-                                                                bc_loss_weights = 1.0,
-                                                                additional_loss_weights = 1.0) where {
-                                                                                                    T <:
-                                                                                                    Real
-                                                                                                    }
+    pde_loss_weights::Vector{WT}
+    bc_loss_weights::Vector{WT}
+    additional_loss_weights::Vector{WT}
+    SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T, WT}(reweight_every;
+            weight_change_inertia = 0.9,
+            pde_loss_weights = 1,
+            bc_loss_weights = 1,
+            additional_loss_weights = 1) where {
+            T <: Real, WT <: Real,
+        }
         new(convert(Int64, reweight_every), convert(T, weight_change_inertia),
-            vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
-            vectorify(additional_loss_weights, T))
+            vectorify(pde_loss_weights, WT), vectorify(bc_loss_weights, WT),
+            vectorify(additional_loss_weights, WT))
     end
 end
 # default to Float64
 SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every;
-                                                         weight_change_inertia = 0.9,
-                                                         pde_loss_weights = 1.0,
-                                                         bc_loss_weights = 1.0,
-                                                         additional_loss_weights = 1.0)
-    GradientScaleAdaptiveLoss{Float64}(reweight_every;
-                                       weight_change_inertia = weight_change_inertia,
-                                       pde_loss_weights = pde_loss_weights,
-                                       bc_loss_weights = bc_loss_weights,
-                                       additional_loss_weights = additional_loss_weights)
+        weight_change_inertia = 0.9,
+        pde_loss_weights = 1,
+        bc_loss_weights = 1,
+        additional_loss_weights = 1)
+    GradientScaleAdaptiveLoss{Float64, Float64}(reweight_every;
+        weight_change_inertia = weight_change_inertia,
+        pde_loss_weights = pde_loss_weights,
+        bc_loss_weights = bc_loss_weights,
+        additional_loss_weights = additional_loss_weights)
+end
+
+# As BPINN gradient of likelihood terms has dual nature
+SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every;
+        weight_change_inertia = 0.9,
+        pde_loss_weights = ForwardDiff.Dual{Float64}(1.0, ntuple(_ -> 0.0, 1)),
+        bc_loss_weights = ForwardDiff.Dual{Float64}(1.0, ntuple(_ -> 0.0, 1)),
+        additional_loss_weights = ForwardDiff.Dual{Float64}(1.0, ntuple(_ -> 0.0, 1)))
+    GradientScaleAdaptiveLoss{Float64, ForwardDiff.Dual{Float64, Float64}}(reweight_every;
+        weight_change_inertia = weight_change_inertia,
+        pde_loss_weights = pde_loss_weights,
+        bc_loss_weights = bc_loss_weights,
+        additional_loss_weights = additional_loss_weights)
 end
 
 function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
-                                         adaloss::GradientScaleAdaptiveLoss,
-                                         pde_loss_functions, bc_loss_functions)
+        adaloss::GradientScaleAdaptiveLoss,
+        pde_loss_functions, bc_loss_functions)
     weight_change_inertia = adaloss.weight_change_inertia
     iteration = pinnrep.iteration
     adaloss_T = eltype(adaloss.pde_loss_weights)
@@ -137,30 +156,32 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
                                        (1 .- weight_change_inertia) .*
                                        bc_loss_weights_proposed
             logscalar(pinnrep.logger, pde_grads_max, "adaptive_loss/pde_grad_max",
-                      iteration[1])
+                iteration[1])
             logvector(pinnrep.logger, pde_grads_maxes, "adaptive_loss/pde_grad_maxes",
-                      iteration[1])
+                iteration[1])
             logvector(pinnrep.logger, bc_grads_mean, "adaptive_loss/bc_grad_mean",
-                      iteration[1])
+                iteration[1])
             logvector(pinnrep.logger, adaloss.bc_loss_weights,
-                      "adaptive_loss/bc_loss_weights",
-                      iteration[1])
+                "adaptive_loss/bc_loss_weights",
+                iteration[1])
         end
         nothing
     end
 end
 
 """
-    function MiniMaxAdaptiveLoss(reweight_every;
-                                pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
-                                bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
-                                pde_loss_weights = 1,
-                                bc_loss_weights = 1,
-                                additional_loss_weights = 1)
+```julia
+function MiniMaxAdaptiveLoss(reweight_every;
+                             pde_max_optimiser = Flux.ADAM(1e-4),
+                             bc_max_optimiser = Flux.ADAM(0.5),
+                             pde_loss_weights = 1,
+                             bc_loss_weights = 1,
+                             additional_loss_weights = 1)
+```
 
 A way of adaptively reweighting the components of the loss function in the total sum such
 that the loss weights are maximized by an internal optimizer, which leads to a behavior
-where loss functions that have not been satisfied get a greater weight.
+where loss functions that have not been satisfied get a greater weight,
 
 ## Positional Arguments
 
@@ -170,9 +191,9 @@ where loss functions that have not been satisfied get a greater weight.
 
 ## Keyword Arguments
 
-* `pde_max_optimiser`: a OptimizationOptimisers optimiser that is used internally to
+* `pde_max_optimiser`: a Flux.Optimise.AbstractOptimiser that is used internally to
   maximize the weights of the PDE loss functions.
-* `bc_max_optimiser`: a OptimizationOptimisers optimiser that is used internally to maximize
+* `bc_max_optimiser`: a Flux.Optimise.AbstractOptimiser that is used internally to maximize
   the weights of the BC loss functions.
 
 ## References
@@ -182,8 +203,8 @@ Levi McClenny, Ulisses Braga-Neto
 https://arxiv.org/abs/2009.04544
 """
 mutable struct MiniMaxAdaptiveLoss{T <: Real,
-                                   PDE_OPT,
-                                   BC_OPT} <:
+    PDE_OPT <: Flux.Optimise.AbstractOptimiser,
+    BC_OPT <: Flux.Optimise.AbstractOptimiser} <:
                AbstractAdaptiveLoss
     reweight_every::Int64
     pde_max_optimiser::PDE_OPT
@@ -192,17 +213,19 @@ mutable struct MiniMaxAdaptiveLoss{T <: Real,
     bc_loss_weights::Vector{T}
     additional_loss_weights::Vector{T}
     SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T,
-                                                       PDE_OPT, BC_OPT}(reweight_every;
-                                                                        pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
-                                                                        bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
-                                                                        pde_loss_weights = 1.0,
-                                                                        bc_loss_weights = 1.0,
-                                                                        additional_loss_weights = 1.0) where {
-                                                                                                            T <:
-                                                                                                            Real,
-                                                                                                            PDE_OPT,
-                                                                                                            BC_OPT
-                                                                                                            }
+            PDE_OPT, BC_OPT}(reweight_every;
+            pde_max_optimiser = Flux.ADAM(1e-4),
+            bc_max_optimiser = Flux.ADAM(0.5),
+            pde_loss_weights = 1,
+            bc_loss_weights = 1,
+            additional_loss_weights = 1) where {
+            T <:
+            Real,
+            PDE_OPT <:
+            Flux.Optimise.AbstractOptimiser,
+            BC_OPT <:
+            Flux.Optimise.AbstractOptimiser,
+        }
         new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser),
             convert(BC_OPT, bc_max_optimiser),
             vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
@@ -212,38 +235,50 @@ end
 
 # default to Float64, ADAM, ADAM
 SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every;
-                                                   pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
-                                                   bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
-                                                   pde_loss_weights = 1.0,
-                                                   bc_loss_weights = 1.0,
-                                                   additional_loss_weights = 1.0)
+        pde_max_optimiser = Flux.ADAM(1e-4),
+        bc_max_optimiser = Flux.ADAM(0.5),
+        pde_loss_weights = 1,
+        bc_loss_weights = 1,
+        additional_loss_weights = 1)
     MiniMaxAdaptiveLoss{Float64, typeof(pde_max_optimiser),
-                        typeof(bc_max_optimiser)}(reweight_every;
-                                                  pde_max_optimiser = pde_max_optimiser,
-                                                  bc_max_optimiser = bc_max_optimiser,
-                                                  pde_loss_weights = pde_loss_weights,
-                                                  bc_loss_weights = bc_loss_weights,
-                                                  additional_loss_weights = additional_loss_weights)
+        typeof(bc_max_optimiser)}(reweight_every;
+        pde_max_optimiser = pde_max_optimiser,
+        bc_max_optimiser = bc_max_optimiser,
+        pde_loss_weights = pde_loss_weights,
+        bc_loss_weights = bc_loss_weights,
+        additional_loss_weights = additional_loss_weights)
 end
 
 function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
-                                         adaloss::MiniMaxAdaptiveLoss,
-                                         pde_loss_functions, bc_loss_functions)
+        adaloss::MiniMaxAdaptiveLoss,
+        pde_loss_functions, bc_loss_functions)
     pde_max_optimiser = adaloss.pde_max_optimiser
-    pde_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(pde_max_optimiser, adaloss.pde_loss_weights)
     bc_max_optimiser = adaloss.bc_max_optimiser
-    bc_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(bc_max_optimiser, adaloss.bc_loss_weights)
     iteration = pinnrep.iteration
 
     function run_minimax_adaptive_loss(θ, pde_losses, bc_losses)
         if iteration[1] % adaloss.reweight_every == 0
-            OptimizationOptimisers.Optimisers.update!(pde_max_optimiser_setup, adaloss.pde_loss_weights, -pde_losses)
-            OptimizationOptimisers.Optimisers.update!(bc_max_optimiser_setup, adaloss.bc_loss_weights, -bc_losses)
+            for i in eachindex(pde_losses)
+                if !isfinite(pde_losses[i].value)
+                    continue
+                else
+                    Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights[i],
+                        pde_losses[i].value)
+                end
+            end
+            for i in eachindex(bc_losses)
+                if !isfinite(bc_losses[i].value)
+                    continue
+                else
+                    Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights,
+                        bc_losses[i].value)
+                end
+            end
             logvector(pinnrep.logger, adaloss.pde_loss_weights,
-                      "adaptive_loss/pde_loss_weights", iteration[1])
+                "adaptive_loss/pde_loss_weights", iteration[1])
             logvector(pinnrep.logger, adaloss.bc_loss_weights,
-                      "adaptive_loss/bc_loss_weights",
-                      iteration[1])
+                "adaptive_loss/bc_loss_weights",
+                iteration[1])
         end
         nothing
     end
diff --git a/src/discretize.jl b/src/discretize.jl
index af035980b3..e4c7a5e195 100644
--- a/src/discretize.jl
+++ b/src/discretize.jl
@@ -26,14 +26,14 @@ to
 for Lux.AbstractExplicitLayer.
 """
 function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
-                                      eq_params = SciMLBase.NullParameters(),
-                                      param_estim = false,
-                                      default_p = nothing,
-                                      bc_indvars = pinnrep.indvars,
-                                      integrand = nothing,
-                                      dict_transformation_vars = nothing,
-                                      transformation_vars = nothing,
-                                      integrating_depvars = pinnrep.depvars)
+        eq_params = SciMLBase.NullParameters(),
+        param_estim = false,
+        default_p = nothing,
+        bc_indvars = pinnrep.indvars,
+        integrand = nothing,
+        dict_transformation_vars = nothing,
+        transformation_vars = nothing,
+        integrating_depvars = pinnrep.depvars)
     @unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input,
     phi, derivative, integral,
     multioutput, init_params, strategy, eq_params,
@@ -47,7 +47,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
         this_eq_indvars = unique(vcat(values(this_eq_pair)...))
     else
         this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars],
-                                integrating_depvars))
+            integrating_depvars))
         this_eq_indvars = transformation_vars isa Nothing ?
                           unique(vcat(values(this_eq_pair)...)) : transformation_vars
         loss_function = integrand
@@ -91,7 +91,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
             push!(params_symbols, Symbol(:($eq_param)))
         end
         params_eq = Expr(:(=), build_expr(:tuple, params_symbols),
-                         build_expr(:tuple, expr_params))
+            build_expr(:tuple, expr_params))
         push!(ex.args, params_eq)
     end
 
@@ -103,7 +103,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
             push!(params_symbols, Symbol(:($eq_param)))
         end
         params_eq = Expr(:(=), build_expr(:tuple, params_symbols),
-                         build_expr(:tuple, expr_params))
+            build_expr(:tuple, expr_params))
         push!(ex.args, params_eq)
     end
 
@@ -118,12 +118,12 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
         indvars_ex = get_indvars_ex(bc_indvars)
         left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
         vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
-                       build_expr(:tuple, right_arg_pairs))
+            build_expr(:tuple, right_arg_pairs))
     else
         indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)]
         left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex
         vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs),
-                       build_expr(:tuple, right_arg_pairs))
+            build_expr(:tuple, right_arg_pairs))
     end
 
     if !(dict_transformation_vars isa Nothing)
@@ -133,11 +133,13 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
         end
         transformation_expr = Expr(:block, :($(transformation_expr_...)))
         vcat_expr_loss_functions = Expr(:block, transformation_expr, vcat_expr,
-                                        loss_function)
+            loss_function)
     end
     let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions)
     push!(ex.args, let_ex)
-    expr_loss_function = :(($vars) -> begin $ex end)
+    expr_loss_function = :(($vars) -> begin
+        $ex
+    end)
 end
 
 """
@@ -152,14 +154,16 @@ function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars)
     bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars
 
     expr_loss_function = build_symbolic_loss_function(pinnrep, eqs;
-                                                      bc_indvars = bc_indvars,
-                                                      eq_params = eq_params,
-                                                      param_estim = param_estim,
-                                                      default_p = default_p)
+        bc_indvars = bc_indvars,
+        eq_params = eq_params,
+        param_estim = param_estim,
+        default_p = default_p)
     u = get_u()
     _loss_function = @RuntimeGeneratedFunction(expr_loss_function)
-    loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u,
-                                                      default_p) end
+    loss_function = (cord, θ) -> begin
+        _loss_function(cord, θ, phi, derivative, integral, u,
+            default_p)
+    end
     return loss_function
 end
 
@@ -172,16 +176,16 @@ strategy.
 function generate_training_sets end
 
 function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, _indvars::Array,
-                                _depvars::Array)
+        _depvars::Array)
     depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars,
-                                                                               _depvars)
+        _depvars)
     return generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars,
-                                  dict_depvars)
+        dict_depvars)
 end
 
 # Generate training set in the domain and on the boundary
 function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::Dict,
-                                dict_depvars::Dict)
+        dict_depvars::Dict)
     if dx isa Array
         dxs = dx
     else
@@ -213,20 +217,20 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D
     bcs_train_sets = map(bound_args) do bt
         span = map(b -> get(dict_var_span, b, b), bt)
         _set = adapt(eltypeθ,
-                     hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
+            hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
     end
 
     pde_vars = get_variables(eqs, dict_indvars, dict_depvars)
     pde_args = get_argument(eqs, dict_indvars, dict_depvars)
 
     pde_train_set = adapt(eltypeθ,
-                          hcat(vec(map(points -> collect(points),
-                                       Iterators.product(bc_data...)))...))
+        hcat(vec(map(points -> collect(points),
+            Iterators.product(bc_data...)))...))
 
     pde_train_sets = map(pde_args) do bt
         span = map(b -> get(dict_var_span_, b, b), bt)
         _set = adapt(eltypeθ,
-                     hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
+            hcat(vec(map(points -> collect(points), Iterators.product(span...)))...))
     end
     [pde_train_sets, bcs_train_sets]
 end
@@ -241,19 +245,19 @@ function get_bounds end
 
 function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy)
     depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars,
-                                                                               _depvars)
+        _depvars)
     return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy)
 end
 
 function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array,
-                    strategy::QuadratureTraining)
+        strategy::QuadratureTraining)
     depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars,
-                                                                               _depvars)
+        _depvars)
     return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy)
 end
 
 function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars,
-                    strategy::QuadratureTraining)
+        strategy::QuadratureTraining)
     dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains])
     dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains])
 
@@ -285,9 +289,9 @@ end
 function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy)
     dx = 1 / strategy.points
     dict_span = Dict([Symbol(d.variables) => [
-                          infimum(d.domain) + dx,
-                          supremum(d.domain) - dx,
-                      ] for d in domains])
+        infimum(d.domain) + dx,
+        supremum(d.domain) - dx,
+    ] for d in domains])
 
     # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains]
     pde_args = get_argument(eqs, dict_indvars, dict_depvars)
@@ -330,7 +334,7 @@ function get_numeric_integral(pinnrep::PINNRepresentation)
                 ChainRulesCore.@ignore_derivatives lb_[i, :] = fill(l, 1, size(cord)[2])
             else
                 ChainRulesCore.@ignore_derivatives lb_[i, :] = l(cord, θ, phi, derivative,
-                                                                 nothing, u, nothing)
+                    nothing, u, nothing)
             end
         end
         for (i, u_) in enumerate(ub)
@@ -338,7 +342,7 @@ function get_numeric_integral(pinnrep::PINNRepresentation)
                 ChainRulesCore.@ignore_derivatives ub_[i, :] = fill(u_, 1, size(cord)[2])
             else
                 ChainRulesCore.@ignore_derivatives ub_[i, :] = u_(cord, θ, phi, derivative,
-                                                                  nothing, u, nothing)
+                    nothing, u, nothing)
             end
         end
         integration_arr = Matrix{Float64}(undef, 1, 0)
@@ -346,7 +350,7 @@ function get_numeric_integral(pinnrep::PINNRepresentation)
             # ub__ = @Zygote.ignore getindex(ub_, :,  i)
             # lb__ = @Zygote.ignore getindex(lb_, :,  i)
             integration_arr = hcat(integration_arr,
-                                   integration_(cord[:, i], lb_[:, i], ub_[:, i], θ))
+                integration_(cord[:, i], lb_[:, i], ub_[:, i], θ))
         end
         return integration_arr
     end
@@ -365,7 +369,7 @@ which is later optimized upon to give Solution or the Solution Distribution of t
 For more information, see `discretize` and `PINNRepresentation`.
 """
 function SciMLBase.symbolic_discretize(pde_system::PDESystem,
-    discretization::AbstractPINN)
+        discretization::AbstractPINN)
     eqs = pde_system.eqs
     bcs = pde_system.bcs
     chain = discretization.chain
@@ -381,7 +385,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
     adaloss = discretization.adaptive_loss
 
     depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(pde_system.indvars,
-                                                                               pde_system.depvars)
+        pde_system.depvars)
 
     multioutput = discretization.multioutput
     init_params = discretization.init_params
@@ -393,7 +397,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
         if chain isa AbstractArray
             x = map(chain) do x
                 _x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(),
-                                                                          x))
+                    x))
                 Float64.(_x) # No ComponentArray GPU support
             end
             names = ntuple(i -> depvars[i], length(chain))
@@ -401,7 +405,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
                                                                            for i in x))
         else
             init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(),
-                                                                                            chain)))
+                chain)))
         end
     else
         init_params = init_params
@@ -436,11 +440,11 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
     if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer)
         for ϕ in phi
             ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)),
-                         ϕ.st)
+                ϕ.st)
         end
     elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer)
         phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)),
-                       phi.st)
+            phi.st)
     end
 
     derivative = discretization.derivative
@@ -471,24 +475,24 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
     bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars)
 
     pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p,
-                                 param_estim, additional_loss, adaloss, depvars, indvars,
-                                 dict_indvars, dict_depvars, dict_depvar_input, logger,
-                                 multioutput, iteration, init_params, flat_init_params, phi,
-                                 derivative,
-                                 strategy, pde_indvars, bc_indvars, pde_integration_vars,
-                                 bc_integration_vars, nothing, nothing, nothing, nothing)
+        param_estim, additional_loss, adaloss, depvars, indvars,
+        dict_indvars, dict_depvars, dict_depvar_input, logger,
+        multioutput, iteration, init_params, flat_init_params, phi,
+        derivative,
+        strategy, pde_indvars, bc_indvars, pde_integration_vars,
+        bc_integration_vars, nothing, nothing, nothing, nothing)
 
     integral = get_numeric_integral(pinnrep)
 
     symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq;
-                                                                bc_indvars = pde_indvar)
+        bc_indvars = pde_indvar)
                                    for (eq, pde_indvar) in zip(eqs, pde_indvars,
-                                                               pde_integration_vars)]
+        pde_integration_vars)]
 
     symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc;
-                                                               bc_indvars = bc_indvar)
+        bc_indvars = bc_indvar)
                                   for (bc, bc_indvar) in zip(bcs, bc_indvars,
-                                                             bc_integration_vars)]
+        bc_integration_vars)]
 
     pinnrep.integral = integral
     pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions
@@ -496,18 +500,18 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
 
     datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar)
                                    for (eq, pde_indvar, integration_indvar) in zip(eqs,
-                                                                                   pde_indvars,
-                                                                                   pde_integration_vars)]
+        pde_indvars,
+        pde_integration_vars)]
 
     datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar)
                                   for (bc, bc_indvar, integration_indvar) in zip(bcs,
-                                                                                 bc_indvars,
-                                                                                 bc_integration_vars)]
+        bc_indvars,
+        bc_integration_vars)]
 
     pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep,
-                                                                              strategy,
-                                                                              datafree_pde_loss_functions,
-                                                                              datafree_bc_loss_functions)
+        strategy,
+        datafree_pde_loss_functions,
+        datafree_bc_loss_functions)
     # setup for all adaptive losses
     num_pde_losses = length(pde_loss_functions)
     num_bc_losses = length(bc_loss_functions)
@@ -523,8 +527,8 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
                                       adaloss.additional_loss_weights
 
     reweight_losses_func = generate_adaptive_loss_function(pinnrep, adaloss,
-                                                           pde_loss_functions,
-                                                           bc_loss_functions)
+        pde_loss_functions,
+        bc_loss_functions)
 
     function get_likelihood_estimate_function(discretization::PhysicsInformedNN)
         function full_loss_function(θ, p)
@@ -548,7 +552,8 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
 
             sum_weighted_pde_losses = sum(weighted_pde_losses)
             sum_weighted_bc_losses = sum(weighted_bc_losses)
-            weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses
+            weighted_loss_before_additional = sum_weighted_pde_losses +
+                                              sum_weighted_bc_losses
 
             full_weighted_loss = if additional_loss isa Nothing
                 weighted_loss_before_additional
@@ -562,7 +567,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
                     return additional_loss(phi, θ_, p_)
                 end
                 weighted_additional_loss_val = adaloss.additional_loss_weights[1] *
-                                            _additional_loss(phi, θ)
+                                               _additional_loss(phi, θ)
                 weighted_loss_before_additional + weighted_additional_loss_val
             end
 
@@ -608,15 +613,17 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
 
     function get_likelihood_estimate_function(discretization::BayesianPINN)
         dataset_pde, dataset_bc = discretization.dataset
-        
+
         # required as Physics loss also needed on the discrete dataset domain points
         # data points are discrete and so by default GridTraining loss applies
         # passing placeholder dx with GridTraining, it uses data points irl
-        datapde_loss_functions, databc_loss_functions = if (!(dataset_bc isa Nothing)||!(dataset_pde isa Nothing))
+        datapde_loss_functions, databc_loss_functions = if (!(dataset_bc isa Nothing) ||
+                                                            !(dataset_pde isa Nothing))
             merge_strategy_with_loglikelihood_function(pinnrep,
                 GridTraining(0.1),
                 datafree_pde_loss_functions,
-                datafree_bc_loss_functions, train_sets_pde = dataset_pde, train_sets_bc = dataset_bc)
+                datafree_bc_loss_functions, train_sets_pde = dataset_pde,
+                train_sets_bc = dataset_bc)
         else
             (nothing, nothing)
         end
@@ -625,32 +632,30 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
             stdpdes, stdbcs, stdextra = allstd
             # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
             pde_loglikelihoods = [logpdf(Normal(0, stdpdes[i]), pde_loss_function(θ))
-                                for (i, pde_loss_function) in enumerate(pde_loss_functions)]
+                                  for (i, pde_loss_function) in enumerate(pde_loss_functions)]
 
             bc_loglikelihoods = [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ))
-                                for (j, bc_loss_function) in enumerate(bc_loss_functions)]
+                                 for (j, bc_loss_function) in enumerate(bc_loss_functions)]
 
             if !(datapde_loss_functions isa Nothing)
                 pde_loglikelihoods += [logpdf(Normal(0, stdpdes[j]), pde_loss_function(θ))
-                                    for (j, pde_loss_function) in enumerate(datapde_loss_functions)]
-
+                                       for (j, pde_loss_function) in enumerate(datapde_loss_functions)]
             end
 
             if !(databc_loss_functions isa Nothing)
-                bc_loglikelihoods += [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ)) 
-                                    for (j, bc_loss_function) in enumerate(databc_loss_functions)]
+                bc_loglikelihoods += [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ))
+                                      for (j, bc_loss_function) in enumerate(databc_loss_functions)]
             end
 
             # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized
             # that's why we prefer the user to maintain the increment in the outer loop callback during optimization
-            ChainRulesCore.@ignore_derivatives if self_increment
-                iteration[1] += 1
-            end
 
-            ChainRulesCore.@ignore_derivatives begin
-                reweight_losses_func(θ, pde_loglikelihoods,
-                    bc_loglikelihoods)
-            end
+            # println("terminate increment")
+            # ChainRulesCore.@ignore_derivatives if iteration[1] > 100
+            #     self_increment=false
+            # end
+            pinnrep.iteration[1] += 1
+            reweight_losses_func(θ, pde_loglikelihoods, bc_loglikelihoods)
 
             weighted_pde_loglikelihood = adaloss.pde_loss_weights .* pde_loglikelihoods
             weighted_bc_loglikelihood = adaloss.bc_loss_weights .* bc_loglikelihoods
@@ -658,7 +663,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
             sum_weighted_pde_loglikelihood = sum(weighted_pde_loglikelihood)
             sum_weighted_bc_loglikelihood = sum(weighted_bc_loglikelihood)
             weighted_loglikelihood_before_additional = sum_weighted_pde_loglikelihood +
-                                                    sum_weighted_bc_loglikelihood
+                                                       sum_weighted_bc_loglikelihood
 
             full_weighted_loglikelihood = if additional_loss isa Nothing
                 weighted_loglikelihood_before_additional
@@ -678,9 +683,46 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
                 weighted_additional_loglikelihood = adaloss.additional_loss_weights[1] *
                                                     _additional_loglikelihood
 
-                    weighted_loglikelihood_before_additional + weighted_additional_loglikelihood
+                weighted_loglikelihood_before_additional + weighted_additional_loglikelihood
             end
 
+            # println("full_weighted_loglikelihood : ", full_weighted_loglikelihood)
+
+            # ChainRulesCore.@ignore_derivatives begin
+            #     println(" inside lower chainrules logging log_frequency part ")
+            #     if iteration[1] % log_frequency == 0
+            #         logvector(pinnrep.logger, pde_loglikelihoods, "unweighted_likelihood/pde_loglikelihoods",
+            #             iteration[1])
+            #         logvector(pinnrep.logger,
+            #             bc_loglikelihoods,
+            #             "unweighted_likelihood/bc_loglikelihoods",
+            #             iteration[1])
+            #         logvector(pinnrep.logger, weighted_pde_loglikelihood,
+            #             "weighted_likelihood/weighted_pde_loglikelihood",
+            #             iteration[1])
+            #         logvector(pinnrep.logger, weighted_bc_loglikelihood,
+            #             "weighted_likelihood/weighted_bc_loglikelihood",
+            #             iteration[1])
+            #         if !(additional_loss isa Nothing)
+            #             logscalar(pinnrep.logger, weighted_additional_loglikelihood,
+            #                 "weighted_likelihood/weighted_additional_loglikelihood", iteration[1])
+            #         end
+            #         logscalar(pinnrep.logger, sum_weighted_pde_loglikelihood,
+            #             "weighted_likelihood/sum_weighted_pde_loglikelihood", iteration[1])
+            #         logscalar(pinnrep.logger, sum_weighted_bc_loglikelihood,
+            #             "weighted_likelihood/sum_weighted_bc_loglikelihood", iteration[1])
+            #         logscalar(pinnrep.logger, full_weighted_loglikelihood,
+            #             "weighted_likelihood/full_weighted_loglikelihood",
+            #             iteration[1])
+            #         logvector(pinnrep.logger, adaloss.pde_loss_weights,
+            #             "adaptive_loss/pde_loss_weights",
+            #             iteration[1])
+            #         logvector(pinnrep.logger, adaloss.bc_loss_weights,
+            #             "adaptive_loss/bc_loss_weights",
+            #             iteration[1])
+            #     end
+            # end
+
             return full_weighted_loglikelihood
         end
 
@@ -689,12 +731,11 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
 
     full_loss_function = get_likelihood_estimate_function(discretization)
     pinnrep.loss_functions = PINNLossFunctions(bc_loss_functions, pde_loss_functions,
-                                                full_loss_function, additional_loss, 
-                                                datafree_pde_loss_functions,
-                                                datafree_bc_loss_functions)
+        full_loss_function, additional_loss,
+        datafree_pde_loss_functions,
+        datafree_bc_loss_functions)
 
     return pinnrep
-
 end
 
 """
@@ -707,6 +748,6 @@ solution is the solution to the PDE.
 function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN)
     pinnrep = symbolic_discretize(pde_system, discretization)
     f = OptimizationFunction(pinnrep.loss_functions.full_loss_function,
-                             Optimization.AutoZygote())
+        Optimization.AutoZygote())
     Optimization.OptimizationProblem(f, pinnrep.flat_init_params)
 end
diff --git a/test/BPINN_PDE_tests.jl b/test/BPINN_PDE_tests.jl
index 6dd3637f5a..9d7fb1f9d2 100644
--- a/test/BPINN_PDE_tests.jl
+++ b/test/BPINN_PDE_tests.jl
@@ -206,3 +206,452 @@ end
     u_predict = pmean(sol1.ensemblesol[1])
     @test u_predict≈u_real atol=0.8
 end
+
+@testset "Example 1: 2D Periodic System" begin
+    # Cos(pi*t) example
+    @parameters t
+    @variables u(..)
+    Dt = Differential(t)
+    eqs = Dt(u(t)) - cos(2 * π * t) ~ 0
+    bcs = [u(0) ~ 0.0]
+    domains = [t ∈ Interval(0.0, 2.0)]
+    chainl = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 1))
+    initl, st = Lux.setup(Random.default_rng(), chainl)
+    @named pde_system = PDESystem(eqs, bcs, domains, [t], [u(t)])
+
+    # non adaptive case
+    discretization = BayesianPINN([chainl], GridTraining([0.01]))
+
+    sol1 = ahmc_bayesian_pinn_pde(pde_system,
+        discretization;
+        draw_samples = 1500,
+        bcstd = [0.02],
+        phystd = [0.01],
+        priorsNNw = (0.0, 1.0),
+        saveats = [1 / 50.0])
+
+    analytic_sol_func(u0, t) = u0 + sin(2 * π * t) / (2 * π)
+    ts = vec(sol1.timepoints[1])
+    u_real = [analytic_sol_func(0.0, t) for t in ts]
+    u_predict = pmean(sol1.ensemblesol[1])
+    @test u_predict≈u_real atol=0.5
+    @test mean(u_predict .- u_real) < 0.1
+end
+
+@testset "Example 2: 1D ODE" begin
+    @parameters θ
+    @variables u(..)
+    Dθ = Differential(θ)
+
+    # 1D ODE
+    eq = Dθ(u(θ)) ~ θ^3 + 2 * θ + (θ^2) * ((1 + 3 * (θ^2)) / (1 + θ + (θ^3))) -
+                    u(θ) * (θ + ((1 + 3 * (θ^2)) / (1 + θ + θ^3)))
+
+    # Initial and boundary conditions
+    bcs = [u(0.0) ~ 1.0]
+
+    # Space and time domains
+    domains = [θ ∈ Interval(0.0, 1.0)]
+
+    # Neural network
+    chain = Lux.Chain(Lux.Dense(1, 12, Lux.σ), Lux.Dense(12, 1))
+
+    discretization = BayesianPINN([chain], GridTraining([0.01]))
+
+    @named pde_system = PDESystem(eq, bcs, domains, [θ], [u])
+
+    sol1 = ahmc_bayesian_pinn_pde(pde_system,
+        discretization;
+        draw_samples = 500,
+        bcstd = [0.1],
+        phystd = [0.05],
+        priorsNNw = (0.0, 10.0),
+        saveats = [1 / 100.0])
+
+    analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2
+    ts = sol1.timepoints[1]
+    u_real = vec([analytic_sol_func(t) for t in ts])
+    u_predict = pmean(sol1.ensemblesol[1])
+    @test u_predict≈u_real atol=0.8
+end
+
+@testset "Example 3: 3rd Degree ODE" begin
+    @parameters x
+    @variables u(..), Dxu(..), Dxxu(..), O1(..), O2(..)
+    Dxxx = Differential(x)^3
+    Dx = Differential(x)
+
+    # ODE
+    eq = Dx(Dxxu(x)) ~ cos(pi * x)
+
+    # Initial and boundary conditions
+    ep = (cbrt(eps(eltype(Float64))))^2 / 6
+
+    bcs = [u(0.0) ~ 0.0,
+        u(1.0) ~ cos(pi),
+        Dxu(1.0) ~ 1.0,
+        Dxu(x) ~ Dx(u(x)) + ep * O1(x),
+        Dxxu(x) ~ Dx(Dxu(x)) + ep * O2(x)]
+
+    # Space and time domains
+    domains = [x ∈ Interval(0.0, 1.0)]
+
+    # Neural network
+    chain = [
+        Lux.Chain(Lux.Dense(1, 10, Lux.tanh), Lux.Dense(10, 10, Lux.tanh),
+            Lux.Dense(10, 1)), Lux.Chain(Lux.Dense(1, 10, Lux.tanh), Lux.Dense(10, 10, Lux.tanh),
+            Lux.Dense(10, 1)), Lux.Chain(Lux.Dense(1, 10, Lux.tanh), Lux.Dense(10, 10, Lux.tanh),
+            Lux.Dense(10, 1)), Lux.Chain(Lux.Dense(1, 4, Lux.tanh), Lux.Dense(4, 1)),
+        Lux.Chain(Lux.Dense(1, 4, Lux.tanh), Lux.Dense(4, 1))]
+
+    discretization = BayesianPINN(chain, GridTraining(0.01))
+
+    @named pde_system = PDESystem(eq, bcs, domains, [x],
+        [u(x), Dxu(x), Dxxu(x), O1(x), O2(x)])
+
+    sol1 = ahmc_bayesian_pinn_pde(pde_system,
+        discretization;
+        draw_samples = 200,
+        bcstd = [0.01, 0.01, 0.01, 0.01, 0.01],
+        phystd = [0.005],
+        priorsNNw = (0.0, 10.0),
+        saveats = [1 / 100.0])
+
+    analytic_sol_func(x) = (π * x * (-x + (π^2) * (2 * x - 3) + 1) - sin(π * x)) / (π^3)
+
+    u_predict = pmean(sol1.ensemblesol[1])
+    xs = vec(sol1.timepoints[1])
+    u_real = [analytic_sol_func(x) for x in xs]
+    @test u_predict≈u_real atol=0.5
+end
+
+@testset "Example 4: 2D Poissons equation" begin
+    @parameters x y
+    @variables u(..)
+    Dxx = Differential(x)^2
+    Dyy = Differential(y)^2
+
+    # 2D PDE
+    eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y)
+
+    # Boundary conditions
+    bcs = [u(0, y) ~ 0.0, u(1, y) ~ 0.0,
+        u(x, 0) ~ 0.0, u(x, 1) ~ 0.0]
+
+    # Space and time domains
+    domains = [x ∈ Interval(0.0, 1.0),
+        y ∈ Interval(0.0, 1.0)]
+
+    # Neural network
+    dim = 2 # number of dimensions
+    chain = Lux.Chain(Lux.Dense(dim, 9, Lux.σ), Lux.Dense(9, 9, Lux.σ), Lux.Dense(9, 1))
+
+    # Discretization
+    dx = 0.04
+    discretization = BayesianPINN([chain], GridTraining(dx))
+
+    @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
+
+    sol1 = ahmc_bayesian_pinn_pde(pde_system,
+        discretization;
+        draw_samples = 200,
+        bcstd = [0.003, 0.003, 0.003, 0.003],
+        phystd = [0.003],
+        priorsNNw = (0.0, 10.0),
+        saveats = [1 / 100.0, 1 / 100.0])
+
+    xs = sol1.timepoints[1]
+    analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2)
+
+    u_predict = pmean(sol1.ensemblesol[1])
+    u_real = [analytic_sol_func(xs[:, i][1], xs[:, i][2]) for i in 1:length(xs[1, :])]
+    @test u_predict≈u_real atol=1.5
+end
+
+@testset "Translating from Flux" begin
+    @parameters θ
+    @variables u(..)
+    Dθ = Differential(θ)
+
+    # 1D ODE
+    eq = Dθ(u(θ)) ~ θ^3 + 2 * θ + (θ^2) * ((1 + 3 * (θ^2)) / (1 + θ + (θ^3))) -
+                    u(θ) * (θ + ((1 + 3 * (θ^2)) / (1 + θ + θ^3)))
+
+    # Initial and boundary conditions
+    bcs = [u(0.0) ~ 1.0]
+
+    # Space and time domains
+    domains = [θ ∈ Interval(0.0, 1.0)]
+
+    # Neural network
+    chain = Flux.Chain(Flux.Dense(1, 12, Flux.σ), Flux.Dense(12, 1))
+
+    discretization = BayesianPINN([chain], GridTraining([0.01]))
+    @test discretization.chain[1] isa Lux.AbstractExplicitLayer
+
+    @named pde_system = PDESystem(eq, bcs, domains, [θ], [u])
+
+    sol1 = ahmc_bayesian_pinn_pde(pde_system,
+        discretization;
+        draw_samples = 500,
+        bcstd = [0.1],
+        phystd = [0.05],
+        priorsNNw = (0.0, 10.0),
+        saveats = [1 / 100.0])
+
+    analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2
+    ts = sol1.timepoints[1]
+    u_real = vec([analytic_sol_func(t) for t in ts])
+    u_predict = pmean(sol1.ensemblesol[1])
+    @test u_predict≈u_real atol=0.8
+end
+
+using NeuralPDE, Flux, Lux, ModelingToolkit, LinearAlgebra, AdvancedHMC
+import ModelingToolkit: Interval, infimum, supremum, Distributions
+
+@parameters x, t, α
+@variables u(..)
+Dt = Differential(t)
+Dx = Differential(x)
+Dx2 = Differential(x)^2
+Dx3 = Differential(x)^3
+Dx4 = Differential(x)^4
+
+# α = 1
+β = 4
+γ = 1
+eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0
+
+u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
+du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2
+
+bcs = [u(x, 0) ~ u_analytic(x, 0),
+    u(-10, t) ~ u_analytic(-10, t),
+    u(10, t) ~ u_analytic(10, t),
+    Dx(u(-10, t)) ~ du(-10, t),
+    Dx(u(10, t)) ~ du(10, t)]
+
+# Space and time domains
+domains = [x ∈ Interval(-10.0, 10.0),
+    t ∈ Interval(0.0, 1.0)]
+
+# Discretization
+dx = 0.4;
+dt = 0.2;
+
+# Function to compute analytical solution at a specific point (x, t)
+function u_analytic_point(x, t)
+    z = -x / 2 + t
+    return 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
+end
+
+# Function to generate the dataset matrix
+function generate_dataset_matrix(domains, dx, dt)
+    x_values = -10:dx:10
+    t_values = 0.0:dt:1.0
+
+    dataset = []
+
+    for t in t_values
+        for x in x_values
+            u_value = u_analytic_point(x, t)
+            push!(dataset, [u_value, x, t])
+        end
+    end
+
+    return vcat([data' for data in dataset]...)
+end
+
+using Plots, MonteCarloMeasurements, StatsPlots
+plotly()
+
+datasetpde = [generate_dataset_matrix(domains, dx, dt)]
+plot(datasetpde[1][:, 2], datasetpde[1][:, 1])
+
+# Add noise to dataset
+datasetpde[1][:, 1] = datasetpde[1][:, 1] .+
+                      randn(size(datasetpde[1][:, 1])) .* 5 / 100 .*
+                      datasetpde[1][:, 1]
+plot!(datasetpde[1][:, 2], datasetpde[1][:, 1])
+
+# Neural network
+chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh),
+    Lux.Dense(8, 8, Lux.tanh),
+    Lux.Dense(8, 1))
+
+discretization = NeuralPDE.BayesianPINN([chain],
+    adaptive_loss = GradientScaleAdaptiveLoss(5),
+    # MiniMaxAdaptiveLoss(5),
+    GridTraining([dx, dt]), param_estim = true, dataset = [datasetpde, nothing])
+@named pde_system = PDESystem(eq,
+    bcs,
+    domains,
+    [x, t],
+    [u(x, t)],
+    [α],
+    defaults = Dict([α => 0.5]))
+
+sol1 = ahmc_bayesian_pinn_pde(pde_system,
+    discretization;
+    draw_samples = 1000,
+     Kernel = AdvancedHMC.NUTS(0.80),
+    bcstd = [1.0, 1.0, 1.0, 1.0, 1.0],
+    phystd = [0.1], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)],
+    priorsNNw = (0.0, 10.0),
+    saveats = [1 / 100.0, 1 / 100.0], progress = true)
+
+phi = discretization.phi[1]
+xs, ts = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, [dx / 10, dt])]
+u_predict = [[first(pmean(phi([x, t], sol1.estimated_nn_params[1]))) for x in xs]
+             for t in ts]
+u_real = [[u_analytic(x, t) for x in xs] for t in ts]
+diff_u = [[abs(u_analytic(x, t) - first(pmean(phi([x, t], sol1.estimated_nn_params[1]))))
+           for x in xs]
+          for t in ts]
+
+p1 = plot(xs, u_predict, title = "predict")
+p2 = plot(xs, u_real, title = "analytic")
+p3 = plot(xs, diff_u, title = "error")
+plot(p1, p2, p3)
+
+using NeuralPDE, Flux, Lux, ModelingToolkit, LinearAlgebra, AdvancedHMC
+import ModelingToolkit: Interval, infimum, supremum, Distributions
+using Plots, MonteCarloMeasurements, StatsPlots
+
+@parameters x, t, α
+@variables u(..)
+Dt = Differential(t)
+Dx = Differential(x)
+Dx2 = Differential(x)^2
+Dx3 = Differential(x)^3
+Dx4 = Differential(x)^4
+
+# α = 1
+β = 4
+γ = 1
+eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0
+
+u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
+du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2
+
+bcs = [u(x, 0) ~ u_analytic(x, 0),
+    u(-10, t) ~ u_analytic(-10, t),
+    u(10, t) ~ u_analytic(10, t),
+    Dx(u(-10, t)) ~ du(-10, t),
+    Dx(u(10, t)) ~ du(10, t)]
+
+# Space and time domains
+domains = [x ∈ Interval(-10.0, 10.0),
+    t ∈ Interval(0.0, 1.0)]
+
+# Discretization
+dx = 0.4;
+dt = 0.2;
+
+# Function to compute analytical solution at a specific point (x, t)
+function u_analytic_point(x, t)
+    z = -x / 2 + t
+    return 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
+end
+
+# Function to generate the dataset matrix
+function generate_dataset_matrix(domains, dx, dt)
+    x_values = -10:dx:10
+    t_values = 0.0:dt:1.0
+
+    dataset = []
+
+    for t in t_values
+        for x in x_values
+            u_value = u_analytic_point(x, t)
+            push!(dataset, [u_value, x, t])
+        end
+    end
+
+    return vcat([data' for data in dataset]...)
+end
+
+datasetpde = [generate_dataset_matrix(domains, dx, dt)]
+plot(datasetpde[1][:, 2], datasetpde[1][:, 1], title = "Dataset from Analytical Solution")
+
+# Add noise to dataset
+datasetpde[1][:, 1] = datasetpde[1][:, 1] .+
+                      randn(size(datasetpde[1][:, 1])) .* 5 / 100 .*
+                      datasetpde[1][:, 1]
+plot!(datasetpde[1][:, 2], datasetpde[1][:, 1])
+
+function CostFun(x::AbstractVector{T}) where {T}
+    function SpringEqu!(du, u, x, t)
+        du[1] = u[2]
+        du[2] = -(x[1] / x[3]) * u[2] - (x[2] / x[3]) * u[1] + 50 / x[3]
+    end
+
+    u0 = T[2.0, 0.0]
+    tspan = (0.0, 1.0)
+    prob = ODEProblem(SpringEqu!, u0, tspan, x)
+    sol = solve(prob)
+
+    Simpos = zeros(T, length(sol.t))
+    Simvel = zeros(T, length(sol.t))
+    tout = zeros(T, length(sol.t))
+    for i in 1:length(sol.t)
+        tout[i] = sol.t[i]
+        Simpos[i] = sol[1, i]
+        Simvel[i] = sol[2, i]
+    end
+
+    totalCost = sum(Simpos)
+    return totalCost
+end
+
+using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL
+import ModelingToolkit: Interval, infimum, supremum
+
+@parameters x, t
+@variables u(..)
+Dt = Differential(t)
+Dx = Differential(x)
+Dx2 = Differential(x)^2
+Dx3 = Differential(x)^3
+Dx4 = Differential(x)^4
+
+α = 1
+β = 4
+γ = 1
+eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0
+
+u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
+du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2
+
+bcs = [u(x, 0) ~ u_analytic(x, 0),
+    u(-10, t) ~ u_analytic(-10, t),
+    u(10, t) ~ u_analytic(10, t),
+    Dx(u(-10, t)) ~ du(-10, t),
+    Dx(u(10, t)) ~ du(10, t)]
+
+# Space and time domains
+domains = [x ∈ Interval(-10.0, 10.0),
+    t ∈ Interval(0.0, 1.0)]
+# Discretization
+dx = 0.4;
+dt = 0.2;
+
+# Neural network
+chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh),
+    Lux.Dense(8, 8, Lux.tanh),
+    Lux.Dense(8, 1))
+
+discretization = PhysicsInformedNN(chain,
+    adaptive_loss = GradientScaleAdaptiveLoss(1),
+    GridTraining([dx, dt]))
+@named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)])
+prob = discretize(pde_system, discretization)
+
+callback = function (p, l)
+    println("Current loss is: $l")
+    return false
+end
+
+opt = OptimizationOptimJL.BFGS()
+res = Optimization.solve(prob, opt; callback = callback, maxiters = 100)
+phi = discretization.phi
\ No newline at end of file