Push more computations into node functions to improve performance #235
+297
−161
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Current Performance
On the Apple M2 chip with the
Rats
example, a manually unrolled log density function has a median benchmark time of 1.113 μs, while the current master branch at 50.916 μs. In principle, we should be able to push the performance to within 10 times of the unrolled version.Performance Analysis
Profiling reveals type instabilities in:
setindex!
operationslogpdf
computation due to runtime distribution determinationThe root cause appears to be storing node functions in
Vector{Any}
, which disallows type inference.Proposed Solution
Put
bijector
,setindex!
, andlogpdf
computations directly into node functions where distribution types are known at compile time.In master branch: (consider
alpha[1]
, which is a stochastic variableIn this PR:
On a high level, previously, node functions are pure functions, now node function models effects to the environment.
The tradeoffs are: (1) Increased complexity in node functions, (2)the log density evaluation is less readable.
Challenges
Type instability is still a problem: in the scope of
evaluate!!
, node function types are still not known by the compiler, because the node functions are still stored in aVector{Any}
. To push this forward, maybe we would need to useFunctionWrappers
for type stability.