Skip to content

Commit

Permalink
add method 'plot_all_ci()' to plot CIs for all parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
jmrohwer committed Apr 16, 2022
1 parent afc5eb0 commit 4b6193d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ A typical workflow would entail:
mini, result, prob=0.95, limits=0.5, log=False, points=11, return_CIclass=True
)
>>> print(c[0]) # OrderedDict of parameter names and corresponding confidence intervals
>>> c[1].plot_ci('a') # plot confidence interval for parameter 'a'
>>> c[1].plot_ci('a') # plot confidence interval for parameter 'a'
>>> c[1].plot_all_ci() # plot confidence intervals for all parameters
```

When using the [Model](https://lmfit.github.io/lmfit-py/model.html) class, the
Expand Down
25 changes: 23 additions & 2 deletions identifiability.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ def calc_threshold(self):
threshold = threshold_scaled * nfix / nfree
return threshold

def plot_ci(self, para):
def plot_ci(self, para, ax=None):
assert para in self.p_names, 'para must be one of ' + str(self.p_names)
f, ax = plt.subplots()
if not ax:
f, ax = plt.subplots()
xx = self.trace_dict[para]['value']
yy = self.trace_dict[para]['dchi']
t = self.trace_dict[para]['threshold']
Expand All @@ -224,6 +225,26 @@ def plot_ci(self, para):
ax.set_ylabel(r'$\chi^2\left/\chi^2_0\right. - 1$')
ax.set_title(para)

def plot_all_ci(self):
num = len(self.p_names)
numcols = 3
numrows = num // numcols + 1
f, ax = plt.subplots(nrows=numrows, ncols=numcols, figsize=(9, 2.5*numrows))
for i in range(num):
if num <= numcols:
theax = ax[i]
else:
theax = ax[i//numcols, i%numcols]
self.plot_ci(self.p_names[i], ax=theax)
# remove empty axes
empty = numcols - num%numcols
if empty != 0:
for i in range(-empty, 0):
if num <= numcols:
ax[i].set_visible(False)
else:
ax[num // numcols, i].set_visible(False)
f.tight_layout()

def conf_interval(
minimizer,
Expand Down

0 comments on commit 4b6193d

Please sign in to comment.