-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
26 additions
and
188 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,202 +1,40 @@ | ||
# Proposal for a New LogDensity Function Interface | ||
|
||
## Introduction | ||
<https://github.com/TuringLang/DynamicPPL.jl/issues/691> | ||
|
||
The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate: | ||
The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows. | ||
|
||
- **Conditioning**: Incorporating observed data into the model. | ||
- **Fixing**: Fixing certain variables to specific values. (like `do` operator) | ||
- **Generated Quantities**: Computing additional expressions or functions based on the model parameters. | ||
- **Prediction**: Making predictions by fixing parameters and unconditioning on data. | ||
## Evaluation functions: | ||
|
||
This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs). | ||
1. `evaluate` | ||
|
||
## Proposed Interface | ||
## Query functions: | ||
|
||
Below is a proposed interface with key functionalities and their implementations. | ||
1. `is_parametric(model)` | ||
2. `dimension(model)` (only defined when `is_parametric(model) == true`) | ||
3. `is_conditioned(model)` | ||
4. `is_fixed(model)` | ||
5. `logjoint(model, params)` | ||
6. `loglikelihood(model, params)` | ||
7. `logprior(model, params)` | ||
|
||
### Core Functions | ||
where `params` can be `Vector`, `NamedTuple`, `Dict`, etc. | ||
|
||
#### Check if a Model is Parametric | ||
## Transformation functions: | ||
|
||
```julia | ||
# Check if a log density model is parametric | ||
function is_parametric(model::LogDensityModel) -> Bool | ||
... | ||
end | ||
``` | ||
1. `condition(model, conditioned_vars)` | ||
2. `fix(model, fixed_vars)` | ||
3. `factor(model, variables_in_the_factor)` | ||
|
||
- **Description**: Determines if the model has a parameter space with a defined dimension. | ||
- | ||
`condition` and `factor` are similar, but `factor` effectively generates a sub-model. | ||
|
||
#### Get the Dimension of a Parametric Model | ||
## Higher-order functions: | ||
|
||
```julia | ||
# Get the dimension of the parameter space (only defined when is_parametric(model) is true) | ||
function dimension(model::LogDensityModel) -> Int | ||
... | ||
end | ||
``` | ||
1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)` | ||
1. `generated_quantities` computes things from the sampling result. | ||
2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.) | ||
3. `rand` is a special case of `generated_quantities` (when no sample is passed). | ||
2. `predict(model, sample)` | ||
|
||
- **Description**: Returns the dimension of the parameter space for parametric models. | ||
|
||
### Log Density Computations | ||
|
||
#### Log-Likelihood | ||
|
||
```julia | ||
# Compute the log-likelihood given parameters | ||
function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Computes the log-likelihood of the data given the model parameters. | ||
|
||
#### Log-Prior | ||
|
||
```julia | ||
# Compute the log-prior given parameters | ||
function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Computes the log-prior probability of the model parameters. | ||
|
||
#### Log-Joint | ||
|
||
```julia | ||
# Compute the log-joint density (log-likelihood + log-prior) | ||
function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
return loglikelihood(model, params) + logprior(model, params) | ||
end | ||
``` | ||
|
||
- **Description**: Computes the total log density by summing the log-likelihood and log-prior. | ||
|
||
### Conditioning and Fixing Variables | ||
|
||
#### Conditioning a Model | ||
|
||
```julia | ||
# Condition the model on observed data | ||
function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`. | ||
|
||
#### Checking if a Model is Conditioned | ||
|
||
```julia | ||
# Check if a model is conditioned | ||
function is_conditioned(model::LogDensityModel) -> Bool | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Checks whether the model has been conditioned on data. | ||
|
||
#### Fixing Variables in a Model | ||
|
||
```julia | ||
# Fix certain variables in the model | ||
function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`. | ||
|
||
#### Checking if a Model has Fixed Variables | ||
|
||
```julia | ||
# Check if a model has fixed variables | ||
function is_fixed(model::LogDensityModel) -> Bool | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Determines if any variables in the model have been fixed. | ||
|
||
### Specialized Models | ||
|
||
#### Conditioned Model Methods | ||
|
||
```julia | ||
# Log-likelihood for a conditioned model | ||
function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
... | ||
end | ||
|
||
# Log-prior for a conditioned model | ||
function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
... | ||
end | ||
|
||
# Log-joint for a conditioned model | ||
function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
return loglikelihood(model, params) + logprior(model, params) | ||
end | ||
``` | ||
|
||
- **Description**: Overrides log density computations to account for the conditioned data. | ||
|
||
#### Fixed Model Methods | ||
|
||
```julia | ||
# Log-likelihood for a fixed model | ||
function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
... | ||
end | ||
|
||
# Log-prior for a fixed model | ||
function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
... | ||
end | ||
|
||
# Log-joint for a fixed model | ||
function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 | ||
return loglikelihood(model, data) + logprior(model, data) | ||
end | ||
``` | ||
|
||
- **Description**: Adjusts log density computations based on the fixed variables. | ||
|
||
### Additional Functionalities | ||
|
||
#### Generated Quantities | ||
|
||
```julia | ||
# Compute generated quantities after fixing parameters | ||
function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Computes additional expressions or functions based on the fixed model parameters. | ||
|
||
#### Prediction | ||
|
||
```julia | ||
# Predict data based on fixed parameters | ||
function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple | ||
... | ||
end | ||
``` | ||
|
||
- **Description**: Generates predictions by fixing the parameters and unconditioning the data. | ||
|
||
## Advantages of the Proposed Interface | ||
|
||
- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling. | ||
|
||
- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side. | ||
|
||
- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve. | ||
|
||
## Usage Examples | ||
|
||
## Non-Parametric Models | ||
`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`. | ||
`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`. |