Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push more computations into node functions to improve performance #235

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Nov 11, 2024

Current Performance

On the Apple M2 chip with the Rats example, a manually unrolled log density function has a median benchmark time of 1.113 μs, while the current master branch at 50.916 μs. In principle, we should be able to push the performance to within 10 times of the unrolled version.

Performance Analysis

Screenshot 2024-11-11 at 7 58 01 AM

Profiling reveals type instabilities in:

  1. Bijector transformations
  2. BangBang's setindex! operations
  3. logpdf computation due to runtime distribution determination

The root cause appears to be storing node functions in Vector{Any}, which disallows type inference.

Proposed Solution

Put bijector, setindex!, and logpdf computations directly into node functions where distribution types are known at compile time.

In master branch: (consider alpha[1], which is a stochastic variable

function (evaluation_env, loop_vars)
      (; var"alpha.c", var"alpha.tau") = evaluation_env
      (;) = loop_vars
      return dnorm(var"alpha.c", var"alpha.tau")
 end

In this PR:

function (__evaluation_env__::NamedTuple{__vars__}, __loop_vars__::NamedTuple{__loop_vars_names__}, __vn__::AbstractPPL.VarName, __is_transformed__::Bool, __is_observed__::Bool, __params__::AbstractVector{<:Real}) where {__vars__, __loop_vars_names__}
      (; var"alpha.c", var"alpha.tau") = __evaluation_env__
      (;) = __loop_vars__
      __dist__ = dnorm(var"alpha.c", var"alpha.tau")
      if !__is_observed__
          if __is_transformed__
              __b__ = Bijectors.bijector(__dist__)
              __b_inv__ = Bijectors.inverse(__b__)
              __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, __params__)
              (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
          else
              (__value__, __logjac__) = (JuliaBUGS.reconstruct(__dist__, __params__), 0.0)
          end
          __evaluation_env__ = BangBang.setindex!!(__evaluation_env__, __value__, __vn__)
          __logp__ = Distributions.logpdf(__dist__, __value__) + __logjac__
      else
          __logp__ = Distributions.logpdf(__dist__, AbstractPPL.get(__evaluation_env__, __vn__))
      end
      return (__logp__::Float64, __evaluation_env__::NamedTuple{__vars__})
  end

On a high level, previously, node functions are pure functions, now node function models effects to the environment.

The tradeoffs are: (1) Increased complexity in node functions, (2)the log density evaluation is less readable.

Challenges

Type instability is still a problem: in the scope of evaluate!!, node function types are still not known by the compiler, because the node functions are still stored in a Vector{Any}. To push this forward, maybe we would need to use FunctionWrappers for type stability.

@sunxd3 sunxd3 marked this pull request as draft November 11, 2024 07:28
@sunxd3 sunxd3 self-assigned this Nov 11, 2024
@sunxd3
Copy link
Member Author

sunxd3 commented Nov 14, 2024

The current approach to computing log density in JuliaBUGS faces fundamental limitations that make further performance optimization really difficult.

Consider a simplified version of our evaluation logic:

function compute_logp(model, values, params)
   # values is a NamedTuple storing all variable values (`evaluation_env` in `BUGSModel`)
   # vns is a Vector{VarName} in topological order
   logp = 0.0
   
   for (vn, node_fn) in zip(model.vns, model.node_functions)
       # the compiler only knows `vn` is a `VarName` but doesn't know the `sym`
       if is_deterministic(vn)
           # `node_fn` here calls `BangBang.setindex!!` on values and return, so the type of `values` can't be known at compile time
           values = node_fn(values)
       else
           values, logjac = node_fn(values, params) # similar issue here
           logp += logpdf(dist, param_values) + logjac
       end
   end
   return logp
end

There are several fundamental type stability issues that prevent further optimization. First, the vns vector cannot be type-stable except in trivial single-variable cases, since each VarName{sym} has a different concrete type (e.g., VarName{:x} vs VarName{:y}).

The second major issue arises from updates to values. Each node_fn call potentially modifies the type of values. We don't want to make the type too restrictive because: (1) BUGS' implicit casting between Int and Float and (2) ReverseDiff's DualNumber types during gradient computation. So although we know the variables names will not change, their values might.

To push for more performance, we would have to opt in for either FunctionWrapper or MistyClosure, but this means we will be limited in terms of the AD backend we could support (pretty much only Enzyme and Mooncake, where the latter needs to be modified). This, in my opinion, would be a bad solution, because it will be hard to understand and maintain.

So, the natural next step would be start generating Julia code directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant