diff --git a/README.md b/README.md index 7bdaf5a..3f1f476 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,8 @@ A typical workflow would entail: mini, result, prob=0.95, limits=0.5, log=False, points=11, return_CIclass=True ) >>> print(c[0]) # OrderedDict of parameter names and corresponding confidence intervals ->>> c[1].plot_ci('a') # plot confidence interval for parameter 'a' +>>> c[1].plot_ci('a') # plot confidence interval for parameter 'a' +>>> c[1].plot_all_ci() # plot confidence intervals for all parameters ``` When using the [Model](https://lmfit.github.io/lmfit-py/model.html) class, the diff --git a/identifiability.py b/identifiability.py index 7636146..1290d3b 100644 --- a/identifiability.py +++ b/identifiability.py @@ -200,9 +200,10 @@ def calc_threshold(self): threshold = threshold_scaled * nfix / nfree return threshold - def plot_ci(self, para): + def plot_ci(self, para, ax=None): assert para in self.p_names, 'para must be one of ' + str(self.p_names) - f, ax = plt.subplots() + if not ax: + f, ax = plt.subplots() xx = self.trace_dict[para]['value'] yy = self.trace_dict[para]['dchi'] t = self.trace_dict[para]['threshold'] @@ -224,6 +225,26 @@ def plot_ci(self, para): ax.set_ylabel(r'$\chi^2\left/\chi^2_0\right. - 1$') ax.set_title(para) + def plot_all_ci(self): + num = len(self.p_names) + numcols = 3 + numrows = num // numcols + 1 + f, ax = plt.subplots(nrows=numrows, ncols=numcols, figsize=(9, 2.5*numrows)) + for i in range(num): + if num <= numcols: + theax = ax[i] + else: + theax = ax[i//numcols, i%numcols] + self.plot_ci(self.p_names[i], ax=theax) + # remove empty axes + empty = numcols - num%numcols + if empty != 0: + for i in range(-empty, 0): + if num <= numcols: + ax[i].set_visible(False) + else: + ax[num // numcols, i].set_visible(False) + f.tight_layout() def conf_interval( minimizer,