Skip to content

Commit

Permalink
Update gamma to match Normal
Browse files Browse the repository at this point in the history
  • Loading branch information
fcooper8472 committed Jan 10, 2024
1 parent fd73727 commit 88765b1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 31 deletions.
84 changes: 54 additions & 30 deletions distribution_zoo/cont_uni/gamma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from distribution_zoo import BaseDistribution

import altair as alt
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import scipy.stats as stats
Expand All @@ -11,66 +11,90 @@ class Gamma(BaseDistribution):

display_name = 'Gamma'
range_min = 0.0
range_max = 100.0
param_range_start = 0.0
param_range_end = 10.0
param_shape = 0.0
param_rate = 1.0
range_max = 150.0
param_shape = st.session_state['gamma_shape'] if 'gamma_shape' in st.session_state else 1.0
param_rate = st.session_state['gamma_rate'] if 'gamma_rate' in st.session_state else 0.5
param_range_start = None
param_range_end = None

def __init__(self):
super().__init__()

def sliders(self):
if 'gamma_range' not in st.session_state:
self.update_range()

# This slider's initial value is set from st.session_state['gamma_range'], set with update_range()
self.param_range_start, self.param_range_end = st.sidebar.slider(
'Range', min_value=self.range_min, max_value=self.range_max, value=(0.0, 10.0), step=0.1
'Range', min_value=self.range_min, max_value=self.range_max, value=(0.0, 10.0), step=0.1, key='gamma_range'
)

self.param_shape = st.sidebar.slider(
r'Shape ($\alpha$)', min_value=0.05, max_value=10.0, value=1.0, step=0.05
r'Shape ($\alpha$)', min_value=0.05, max_value=10.0, value=self.param_shape, step=0.05, key='gamma_shape',
on_change=self.update_range
)

self.param_rate = st.sidebar.slider(
r'Rate ($\beta$)', min_value=0.05, max_value=2.0, value=0.5, step=0.05
r'Rate ($\beta$)', min_value=0.2, max_value=2.0, value=self.param_rate, step=0.05, key='gamma_rate',
on_change=self.update_range
)

def update_range(self):

shape = st.session_state['gamma_shape'] if 'gamma_shape' in st.session_state else self.param_shape
rate = st.session_state['gamma_rate'] if 'gamma_rate' in st.session_state else self.param_rate

new_lower = 0.0
new_upper = min(round(stats.gamma(a=shape, loc=0.0, scale=1. / rate).ppf(0.999), 1), self.range_max)
st.session_state['gamma_range'] = (new_lower, new_upper)

def plot(self):

x = np.linspace(self.range_min, self.range_max, 1000)
mean = self.param_shape / self.param_rate
display_mean = round(mean, 2)

x = np.linspace(self.param_range_start, self.param_range_end, 1000)

chart_data = pd.DataFrame(
{
'x': x,
'pdf': stats.gamma.pdf(x, a=self.param_shape, loc=0.0, scale=1./self.param_rate),
'cdf': stats.gamma.cdf(x, a=self.param_shape, loc=0.0, scale=1./self.param_rate),
'pdf': stats.gamma.pdf(x, a=self.param_shape, loc=0.0, scale=1. / self.param_rate),
'cdf': stats.gamma.cdf(x, a=self.param_shape, loc=0.0, scale=1. / self.param_rate),
}
)

# Define the initial x-axis range for the view
initial_x_range = [self.param_range_start, self.param_range_end]

# Create an Altair chart for the PDF
pdf_chart = alt.Chart(chart_data).mark_line().encode(
x=alt.X('x:Q', scale=alt.Scale(domain=initial_x_range)),
y='pdf:Q',
tooltip=['x', 'pdf']
).interactive()

# Create an Altair chart for the CDF
cdf_chart = alt.Chart(chart_data).mark_line().encode(
x=alt.X('x:Q', scale=alt.Scale(domain=initial_x_range)),
y='cdf:Q',
tooltip=['x', 'cdf']
).interactive()
line_data = pd.DataFrame(
{
'x': [mean, mean],
'pdf': [0.0, max(chart_data['pdf'])],
'cdf': [0.0, max(chart_data['cdf'])],
}
)

# Create Plotly chart for the PDF
pdf_chart = go.Figure(go.Scatter(x=chart_data['x'], y=chart_data['pdf'], mode='lines', name='PDF'))
pdf_chart.add_trace(
go.Scatter(x=line_data['x'], y=line_data['pdf'], mode='lines', name=f'Mean ({display_mean})',
line=dict(color='orange', width=2)))
pdf_chart.update_layout(xaxis_title='x', yaxis_title='pdf', margin=dict(l=20, r=20, t=20, b=20))

# Create Plotly chart for the CDF
cdf_chart = go.Figure(go.Scatter(x=chart_data['x'], y=chart_data['cdf'], mode='lines', name='CDF'))
cdf_chart.add_trace(
go.Scatter(x=line_data['x'], y=line_data['cdf'], mode='lines', name=f'Mean ({display_mean})',
line=dict(color='orange', width=2)))
cdf_chart.update_layout(xaxis_title='x', yaxis_title='cdf', margin=dict(l=20, r=20, t=20, b=20))

# Streamlit columns for displaying the charts
pdf_col, cdf_col = st.columns(2)

with pdf_col:
st.subheader('Probability density function')
st.altair_chart(pdf_chart, use_container_width=True)
st.plotly_chart(pdf_chart, use_container_width=True)

with cdf_col:
st.subheader('Cumulative distribution function')
st.altair_chart(cdf_chart, use_container_width=True)
st.plotly_chart(cdf_chart, use_container_width=True)

def update_code_substitutions(self):
pass
2 changes: 1 addition & 1 deletion distribution_zoo/cont_uni/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def sliders(self):
if 'normal_range' not in st.session_state:
self.update_range()

# This slider's initial value is set from st.session_state['normal_mean'], set with update_range()
# This slider's initial value is set from st.session_state['normal_range'], set with update_range()
self.param_range_start, self.param_range_end = st.sidebar.slider(
'Range', min_value=self.range_min, max_value=self.range_max, step=0.1, key='normal_range'
)
Expand Down

0 comments on commit 88765b1

Please sign in to comment.