Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept new coords in model transforms #7549

Open
lucianopaz opened this issue Oct 25, 2024 · 3 comments
Open

Accept new coords in model transforms #7549

lucianopaz opened this issue Oct 25, 2024 · 3 comments

Comments

@lucianopaz
Copy link
Contributor

lucianopaz commented Oct 25, 2024

Describe the issue:

When we call pymc.do or pymc.observe we might also want to change some coordinate values. At the moment, this can't be done.

Reproduceable code example:

import pymc as pm

with pm.Model(coords={"A": range(2)}) as m:
    a = pm.Normal("a", dims="A")
    b = pm.Deterministic("b", a + 1)

with pm.do(m, {"a": list(range(5))}) as m2:
    samples = pm.draw(m2["b"], 10)
    print(samples.shape)  # prints (10, 5) as expected
    samples = pm.sample_prior_predictive(var_names="b", draws=10)  # errors

Error message:

ValueError                                Traceback (most recent call last)
Cell In[5], line 10
      8 samples = pm.draw(m2["b"], 10)
      9 print(samples.shape)
---> 10 samples = pm.sample_prior_predictive(var_names="b", draws=10)

File ~/repos/pymc/pymc/sampling/forward.py:466, in sample_prior_predictive(draws, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs, samples)
    464 if idata_kwargs:
    465     ikwargs.update(idata_kwargs)
--> 466 return pm.to_inference_data(prior=prior, **ikwargs)

File ~/repos/pymc/pymc/backends/arviz.py:532, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, log_prior, coords, dims, sample_dims, model, save_warmup, include_transformed)
    517 if isinstance(trace, InferenceData):
    518     return trace
    520 return InferenceDataConverter(
    521     trace=trace,
    522     prior=prior,
    523     posterior_predictive=posterior_predictive,
    524     log_likelihood=log_likelihood,
    525     log_prior=log_prior,
    526     coords=coords,
    527     dims=dims,
    528     sample_dims=sample_dims,
    529     model=model,
    530     save_warmup=save_warmup,
    531     include_transformed=include_transformed,
--> 532 ).to_inference_data()

File ~/repos/pymc/pymc/backends/arviz.py:434, in InferenceDataConverter.to_inference_data(self)
    432     id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
    433 else:
--> 434     id_dict["constant_data"] = self.constant_data_to_xarray()
    435 idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
    436 if self.log_likelihood:

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:67, in requires.__call__.<locals>.wrapped(cls)
     65     if all((getattr(cls, prop_i) is None for prop_i in prop)):
     66         return None
---> 67 return func(cls)

File ~/repos/pymc/pymc/backends/arviz.py:398, in InferenceDataConverter.constant_data_to_xarray(self)
    395 if not constant_data:
    396     return None
--> 398 xarray_dataset = dict_to_dataset(
    399     constant_data,
    400     library=pymc,
    401     coords=self.coords,
    402     dims=self.dims,
    403     default_dims=[],
    404 )
    406 # provisional handling of scalars in constant
    407 # data to prevent promotion to rank 1
    408 # in the future this will be handled by arviz
    409 scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0]

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:394, in pytree_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    391 except TypeError:  # probably unsortable keys -- the function will still work if
    392     pass  # it is an honest dictionary.
--> 394 data_vars = {
    395     key: numpy_to_data_array(
    396         values,
    397         var_name=key,
    398         coords=coords,
    399         dims=dims.get(key),
    400         default_dims=default_dims,
    401         index_origin=index_origin,
    402         skip_event_dims=skip_event_dims,
    403     )
    404     for key, values in data.items()
    405 }
    406 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:395, in <dictcomp>(.0)
    391 except TypeError:  # probably unsortable keys -- the function will still work if
    392     pass  # it is an honest dictionary.
    394 data_vars = {
--> 395     key: numpy_to_data_array(
    396         values,
    397         var_name=key,
    398         coords=coords,
    399         dims=dims.get(key),
    400         default_dims=default_dims,
    401         index_origin=index_origin,
    402         skip_event_dims=skip_event_dims,
    403     )
    404     for key, values in data.items()
    405 }
    406 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/miniforge3/lib/python3.10/site-packages/arviz/data/base.py:299, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    297 # filter coords based on the dims
    298 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 299 return xr.DataArray(ary, coords=coords, dims=dims)

File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:455, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
    453 data = _check_data_shape(data, coords, dims)
    454 data = as_compatible_data(data)
--> 455 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
    456 variable = Variable(dims, data, attrs, fastpath=True)
    458 if not isinstance(coords, Coordinates):

File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:194, in _infer_coords_and_dims(shape, coords, dims)
    191             var.dims = (dim,)
    192             new_coords[dim] = var.to_index_variable()
--> 194 _check_coords_dims(shape, new_coords, dims_tuple)
    196 return new_coords, dims_tuple

File ~/miniforge3/lib/python3.10/site-packages/xarray/core/dataarray.py:128, in _check_coords_dims(shape, coords, dim)
    126 for d, s in v.sizes.items():
    127     if s != sizes[d]:
--> 128         raise ValueError(
    129             f"conflicting sizes for dimension {d!r}: "
    130             f"length {sizes[d]} on the data but length {s} on "
    131             f"coordinate {k!r}"
    132         )

ValueError: conflicting sizes for dimension 'A': length 5 on the data but length 2 on coordinate 'A'

PyMC version information:

main

Context for the issue:

No response

@lucianopaz lucianopaz added the bug label Oct 25, 2024
@ricardoV94
Copy link
Member

ricardoV94 commented Oct 25, 2024

Not sure I would count it as a bug, rather missing functionality. The user facing transform methods should accept new coords.

Otherwise you can still do it in two calls, 1) transform followed by 2) updating the coord of the new model.

@lucianopaz
Copy link
Contributor Author

Yeah, I agree. The two steps work using m2.set_dim or whatever combination.

@ricardoV94 ricardoV94 changed the title BUG: conditioning a model cannot change the coordinate shapes Accept new coords in model transforms Oct 25, 2024
@lucifer4073
Copy link

Hi @lucianopaz, I would like to work on this issue. Can you please assign it to me?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants