Skip to content

Commit

Permalink
Update range slider when the parameters change
Browse files Browse the repository at this point in the history
  • Loading branch information
fcooper8472 committed Jan 7, 2024
1 parent 2e9ed74 commit f499c0b
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions distribution_zoo/cont_uni/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,43 @@ class Normal(BaseDistribution):
display_name = 'Normal'
range_min = -50.0
range_max = 50.0
param_range_start = -10.0
param_range_end = 10.0
param_mean = 0.0
param_std = 1.0
param_mean = st.session_state['normal_mean'] if 'normal_mean' in st.session_state else 0.0
param_std = st.session_state['normal_std'] if 'normal_std' in st.session_state else 1.0
param_range_start = None
param_range_end = None

def __init__(self):
super().__init__()

def sliders(self):

if 'normal_range' not in st.session_state:
self.update_range()

# This slider's initial value is set from st.session_state['normal_mean'], set with update_range()
self.param_range_start, self.param_range_end = st.sidebar.slider(
'Range', min_value=self.range_min, max_value=self.range_max, value=(-10.0, 10.0), step=0.1, key='normal_range'
'Range', min_value=self.range_min, max_value=self.range_max, step=0.1, key='normal_range'
)

self.param_mean = st.sidebar.slider(
r'Mean ($\mu$)', min_value=-30.0, max_value=30.0, value=0.0, step=0.1, key='normal_mean'
r'Mean ($\mu$)', min_value=-16.0, max_value=16.0, value=self.param_mean, step=0.1, key='normal_mean',
on_change=self.update_range
)

self.param_std = st.sidebar.slider(
r'Standard deviation ($\sigma$)', min_value=0.1, max_value=20.0, value=1.0, step=0.1, key='normal_std'
r'Standard deviation ($\sigma$)', min_value=0.1, max_value=8.0, value=self.param_std, step=0.1,
key='normal_std', on_change=self.update_range
)

def update_range(self):

mean = st.session_state['normal_mean'] if 'normal_mean' in st.session_state else self.param_mean
std = st.session_state['normal_std'] if 'normal_std' in st.session_state else self.param_std

new_lower = round(stats.norm(loc=mean, scale=std).ppf(0.0001), 1)
new_upper = round(stats.norm(loc=mean, scale=std).ppf(0.9999), 1)
st.session_state['normal_range'] = (new_lower, new_upper)

def plot(self):

x = np.linspace(self.range_min, self.range_max, 1000)
Expand Down

0 comments on commit f499c0b

Please sign in to comment.