Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

az.compare stacking weights do not sum to one #2359

Open
hschuett opened this issue Jul 17, 2024 · 4 comments
Open

az.compare stacking weights do not sum to one #2359

hschuett opened this issue Jul 17, 2024 · 4 comments

Comments

@hschuett
Copy link

Dear Arviz team,

Issue
when using arviz.compare() to compute stacking weights for model averaging, we got a) weights that do not sum to one and b) got two different sets of weights in two different runs with the same data. (literally, all other numbers in the compare output are the same, except for the weights. Below are the two compare results dfs we got.

Documentation
First run (on a grid batch job):

rank elpd_loo p_loo elpd_diff weight se dse warning scale
0 -299161.07 20338.32 0.00 0.93 757.80 0.0000 True log
1 -315073.54 29909.27 15912.47 0.39 741.13 238.65 True log
2 -318882.55 805.33 19721.49 0.00 726.20 272.98 True log
3 -328994.76 617.96 29833.70 0.00 717.08 320.34 True log
4 -343025.50 89.02 43864.43 0.00 694.47 381.01 True log

Arviz version: 0.18.0
Pandas version: 2.0.3
Numpy version: 1.24.3
Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
OS info: posix, Linux, 4.19.0-25-amd64

Second run (on Jupyter server using the same resources):

rank elpd_loo p_loo elpd_diff weight se dse warning scale
0 -299161.0654 20338.3236 0.0000 1.0000 757.7960 0.0000 True log
1 -315073.5328 29909.2672 15912.4675 0.0019 741.1303 238.6455 True log
2 -318882.5520 805.3308 19721.4866 0.0011 726.2029 272.9801 True log
3 -328994.7609 617.9605 29833.6956 0.0005 717.0755 320.3400 True log
4 -343025.4954 89.0185 43864.4300 0.0000 694.4689 381.0094 True log

Arviz version: 0.18.0
Pandas version: 2.0.3
Numpy version: 1.24.3
Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
OS info: posix, Linux, 4.19.0-25-amd64

Both runs produce the same content for all columns except the weight column, and in both cases the weights do not sum to 1. (I have checked for differences for these two run environments and could not find the source).

Code that produces these tables from a bunch of loo objects, constructed with az.loo():

def gen_model_comparison(loo_filepaths, model_names):
    """
    Based on a list of filepaths of loofiles, compute stacked model weights using the arviz package
    """
    loo_objs =[]
    for fp in loo_filepaths:
        with open(fp, "rb") as file:
            loo_objs.append(pickle.load(file))
    for nr, loo in enumerate(loo_objs):
        print(model_names[nr] + " -> " + str(loo.head()["n_data_points"]))
    comp_dict = {e[0] : e[1] for e in zip(model_names, loo_objs)}
    comp1 = az.compare(comp_dict, ic='loo', method="stacking", scale='log')
    return comp1

The models are quite big, but if necessary, I can provide more data.

Expected behavior
Unless I'm mistaken, the weights should sum to 1.00.

@OriolAbril
Copy link
Member

Thanks for reporting, it looks quite strange. Have you been able to reproduce it with any other models? Are there NaNs or anything of the sort in the data?

Also, could you share some more info about the "grid batch job"? The 2nd one is much closer to summing to 1 and might be due to numerical stability issues (in which case a fix could be renormalizing the weights right before returning them). But the 1st case looks weirder. I am also a bit surprised to see only 2 decimals on the table, is that the exact same output you got from az.compare? There should be no rounding functionality in compare and I doubt those are the results with all the decimal places.

@hschuett
Copy link
Author

hschuett commented Jul 18, 2024

No, you are right, we rounded for the logs. I switched that off. Maybe I found part of the issue. The arviz inference data showed older versions of arviz and numpyro than what I had in both of my environments. The loo objects were created with versions 0.17.0 for arviz and 0.13,2 for numpyro. We updated all environments and reran everything with arviz 0.18.0 and numpyro 0.15.1. Since then, both environments give me the same answers. A different answer yet again, and not summing up to 1 though. (0.7183 + 0.3622 = 1.0805).

rank elpd_loo p_loo elpd_diff weight se dse warning scale
0 -299155.8269 20325.1512 0.0000 0.7183 757.7038 0.0000 True log
1 -315073.5328 29909.2672 15917.7059 0.3622 741.1303 238.5335 True log
2 -318876.9458 799.8740 19721.1188 0.0000 726.2186 272.9542 True log
3 -328994.7609 617.9605 29838.9340 0.0000 717.0755 320.3247 True log
4 -343025.4954 89.0185 43869.6684 0.0000 694.4689 380.9634 True log

Even though these are re-runs, all columns but the weight are very very similar. (each model is big. 345,539 observations) Unfortunately, my colleague didn't save the orginal 0.17 arviz loo objects before re-running ... so I cannot replicate the old numbers anymore.

@OriolAbril
Copy link
Member

All your estimates have warning=True, so different results across different MCMC runs is possible, even if using the same versions of all libraries. What should be deterministic and not depend on warning=True indicating it is an untrustworthy estimate is running compare on the exact same samples. I understand now when using the same exact samples you always get the same result (save rounding errors), is that right?

As for normalization there is probably a bug in the way it is implemented right now, we'll look into it and fix it.

we rounded for the logs

How did you do that? Is there an easy way for us to check inside az.compare pr az.loo for example and warn/raise an error if it is active? These computations rely on operations with logs and logsumexp and should not be expected to work without precise computations and float64.

@hschuett
Copy link
Author

hschuett commented Jul 18, 2024

I understand now when using the same exact samples you always get the same result (save rounding errors), is that right?

Yes, that is correct

"we rounded for the logs", How did you do that?

Apologies, I didn't phrase that well. What I meant was, we rounded to two decimals when printing the compare dataframe out to the log-files of our scripts. That's why you only saw two decimals in the first table. We didn't do any rounding or adjustments to the loo objects before feeding them to az.compare.

Noted and agreed on the warning=True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants