Skip to content

Commit

Permalink
Add first attempt at dirichlet
Browse files Browse the repository at this point in the history
  • Loading branch information
fcooper8472 committed Jan 19, 2024
1 parent 4cc6715 commit 94bef7e
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 1 deletion.
8 changes: 7 additions & 1 deletion distribution_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .utils import get_random_animal_emoji
from .utils import inject_custom_css
from .utils import get_indices_from_query_params
from .utils import unit_simplex_3d_uniform_cover

from .cont_uni import (
Normal,
Expand All @@ -16,14 +17,19 @@
Poisson,
)

from .mult import (
Dirichlet,
)

dist_mapping = {
DistributionClass('Continuous Univariate', 'cont_uni'): [
Normal(),
Gamma(),
],
DistributionClass('Discrete Univariate', 'disc_uni'): [
Poisson()
Poisson(),
],
DistributionClass('Multivariate', 'mult'): [
Dirichlet(),
],
}
1 change: 1 addition & 0 deletions distribution_zoo/mult/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dirichlet import Dirichlet
87 changes: 87 additions & 0 deletions distribution_zoo/mult/dirichlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from distribution_zoo import BaseDistribution
from distribution_zoo import unit_simplex_3d_uniform_cover

import plotly.figure_factory as ff
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import pickle
import scipy.stats as stats
import streamlit as st


class Dirichlet(BaseDistribution):

display_name = 'Dirichlet'
param_dim = st.session_state['dirichlet_dim'] if 'dirichlet_dim' in st.session_state else 3
param_a1 = st.session_state['dirichlet_a1'] if 'dirichlet_a1' in st.session_state else 2.0
param_a2 = st.session_state['dirichlet_a2'] if 'dirichlet_a2' in st.session_state else 2.0
param_a3 = st.session_state['dirichlet_a3'] if 'dirichlet_a3' in st.session_state else 2.0
param_a4 = st.session_state['dirichlet_a4'] if 'dirichlet_a4' in st.session_state else 2.0

def __init__(self):
super().__init__()

def sliders(self):

self.param_dim = st.sidebar.slider(
r'Dimension', min_value=2, max_value=4, value=self.param_dim, step=1, key='dirichlet_dim'
)

self.param_a1 = st.sidebar.slider(
r'Concentration $\alpha_1$', min_value=0.1, max_value=10.0, value=self.param_a1, step=0.1, key='dirichlet_a1'
)

self.param_a2 = st.sidebar.slider(
r'Concentration $\alpha_2$', min_value=0.1, max_value=10.0, value=self.param_a2, step=0.1, key='dirichlet_a2'
)

if self.param_dim > 2:
self.param_a3 = st.sidebar.slider(
r'Concentration $\alpha_3$', min_value=0.1, max_value=10.0, value=self.param_a3, step=0.1, key='dirichlet_a3'
)

if self.param_dim > 3:
self.param_a4 = st.sidebar.slider(
r'Concentration $\alpha_4$', min_value=0.1, max_value=10.0, value=self.param_a4, step=0.1, key='dirichlet_a4'
)

def plot(self):

if self.param_dim == 2:
self.plot_2d()
elif self.param_dim == 3:
self.plot_3d()
else: # self.param_dim == 4:
self.plot_4d()

def plot_2d(self):
pass

def plot_3d(self):

sample_points = unit_simplex_3d_uniform_cover(4)

samples = stats.dirichlet.pdf(sample_points, [self.param_a1, self.param_a2, self.param_a3])

# Create a 3D scatter plot
pdf_chart = ff.create_ternary_contour(
coordinates=sample_points,
values=samples,
pole_labels=['α1', 'α2', 'α3'],
colorscale='Cividis',
ncontours=128,
showscale=True,
# linecolor='blue',
)
for i in range(len(pdf_chart.data)):
pdf_chart.data[i].line.width = 0

# Display the plot in Streamlit
st.plotly_chart(pdf_chart)

def plot_4d(self):
st.info('Plot is not available for the 4-dimensional Dirichlet distribution', icon="ℹ️")

def update_code_substitutions(self):
pass
45 changes: 45 additions & 0 deletions distribution_zoo/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import numpy as np
import random
import streamlit as st

Expand Down Expand Up @@ -61,3 +62,47 @@ def get_indices_from_query_params(dist_mapping: dict):
return class_index, None

return class_index, dist_index


@st.cache_data
def unit_simplex_3d_uniform_cover(depth=3):
"""
Uniformly cover the 3d unit simplex X+Y+Z=1 in a generative way, by repeatedly splitting the triangle into
four smaller triangles, and calculating the centre of each. This means that 4^depth samples will be generated.
This function avoids using random sampling, and creates a geometrically uniform cover. It is used primarily
to create the points on which to calculate the 3d Dirichlet distribution PDF.
Args:
depth: Maximum depth of the subdivision.
Returns:
A NumPy array of shape (3, 4^{depth}) containing the simplex points.
"""
# Define the unit simplex vertices
simplex_vertices = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

# Initialize with the first level of subdivision
triangles = [simplex_vertices]

for _ in range(depth):
new_triangles = []
for triangle in triangles:
midpoint1 = (triangle[0] + triangle[1]) / 2
midpoint2 = (triangle[0] + triangle[2]) / 2
midpoint3 = (triangle[1] + triangle[2]) / 2

new_triangles.extend([
[triangle[0], midpoint1, midpoint2],
[triangle[1], midpoint1, midpoint3],
[triangle[2], midpoint2, midpoint3],
[midpoint1, midpoint2, midpoint3]
])

triangles = new_triangles

# Compute centroids of final triangles
points = np.array([np.mean(triangle, axis=0) for triangle in triangles])

# Transpose, because the Dirichlet PDF function wants the samples this way around. (This is unusual.)
return points.T
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ inflection = "^0.5.1"
pandas = "^2.1.4"
requests = "^2.31.0"
plotly = "^5.18.0"
scikit-image = "^0.22.0"

[tool.poetry.group.dev.dependencies]
flake8 = "^6.1.0"
Expand Down

0 comments on commit 94bef7e

Please sign in to comment.