-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding StatsBase.predict to the API #466
Comments
Maybe this could even be part of AbstractPPL and be defined on |
Yeah, makes sense. |
I'm down with this, but it's worth pointing out that just calling And regarding adding to APPL; we need to propagate that change back to v0.5 too then, because v0.6 is currently not compatible with DPPL (see #440). |
Would |
For maximal model-compat, yes. But you do of course take a performance hit as a result 😕 |
Hrm. Maybe then |
Adding |
This PR adds a 3-arg form of `rand` (suggested by @devmotion in TuringLang/DynamicPPL.jl#466 (comment)) to the interface for `AbstractProbabilisticProgram` and implements the default 1- and 2-arg methods that dispatch to this. Currently tests fail because this breaks the fallbacks for `GraphPPL.Model`, which expects `rand` to forward to its `rand!` method. I'm not certain how we want to define the interface for this `Model`. Co-authored-by: Xianda Sun <[email protected]>
In Turing,
StatsBase.predict
is overloaded to dispatch onDynamicPPL.Model
andMCMCChains.Chains
(https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and callsrand
on the model. We also want to do the same thing forInferenceData
(see #465).It would be convenient if
StatsBase.predict
was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just callrand
for a conditioned model:The text was updated successfully, but these errors were encountered: