Skip to content

Add get_merged_chains function #409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ChristianMichelsen
Copy link

I have tried to add some small helper functions that allows one to merge chains of a model with the generated quantities.

For more information regarding the background of this, please see this Discourse thread.

I am very new to the internals of DynamicPPL.jl, so this is most likely not the optimal implementation, but I was encouraged on Slack to make this PR anyways.
Hope this can help in some way.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ChristianMichelsen!

I think we should make it simpler to work with the output of generated_quantities if possible. Unfortunately, I think there are some issues with this PR.

Most importantly, DynamicPPL does not depend on MCMCChains but only AbstractMCMC and hence MCMCChains-specific functionality does not seem to belong to DynamicPPL.

Regarding the implementation some high-level suggestions are:

  • Reduce the number of functions as it seems many functions are only called in one place
  • Avoid names such as get_...
  • Avoid passing around Dicts but instead use multiple function arguments whenever possible (also reducing the number of functions will help here)
  • Add tests and make tests in docstrings doctests by using jldoctest instead of julia

Maybe an alternative approach here could be to add conversion methods convert(::Type{<:MCMCChains.Chains}, ::Matrix{<:NamedTuple}) etc. to MCMCChains. Then it would be easy to obtain a Chains object if desired (it's not always desired, the generated quantities don't even have to be arrays or scalars but could be of some other quite arbitrary type). MCMCChains also already supports concatenation of chains, so it would be also possible to concatenate the resulting chain with the original chain if desired.

@ChristianMichelsen
Copy link
Author

Thanks for your reply, @devmotion!

It seems like my code should maybe be taken somewhere else. Do you have any suggestions? Turing.jl? MCMCChains.jl?

Regarding your suggestions, I have a couple of comments:

  • I thought that small functions were good, both from a compiler standpoint, but also just for readability. I have changed that now.
  • Is compute_ e.g. better? Or what are you suggesting?
  • I have changed this as per your suggestion.
  • I have no experience with testing in Julia, so I'll have to add that sometimes later.

My updated code is below, now as a single module:

module MergeChains

export merge


import Turing


function get_generated_quantities(model::Turing.Model, chains::Turing.Chains)
    chains_params = Turing.MCMCChains.get_sections(chains, :parameters)
    generated_quantities = Turing.generated_quantities(model, chains_params)
    return generated_quantities
end


function generated_quantities_to_chain(
    generated_quantities::AbstractMatrix,
    chains::Turing.Chains,
    variable::Union{Symbol,String},
)

    # The number of dimensions (K) for the specific variable
    K = length(first(generated_quantities)[variable])
    N_samples = length(chains)
    N_chains = length(Turing.chains(chains))

    matrix = zeros(N_samples, K, N_chains)
    for chain = 1:N_chains
        for (i, xi) in enumerate(generated_quantities[:, chain])
            matrix[i, :, chain] .= xi[variable]
        end
    end

    if K == 1
        chain_names = [Symbol("$variable")]
    else
        chain_names = [Symbol("$variable[$i]") for i = 1:K]
    end
    generated_chain = Turing.Chains(matrix, chain_names, info = chains.info)

    return generated_chain
end


function generated_quantities_to_chain(
    generated_quantities::AbstractMatrix,
    chains::Turing.Chains,
    variables::Tuple,
)
    func = variable -> generated_quantities_to_chain(generated_quantities, chains, variable)
    return hcat(func.(variables)...)
end


function merge_generated_chains(chains::Turing.Chains, generated_chains::Turing.Chains)
    return hcat(chains, Turing.setrange(generated_chains, range(chains)))
end


function merge(model::Turing.Model, chains::Turing.Chains)

    generated_quantities = get_generated_quantities(model, chains)

    if generated_quantities isa Matrix{Nothing}
        return chains
    end

    variables = generated_quantities |> first |> keys
    generated_chains =
        generated_quantities_to_chain(generated_quantities, chains, variables)

    chains_merged = merge_generated_chains(chains, generated_chains)

    return chains_merged

end

end # module

@codecov
Copy link

codecov bot commented Jul 4, 2023

Codecov Report

Patch coverage has no change and project coverage change: -1.17 ⚠️

Comparison is base (e6dd4ef) 76.40% compared to head (67e433a) 75.24%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #409      +/-   ##
==========================================
- Coverage   76.40%   75.24%   -1.17%     
==========================================
  Files          21       22       +1     
  Lines        2522     2561      +39     
==========================================
  Hits         1927     1927              
- Misses        595      634      +39     
Impacted Files Coverage Δ
src/merge_generated_quantities.jl 0.00% <0.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@yebai
Copy link
Member

yebai commented May 5, 2024

Closed in favour of #594

@yebai yebai closed this May 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants