Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 23, 2024
1 parent 21df935 commit 6117e4e
Showing 1 changed file with 26 additions and 188 deletions.
214 changes: 26 additions & 188 deletions design_notes/logdensity_interface.md
Original file line number Diff line number Diff line change
@@ -1,202 +1,40 @@
# Proposal for a New LogDensity Function Interface

## Introduction
<https://github.com/TuringLang/DynamicPPL.jl/issues/691>

The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate:
The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows.

- **Conditioning**: Incorporating observed data into the model.
- **Fixing**: Fixing certain variables to specific values. (like `do` operator)
- **Generated Quantities**: Computing additional expressions or functions based on the model parameters.
- **Prediction**: Making predictions by fixing parameters and unconditioning on data.
## Evaluation functions:

This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs).
1. `evaluate`

## Proposed Interface
## Query functions:

Below is a proposed interface with key functionalities and their implementations.
1. `is_parametric(model)`
2. `dimension(model)` (only defined when `is_parametric(model) == true`)
3. `is_conditioned(model)`
4. `is_fixed(model)`
5. `logjoint(model, params)`
6. `loglikelihood(model, params)`
7. `logprior(model, params)`

### Core Functions
where `params` can be `Vector`, `NamedTuple`, `Dict`, etc.

#### Check if a Model is Parametric
## Transformation functions:

```julia
# Check if a log density model is parametric
function is_parametric(model::LogDensityModel) -> Bool
...
end
```
1. `condition(model, conditioned_vars)`
2. `fix(model, fixed_vars)`
3. `factor(model, variables_in_the_factor)`

- **Description**: Determines if the model has a parameter space with a defined dimension.
-
`condition` and `factor` are similar, but `factor` effectively generates a sub-model.

#### Get the Dimension of a Parametric Model
## Higher-order functions:

```julia
# Get the dimension of the parameter space (only defined when is_parametric(model) is true)
function dimension(model::LogDensityModel) -> Int
...
end
```
1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)`
1. `generated_quantities` computes things from the sampling result.
2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.)
3. `rand` is a special case of `generated_quantities` (when no sample is passed).
2. `predict(model, sample)`

- **Description**: Returns the dimension of the parameter space for parametric models.

### Log Density Computations

#### Log-Likelihood

```julia
# Compute the log-likelihood given parameters
function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
...
end
```

- **Description**: Computes the log-likelihood of the data given the model parameters.

#### Log-Prior

```julia
# Compute the log-prior given parameters
function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
...
end
```

- **Description**: Computes the log-prior probability of the model parameters.

#### Log-Joint

```julia
# Compute the log-joint density (log-likelihood + log-prior)
function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
return loglikelihood(model, params) + logprior(model, params)
end
```

- **Description**: Computes the total log density by summing the log-likelihood and log-prior.

### Conditioning and Fixing Variables

#### Conditioning a Model

```julia
# Condition the model on observed data
function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel
...
end
```

- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`.

#### Checking if a Model is Conditioned

```julia
# Check if a model is conditioned
function is_conditioned(model::LogDensityModel) -> Bool
...
end
```

- **Description**: Checks whether the model has been conditioned on data.

#### Fixing Variables in a Model

```julia
# Fix certain variables in the model
function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel
...
end
```

- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`.

#### Checking if a Model has Fixed Variables

```julia
# Check if a model has fixed variables
function is_fixed(model::LogDensityModel) -> Bool
...
end
```

- **Description**: Determines if any variables in the model have been fixed.

### Specialized Models

#### Conditioned Model Methods

```julia
# Log-likelihood for a conditioned model
function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
...
end

# Log-prior for a conditioned model
function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
...
end

# Log-joint for a conditioned model
function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64
return loglikelihood(model, params) + logprior(model, params)
end
```

- **Description**: Overrides log density computations to account for the conditioned data.

#### Fixed Model Methods

```julia
# Log-likelihood for a fixed model
function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
...
end

# Log-prior for a fixed model
function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
...
end

# Log-joint for a fixed model
function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64
return loglikelihood(model, data) + logprior(model, data)
end
```

- **Description**: Adjusts log density computations based on the fixed variables.

### Additional Functionalities

#### Generated Quantities

```julia
# Compute generated quantities after fixing parameters
function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple
...
end
```

- **Description**: Computes additional expressions or functions based on the fixed model parameters.

#### Prediction

```julia
# Predict data based on fixed parameters
function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple
...
end
```

- **Description**: Generates predictions by fixing the parameters and unconditioning the data.

## Advantages of the Proposed Interface

- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling.

- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side.

- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve.

## Usage Examples

## Non-Parametric Models
`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`.
`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`.

0 comments on commit 6117e4e

Please sign in to comment.