-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: model initial_point fails when pt.config.floatX = "float32"
#7608
Comments
] |
debug trace wouldn't fit in the issue: Model debug info from using the debugger to pause at mcmc.py L1386 and calling `model.debug(point, verbose=True)`
point={'mu_adstock_logodds__': array([-0.4352023 , -0.2916657 , -0.22959015, -0.13479875, -0.30916345,
-1.5283642 , -1.655654 , -0.48334104, -0.30794948, 0.33021173,
0.0656831 , -1.6596988 , -0.47080922, -0.91784286, -0.15092045,
-1.1535047 , -1.0793369 , -0.8457366 , -1.3219562 , -1.3434141 ],
dtype=float32), 'mu_lambda_log__': array([2.2058203 , 1.453856 , 1.6157596 , 0.43113965, 1.6577175 ,
0.84230906, 0.27920187, 0.899368 , 1.8292453 , 2.2519255 ,
2.3814824 , 2.108992 , 1.4626849 , 1.5980046 , 0.8299414 ,
0.3984745 , 1.8697526 , 1.4488894 , 2.1219566 , 0.18294996],
dtype=float32), 'mu_a': array(0.16189706, dtype=float32), 'z_a_state_zerosum__': array([-0.28566697, -0.48986882, -0.24404879, -0.4398191 , -0.85556257,
0.10748474], dtype=float32), 'z_a_age_zerosum__': array([-0.8873828 , -0.32871515, 0.57389575, 0.30648488, -0.6199051 ],
dtype=float32), 'z_a_brand_zerosum__': array([-0.46795005], dtype=float32), 'z_a_cohort_zerosum__': array([0.4478183], dtype=float32), 'roas_rv_log__': array([-0.6555429 , -0.068955 , -0.9309296 , -0.47975117, -1.3743148 ,
0.47711676, 0.14864558, 0.10893539, 0.38675046, -0.8370424 ,
-0.66264176, -1.5664616 , -0.572411 , -1.4761899 , -1.0794916 ,
-0.2966979 , 0.05331168, 1.0579 , 0.13020533, 0.8946314 ],
dtype=float32), 'z_b_state_zerosum__': array([[-0.00928065, 0.22661608, 0.26506114, 0.10868097, 0.47448924,
0.6638588 ],
[ 0.39453208, 0.9125831 , -0.05330055, 0.16203904, -0.7428469 ,
-0.2878859 ],
[-0.22169496, 0.25026786, 0.7721265 , -0.16524933, -0.8161399 ,
0.5124463 ],
[ 0.4635996 , 0.935813 , 0.3664374 , -0.39854062, -0.11831979,
0.23826346],
[ 0.4144873 , -0.07588482, -0.4675184 , 0.9954296 , 0.44995347,
0.6562674 ],
[ 0.94758034, 0.2068893 , 0.6966277 , 0.31964955, -0.8013234 ,
-0.59591883],
[-0.17715056, -0.7038275 , 0.18067661, 0.01431344, -0.9491178 ,
-0.3321023 ],
[ 0.3942007 , 0.9996393 , -0.31270924, -0.08990093, -0.09300919,
-0.16450764],
[-0.00447197, -0.61609423, 0.8628801 , 0.96006954, 0.7203218 ,
-0.7518324 ],
[ 0.6942957 , -0.44699988, 0.57910615, 0.8879041 , 0.531556 ,
-0.9510816 ],
[ 0.78471005, 0.10752742, 0.4335172 , -0.58196217, -0.9966123 ,
-0.17337854],
[-0.06716231, -0.5351729 , -0.1103561 , -0.15798165, -0.15524508,
0.8739795 ],
[ 0.47066316, 0.03429028, -0.2272006 , 0.57281727, 0.9989922 ,
-0.26203355],
[-0.59414744, -0.34866652, -0.58397436, -0.12034182, 0.16198853,
-0.36454397],
[ 0.12944746, -0.05762197, 0.99427617, 0.81767935, 0.5921547 ,
0.9800794 ],
[ 0.9717736 , 0.9814946 , 0.4856121 , -0.5534532 , 0.11700594,
0.9247631 ],
[-0.2042932 , -0.411241 , 0.27332023, 0.9046378 , 0.6154953 ,
-0.08056752],
[ 0.9214974 , -0.65947914, -0.41038954, 0.54713 , -0.3560202 ,
-0.9969325 ],
[-0.08087815, 0.18727091, 0.84307253, -0.48801887, 0.29456693,
0.5796735 ],
[ 0.27902 , 0.29730567, -0.40406513, -0.18478568, -0.61452186,
-0.5549851 ]], dtype=float32), 'z_b_age_zerosum__': array([[-6.8486583e-01, -4.6815168e-02, -5.9378707e-01, 8.9602423e-01,
8.5502052e-01],
[ 7.8728044e-01, -3.6670679e-01, -3.3962426e-01, 3.0838227e-01,
3.6406529e-01],
[ 2.6335692e-02, -6.4281446e-01, 5.1187193e-01, 8.4743094e-01,
2.2725777e-01],
[ 1.9795492e-01, 9.2090023e-01, -8.8563585e-01, -2.8022802e-01,
-2.3639840e-01],
[ 2.9900335e-02, 7.1486712e-01, 6.7400551e-01, -7.0308822e-01,
-6.9614536e-01],
[ 1.3013636e-01, 1.0248652e-01, -5.7761997e-02, -8.4077924e-01,
-4.7718164e-01],
[-5.9512955e-01, -9.2812777e-02, -9.9525869e-02, 9.1666229e-02,
7.0176089e-01],
[ 4.7530079e-01, 2.4512438e-01, -4.2329890e-01, -6.4359361e-01,
4.8717073e-01],
[-4.3700787e-01, -3.6620468e-01, -5.4181212e-01, -7.7344787e-01,
-5.0397778e-01],
[ 5.0274485e-01, 8.3750290e-01, -1.0284081e-01, 6.3057953e-01,
9.5303126e-02],
[ 9.8040715e-02, 3.1213897e-01, 9.7941196e-01, 6.9000393e-01,
1.3390434e-01],
[ 5.9936064e-01, -1.6784413e-01, 9.4419844e-02, -5.2747607e-02,
1.4664505e-01],
[-8.8135488e-02, -1.6365480e-01, -2.4431337e-01, 2.9717276e-01,
8.4692138e-01],
[ 4.3495792e-01, 5.3633377e-02, 4.0893257e-01, -1.4952035e-01,
-2.0427135e-01],
[-4.2102399e-01, 7.0554173e-01, -7.1471161e-01, -6.7319351e-01,
6.4274400e-01],
[ 2.9349172e-01, -4.3267983e-01, -6.5261596e-01, -7.2232783e-01,
8.7439209e-01],
[ 3.5815242e-01, 9.3956250e-01, 7.0418483e-01, 6.0373771e-01,
3.7868690e-02],
[-8.4775686e-01, -4.1210197e-02, -8.3802587e-01, -7.5553125e-01,
1.8493308e-01],
[-9.3824327e-01, 4.9592870e-01, 7.3176724e-01, 1.4875463e-01,
4.6959123e-01],
[-9.1654237e-04, -3.3683771e-01, 2.9216877e-01, -1.3542533e-01,
3.8034424e-02]], dtype=float32), 'z_b_brand_zerosum__': array([[ 0.1755699 ],
[-0.64566314],
[-0.164115 ],
[-0.7532449 ],
[ 0.8172519 ],
[-0.07537084],
[ 0.18247645],
[-0.6951736 ],
[-0.34786522],
[ 0.13041314],
[-0.26256517],
[-0.75953436],
[ 0.13290128],
[ 0.7542543 ],
[ 0.43538788],
[ 0.2650157 ],
[-0.22280337],
[ 0.9436325 ],
[ 0.88120073],
[-0.2978969 ]], dtype=float32), 'z_b_cohort_zerosum__': array([[ 0.08365148],
[ 0.4037446 ],
[ 0.219702 ],
[-0.04724727],
[-0.39688712],
[-0.05613985],
[ 0.8531689 ],
[ 0.21144761],
[-0.7068673 ],
[ 0.0072215 ],
[-0.5336951 ],
[ 0.34807727],
[ 0.39438143],
[ 0.7545098 ],
[ 0.8487969 ],
[-0.93471295],
[ 0.36905304],
[ 0.23143601],
[-0.85752094],
[-0.93718106]], dtype=float32), 'mu_b_pos_con': array([-2.4778757, -0.7670131, -2.4533374, -2.1609292, -1.4259497],
dtype=float32), 'z_b_pos_con_state_zerosum__': array([[-0.12270445, 0.24486493, 0.39655864, -0.2905437 , -0.34443325,
0.00207796],
[-0.6199098 , -0.9983607 , 0.17514546, 0.70688826, -0.32889152,
-0.14889538],
[-0.7383829 , -0.3004701 , -0.3937158 , 0.05630853, 0.04613764,
-0.7968154 ],
[-0.83736265, 0.53274584, 0.00774734, 0.3415022 , -0.02279032,
-0.80878764],
[-0.29654893, 0.8825508 , -0.31042358, -0.4510256 , -0.8262907 ,
-0.539443 ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[ 0.78295594, 0.05425891, -0.34404022, 0.54096305, 0.30847767],
[-0.73004913, -0.9038148 , -0.16850933, -0.00236369, -0.6851207 ],
[ 0.39292192, 0.9246333 , 0.6826134 , -0.11111186, 0.9510223 ],
[-0.83537763, 0.82692575, -0.73304725, 0.18374918, -0.9958687 ],
[ 0.56813633, 0.7769569 , -0.94843704, 0.06147736, -0.7710584 ]],
dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[-0.9075556 ],
[-0.27464116],
[ 0.7783447 ],
[-0.8754568 ],
[-0.9165574 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.37049973],
[ 0.95798755],
[ 0.71250516],
[ 0.83639866],
[-0.19618082]], dtype=float32), 'mu_b_neg_con': array([-1.0723019], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[ 0.9188992 , -0.9298911 , 0.19285995, -0.25151858, 0.37471578,
-0.33717418]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[-0.80139756, 0.6711084 , 0.80932933, -0.03985522, 0.22147077]],
dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.5175842]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[-0.61060673]], dtype=float32), 'mu_b_lag': array([-2.16711], dtype=float32), 'z_b_lag_state_zerosum__': array([[-0.79260015, -0.74764526, -0.82469386, -0.03429089, -0.9838945 ,
-0.7090971 ]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.48504904, 0.87696105, -0.742052 , 0.9109938 , 0.04824264]],
dtype=float32), 'z_b_lag_brand_zerosum__': array([[0.00786572]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.23193653]], dtype=float32), 'mu_b_fourier_year': array([-0.77838343, 0.6178618 , 0.57258415, 0.7673871 , 0.3929526 ,
0.8627489 , 0.4622131 , -0.769917 , 0.13717452, -0.13153929,
-0.83192456, 0.1140356 , 0.05431605, -0.10188404, 0.9689602 ,
0.05997486, 0.06146386, -0.9775481 , -0.47403213, -0.6081761 ],
dtype=float32), 'sd_y_log__': array(3.3264434, dtype=float32)}
The variable z_a_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [] [id C] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-57.60402, dtype=float32), array(False)]
Outputs clients: [[output[0](z_a_state_zerosum___logprob)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 267, in eval_in_context
result = eval(compiled, global_vars, local_vars)
File "<string>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
rv_logps = transformed_conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
temp_logp_terms = conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
logprobs_jac.append(logp + log_jac_det)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_a_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [] [id C] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-78.23544, dtype=float32), array(False)]
Outputs clients: [[output[0](z_a_age_zerosum___logprob)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 267, in eval_in_context
result = eval(compiled, global_vars, local_vars)
File "<string>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
rv_logps = transformed_conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
temp_logp_terms = conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
logprobs_jac.append(logp + log_jac_det)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_state has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [7]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_age has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [6]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_brand has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [2]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_cohort has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [2]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-13.464529, -93.78406 , -63.23401 , -79.51486 , -98.71708 ],
dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([ -49.188126, -85.46398 , -112.68599 , -140.30869 , -114.29598 ],
dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_brand has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [2]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-39.79921 , -2.3877418, -28.907381 , -36.937584 , -40.620224 ],
dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_cohort has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [2]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([ -5.4798565 , -44.503365 , -23.999535 , -33.594482 ,
-0.54069936], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_neg_con_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-94.879524], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_neg_con_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-82.99558], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_neg_con_cohort has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [2]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-17.258383], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_lag_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-158.66568], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_lag_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-112.442345], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node. |
Your traceback doesn't look like pymc=5.19, the code doesn't use |
Your traceback does not correspond to the example code either, it's going through a jax sampler, but I don't see it specified in |
I can reproduce it though. Perhaps the logp/dlogp underflows in float32 |
5.19 definitely uses Line 1390 in a714b24
|
As I mentioned in the original comment, the traceback I provided is from my own model, but the reproducible example is using the radon dataset. The only code we use from |
Ah my bad, it does after it fails with the jitter. Either way when you sample with the pymc nuts, it's not failing in the jitter but only inside NUTS already. The Setting an older pymc-marketing gets around the pymc limitation but it's not a solution. It's just because that older version did not pin pymc as strictly as the new one does. Either way that's not the problem you're seeing here |
Right, the pymc-marketing is irrelevant here. |
Ah, I can see the confusion though because the "reproducible" example doesn't reproduce exactly what I was reporting. I took my coworker's word for it at 4pm on a friday afternoon haha. This for example doesn't throw the error: import numpy as np
import pandas as pd
import pymc as pm
import pytensor as pt
pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")
idata = pm.sample(
model=model,
chains=1,
tune=500,
draws=500,
progressbar=False,
compute_convergence_checks=False,
return_inferencedata=False,
nuts_sampler="numpyro",
# compile_kwargs=dict(mode="NUMBA")
) |
Here is a more direct reproducible example: import pytensor
pytensor.config.floatX = "float32"
import numpy as np
import pandas as pd
import pymc as pm
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float32)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")
# Bad point fished from the interactive debugger
from numpy import array, float32
q = {'intercept': array(-0.8458743, dtype=float32),
'county_raw_zerosum__': array([-0.8473211 , 0.97756225, 0.5851473 , -0.8831246 , 0.67874885,
0.74649656, 0.40699005, 0.9938065 , 0.90805703, -0.55194354,
0.7369223 , -0.8693557 , -0.18068689, 0.34439757, 0.8696054 ,
-0.90608346, -0.19901727, 0.18405294, 0.85029787, 0.69731015,
-0.11369044, -0.45499414, 0.4499965 , -0.78362477, -0.42028612,
0.33963433, -0.56401193, 0.45644552, -0.39769658, -0.00929202,
-0.9610129 , 0.40683702, 0.11690333, -0.21440822, -0.35790983,
-0.72231764, -0.7358892 , -0.76221883, -0.44132066, 0.8106245 ,
-0.01106247, 0.89837337, 0.15829656, -0.48148382, -0.07137716,
-0.37613812, -0.36517394, 0.14016594, -0.63096076, -0.42230594,
0.776719 , -0.3128489 , 0.56846076, 0.11121392, 0.5724536 ,
-0.46519637, 0.83556646, -0.3795832 , -0.24870592, -0.908497 ,
-0.62978345, -0.23155476, 0.21914907, 0.5683378 , 0.4083237 ,
0.45315483, -0.06205622, 0.63755155, 0.97950894, -0.05648626,
-0.16262522, 0.40750283, -0.9556285 , -0.42807412, 0.6204139 ,
0.5904101 , -0.7840837 , -0.45694816, -0.6592951 , -0.20405641,
0.7004118 , 0.09331694, 0.06100031, 0.10267377], dtype=float32),
'county_sd_log__': array(0.45848975, dtype=float32),
'floor_effect': array(0.43849692, dtype=float32),
'county_floor_raw_zerosum__': array([ 0.68369645, 0.6433043 , -0.0029135 , -0.49709547, -0.02687999,
0.8271722 , -0.10023019, -0.30813244, -0.4091758 , -0.591417 ,
0.2297259 , -0.6770909 , 0.46815294, 0.23881096, 0.41891697,
0.6744159 , -0.8680713 , 0.9475378 , 0.36461526, -0.11404609,
-0.2285417 , -0.52589136, 0.9446311 , 0.5722908 , 0.86332804,
-0.42848182, -0.1902879 , 0.95098126, 0.1297681 , 0.51527834,
0.7873266 , -0.5753548 , 0.4216227 , -0.08488699, -0.3141113 ,
0.6385347 , -0.26448518, -0.0412051 , -0.6691395 , -0.8684154 ,
0.48946136, -0.5839668 , -0.43648678, -0.20375745, 0.6134852 ,
0.34660435, 0.2335634 , -0.30285057, 0.0682539 , -0.7834195 ,
-0.54660916, -0.94278365, 0.5532979 , -0.76577055, -0.6490462 ,
-0.492982 , -0.74057543, -0.7026031 , 0.5502333 , -0.8355645 ,
-0.16759473, -0.1209451 , 0.5091448 , -0.76411086, 0.14865868,
-0.71105725, 0.8838853 , -0.43318895, -0.8210448 , 0.04136186,
-0.11312467, -0.92210877, -0.19974665, -0.87211764, -0.8225621 ,
0.03210128, -0.31010386, -0.5447676 , -0.79350907, 0.737303 ,
-0.04805126, -0.7177033 , -0.77231514, 0.45744413], dtype=float32),
'county_floor_sd_log__': array(-0.49195403, dtype=float32),
'sigma_log__': array(0.16196507, dtype=float32)}
model.compile_logp()(q) # array(-2813.50744694)
model.logp_dlogp_function(ravel_inputs=False)._pytensor_function(**q)[0] # array(-inf) I suspect it may be some change on the PyTensor side, not PyMC if it was indeed working in older versions. |
import numpy as np
import pandas as pd
import pymc as pm
import pytensor as pt
pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")
try:
idata = pm.sample(
model=model,
chains=1,
tune=500,
draws=500,
progressbar=False,
compute_convergence_checks=False,
return_inferencedata=False,
# nuts_sampler="numpyro",
# compile_kwargs=dict(mode="NUMBA")
)
except Exception as e:
print(e)
model.debug() results in: Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [intercept, county_raw, county_sd, floor_effect, county_floor_raw, county_floor_sd, sigma]
Bad initial energy: SamplerWarning(kind=<WarningType.BAD_ENERGY: 8>, message='Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:\n[]\n.Try model.debug() to identify parametrization problems.', level='critical', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None)
point={'intercept': array(0., dtype=float32), 'county_raw_zerosum__': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32), 'county_sd_log__': array(0., dtype=float32), 'floor_effect': array(0., dtype=float32), 'county_floor_raw_zerosum__': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32), 'county_floor_sd_log__': array(0., dtype=float32), 'sigma_log__': array(0.4054651, dtype=float32)}
No problems found |
The problem is not the initial point, not even the one with jitter but the one from the first step of NUTS which is already mutated. And also it only shows up with the logp_dlogp_function that fuses the logp and dlogp. |
This snippet #7608 (comment) fails in 5.18.0, so maybe the model is just not good enough on float32. So it may be irrelevant for your issue? |
import numpy as np
import pandas as pd
import pymc as pm
import pytensor as pt
from numpy import array, float32
pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")
model_logp_fn = model.compile_logp()
q = {
"intercept": array(-0.8458743, dtype=float32),
"county_raw_zerosum__": array(
[
-0.8473211,
0.97756225,
0.5851473,
-0.8831246,
0.67874885,
0.74649656,
0.40699005,
0.9938065,
0.90805703,
-0.55194354,
0.7369223,
-0.8693557,
-0.18068689,
0.34439757,
0.8696054,
-0.90608346,
-0.19901727,
0.18405294,
0.85029787,
0.69731015,
-0.11369044,
-0.45499414,
0.4499965,
-0.78362477,
-0.42028612,
0.33963433,
-0.56401193,
0.45644552,
-0.39769658,
-0.00929202,
-0.9610129,
0.40683702,
0.11690333,
-0.21440822,
-0.35790983,
-0.72231764,
-0.7358892,
-0.76221883,
-0.44132066,
0.8106245,
-0.01106247,
0.89837337,
0.15829656,
-0.48148382,
-0.07137716,
-0.37613812,
-0.36517394,
0.14016594,
-0.63096076,
-0.42230594,
0.776719,
-0.3128489,
0.56846076,
0.11121392,
0.5724536,
-0.46519637,
0.83556646,
-0.3795832,
-0.24870592,
-0.908497,
-0.62978345,
-0.23155476,
0.21914907,
0.5683378,
0.4083237,
0.45315483,
-0.06205622,
0.63755155,
0.97950894,
-0.05648626,
-0.16262522,
0.40750283,
-0.9556285,
-0.42807412,
0.6204139,
0.5904101,
-0.7840837,
-0.45694816,
-0.6592951,
-0.20405641,
0.7004118,
0.09331694,
0.06100031,
0.10267377,
],
dtype=float32,
),
"county_sd_log__": array(0.45848975, dtype=float32),
"floor_effect": array(0.43849692, dtype=float32),
"county_floor_raw_zerosum__": array(
[
0.68369645,
0.6433043,
-0.0029135,
-0.49709547,
-0.02687999,
0.8271722,
-0.10023019,
-0.30813244,
-0.4091758,
-0.591417,
0.2297259,
-0.6770909,
0.46815294,
0.23881096,
0.41891697,
0.6744159,
-0.8680713,
0.9475378,
0.36461526,
-0.11404609,
-0.2285417,
-0.52589136,
0.9446311,
0.5722908,
0.86332804,
-0.42848182,
-0.1902879,
0.95098126,
0.1297681,
0.51527834,
0.7873266,
-0.5753548,
0.4216227,
-0.08488699,
-0.3141113,
0.6385347,
-0.26448518,
-0.0412051,
-0.6691395,
-0.8684154,
0.48946136,
-0.5839668,
-0.43648678,
-0.20375745,
0.6134852,
0.34660435,
0.2335634,
-0.30285057,
0.0682539,
-0.7834195,
-0.54660916,
-0.94278365,
0.5532979,
-0.76577055,
-0.6490462,
-0.492982,
-0.74057543,
-0.7026031,
0.5502333,
-0.8355645,
-0.16759473,
-0.1209451,
0.5091448,
-0.76411086,
0.14865868,
-0.71105725,
0.8838853,
-0.43318895,
-0.8210448,
0.04136186,
-0.11312467,
-0.92210877,
-0.19974665,
-0.87211764,
-0.8225621,
0.03210128,
-0.31010386,
-0.5447676,
-0.79350907,
0.737303,
-0.04805126,
-0.7177033,
-0.77231514,
0.45744413,
],
dtype=float32,
),
"county_floor_sd_log__": array(-0.49195403, dtype=float32),
"sigma_log__": array(0.16196507, dtype=float32),
}
model.debug(q, verbose=True) makes a nicer debug trace: point={'intercept': array(-0.8458743, dtype=float32), 'county_raw_zerosum__': array([-0.8473211 , 0.97756225, 0.5851473 , -0.8831246 , 0.67874885,
0.74649656, 0.40699005, 0.9938065 , 0.90805703, -0.55194354,
0.7369223 , -0.8693557 , -0.18068689, 0.34439757, 0.8696054 ,
-0.90608346, -0.19901727, 0.18405294, 0.85029787, 0.69731015,
-0.11369044, -0.45499414, 0.4499965 , -0.78362477, -0.42028612,
0.33963433, -0.56401193, 0.45644552, -0.39769658, -0.00929202,
-0.9610129 , 0.40683702, 0.11690333, -0.21440822, -0.35790983,
-0.72231764, -0.7358892 , -0.76221883, -0.44132066, 0.8106245 ,
-0.01106247, 0.89837337, 0.15829656, -0.48148382, -0.07137716,
-0.37613812, -0.36517394, 0.14016594, -0.63096076, -0.42230594,
0.776719 , -0.3128489 , 0.56846076, 0.11121392, 0.5724536 ,
-0.46519637, 0.83556646, -0.3795832 , -0.24870592, -0.908497 ,
-0.62978345, -0.23155476, 0.21914907, 0.5683378 , 0.4083237 ,
0.45315483, -0.06205622, 0.63755155, 0.97950894, -0.05648626,
-0.16262522, 0.40750283, -0.9556285 , -0.42807412, 0.6204139 ,
0.5904101 , -0.7840837 , -0.45694816, -0.6592951 , -0.20405641,
0.7004118 , 0.09331694, 0.06100031, 0.10267377], dtype=float32), 'county_sd_log__': array(0.45848975, dtype=float32), 'floor_effect': array(0.43849692, dtype=float32), 'county_floor_raw_zerosum__': array([ 0.68369645, 0.6433043 , -0.0029135 , -0.49709547, -0.02687999,
0.8271722 , -0.10023019, -0.30813244, -0.4091758 , -0.591417 ,
0.2297259 , -0.6770909 , 0.46815294, 0.23881096, 0.41891697,
0.6744159 , -0.8680713 , 0.9475378 , 0.36461526, -0.11404609,
-0.2285417 , -0.52589136, 0.9446311 , 0.5722908 , 0.86332804,
-0.42848182, -0.1902879 , 0.95098126, 0.1297681 , 0.51527834,
0.7873266 , -0.5753548 , 0.4216227 , -0.08488699, -0.3141113 ,
0.6385347 , -0.26448518, -0.0412051 , -0.6691395 , -0.8684154 ,
0.48946136, -0.5839668 , -0.43648678, -0.20375745, 0.6134852 ,
0.34660435, 0.2335634 , -0.30285057, 0.0682539 , -0.7834195 ,
-0.54660916, -0.94278365, 0.5532979 , -0.76577055, -0.6490462 ,
-0.492982 , -0.74057543, -0.7026031 , 0.5502333 , -0.8355645 ,
-0.16759473, -0.1209451 , 0.5091448 , -0.76411086, 0.14865868,
-0.71105725, 0.8838853 , -0.43318895, -0.8210448 , 0.04136186,
-0.11312467, -0.92210877, -0.19974665, -0.87211764, -0.8225621 ,
0.03210128, -0.31010386, -0.5447676 , -0.79350907, 0.737303 ,
-0.04805126, -0.7177033 , -0.77231514, 0.45744413], dtype=float32), 'county_floor_sd_log__': array(-0.49195403, dtype=float32), 'sigma_log__': array(0.16196507, dtype=float32)}
The variable county_raw has the following parameters:
0: 1.0 [id A] <Scalar(float32, shape=())>
1: MakeVector{dtype='int64'} [id B] <Vector(int64, shape=(1,))>
└─ county [id C] <Scalar(int64, shape=())>
2: [] [id D] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 1.0
1: [85]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-91.11339, dtype=float32), array(False)]
Outputs clients: [[output[0](county_raw_zerosum___logprob)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
File "/workspaces/mmm_v2/reproducible.py", line 231, in <module>
model.debug(q, verbose=True)
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
rv_logps = transformed_conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
temp_logp_terms = conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
logprobs_jac.append(logp + log_jac_det)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable county_floor_raw has the following parameters:
0: 1.0 [id A] <Scalar(float32, shape=())>
1: MakeVector{dtype='int64'} [id B] <Vector(int64, shape=(1,))>
└─ county [id C] <Scalar(int64, shape=())>
2: [] [id D] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 1.0
1: [85]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0
mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-90.89934, dtype=float32), array(False)]
Outputs clients: [[output[0](county_floor_raw_zerosum___logprob)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
[value.type()],
File "/workspaces/mmm_v2/reproducible.py", line 231, in <module>
model.debug(q, verbose=True)
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
rv_logps = transformed_conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
temp_logp_terms = conditional_logp(
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
node_logprobs = _logprob(
File "/usr/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
logprobs_jac.append(logp + log_jac_det)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node. |
If you do the above but change FloatX to 64 and the datatypes in the array to float64 you get: point={'intercept': array(-0.8458743), 'county_raw_zerosum__': array([-0.8473211 , 0.97756225, 0.5851473 , -0.8831246 , 0.67874885,
0.74649656, 0.40699005, 0.9938065 , 0.90805703, -0.55194354,
0.7369223 , -0.8693557 , -0.18068689, 0.34439757, 0.8696054 ,
-0.90608346, -0.19901727, 0.18405294, 0.85029787, 0.69731015,
-0.11369044, -0.45499414, 0.4499965 , -0.78362477, -0.42028612,
0.33963433, -0.56401193, 0.45644552, -0.39769658, -0.00929202,
-0.9610129 , 0.40683702, 0.11690333, -0.21440822, -0.35790983,
-0.72231764, -0.7358892 , -0.76221883, -0.44132066, 0.8106245 ,
-0.01106247, 0.89837337, 0.15829656, -0.48148382, -0.07137716,
-0.37613812, -0.36517394, 0.14016594, -0.63096076, -0.42230594,
0.776719 , -0.3128489 , 0.56846076, 0.11121392, 0.5724536 ,
-0.46519637, 0.83556646, -0.3795832 , -0.24870592, -0.908497 ,
-0.62978345, -0.23155476, 0.21914907, 0.5683378 , 0.4083237 ,
0.45315483, -0.06205622, 0.63755155, 0.97950894, -0.05648626,
-0.16262522, 0.40750283, -0.9556285 , -0.42807412, 0.6204139 ,
0.5904101 , -0.7840837 , -0.45694816, -0.6592951 , -0.20405641,
0.7004118 , 0.09331694, 0.06100031, 0.10267377]), 'county_sd_log__': array(0.45848975), 'floor_effect': array(0.43849692), 'county_floor_raw_zerosum__': array([ 0.68369645, 0.6433043 , -0.0029135 , -0.49709547, -0.02687999,
0.8271722 , -0.10023019, -0.30813244, -0.4091758 , -0.591417 ,
0.2297259 , -0.6770909 , 0.46815294, 0.23881096, 0.41891697,
0.6744159 , -0.8680713 , 0.9475378 , 0.36461526, -0.11404609,
-0.2285417 , -0.52589136, 0.9446311 , 0.5722908 , 0.86332804,
-0.42848182, -0.1902879 , 0.95098126, 0.1297681 , 0.51527834,
0.7873266 , -0.5753548 , 0.4216227 , -0.08488699, -0.3141113 ,
0.6385347 , -0.26448518, -0.0412051 , -0.6691395 , -0.8684154 ,
0.48946136, -0.5839668 , -0.43648678, -0.20375745, 0.6134852 ,
0.34660435, 0.2335634 , -0.30285057, 0.0682539 , -0.7834195 ,
-0.54660916, -0.94278365, 0.5532979 , -0.76577055, -0.6490462 ,
-0.492982 , -0.74057543, -0.7026031 , 0.5502333 , -0.8355645 ,
-0.16759473, -0.1209451 , 0.5091448 , -0.76411086, 0.14865868,
-0.71105725, 0.8838853 , -0.43318895, -0.8210448 , 0.04136186,
-0.11312467, -0.92210877, -0.19974665, -0.87211764, -0.8225621 ,
0.03210128, -0.31010386, -0.5447676 , -0.79350907, 0.737303 ,
-0.04805126, -0.7177033 , -0.77231514, 0.45744413]), 'county_floor_sd_log__': array(-0.49195403), 'sigma_log__': array(0.16196507)}
No problems found |
Yes it's a precision issue that sometimes shows up. debug is not the most useful here because it's evaluating each variable logp at a time. It also doesn't count as bug just yet because float32 is inherently less precise (that's the whole point of it). To mark it as a bug we need proof of some regression (used to work when evaluated at the same point) or a justification for why imprecision is unreasonable here |
I'm out of time to check it tonight but I'm 99% sure the value returned by feeding the same bad point into the compiled logp function between 5.18 and 5.19 is different. I can test by running the debugger to the first line of |
I tested the script I pasted above in 5.18.0 and it wasn't different. Both cases underflowed to -inf. Also need to narrow the focus. Ignore model.debug and check_start_values, that's too noisy and indirect. Evaluating the logp_dlogp_function like the snippet I shared and checking if the first value is -inf is what matters for us/nuts. |
I think I've worked it out. Regardless of whether there is a difference in the compiled logp function or the generated points, we would have never failed in Previously: def _init_jitter
... # not important stuff
initial_points = []
for ipfn, seed in zip(ipfns, seeds):
rng = np.random.RandomState(seed)
for i in range(jitter_max_retries + 1):
point = ipfn(seed)
if i < jitter_max_retries:
try:
model.check_start_vals(point)
except SamplingError:
# Retry with a new seed
seed = rng.randint(2**30, dtype=np.int64)
else:
break
initial_points.append(point)
return initial_points Here in From there it returns up the stack to I've tried to follow the code through where it initialises the numpyro NUTS kernel, MCMC sampler and running the sampler with the initial points, but it's pretty deep in the JAX internals and because it uses So I guess the problem for me here is that |
That sounds right @nataziel. Question: was numpyro actually sampling or did it just return 100% divergences for that chain? It usually does so silently. If it was working fine, it could be that the initial point is not underflowing with the JAX backend, in which case we should probably use the jax logp function to evaluate the init jitter instead of the default |
I edited the code in I don't quite understand what needs to be passed to the jaxified logp function, it seems like the values returned by |
I meant that we should pass the jaxified logp function to be used here instead of defaulting to Lines 1373 to 1379 in a714b24
|
I tried this: from pymc.sampling.jax import get_jaxified_logp
model_logp_fn: Callable
if logp_dlogp_func is None:
model_logp_fn = get_jaxified_logp(model=model, negative_logp=False)
else:
def model_logp_fn(ip):
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0] and tried passing in this point to {'mu_adstock_logodds__': array([-0.7406309 , -0.2676502 , 0.72276473, 0.39711773, -1.4843715 ,
-0.9303382 , -0.70388365, 0.12047255, -1.20291 , -1.0157162 ,
-0.9788475 , -0.16161698, -0.5186577 , 0.44525263, 0.50664103,
-0.5502606 , -0.9929964 , -1.7073784 , -0.469832 , -0.7702299 ],
dtype=float32), 'mu_lambda_log__': array([1.4899505 , 0.87828857, 0.33736384, 0.12322366, 1.6119224 ,
2.1611521 , 0.19036102, 0.78274024, 1.9591975 , 0.86114645,
1.00022 , 1.8908874 , 0.7735787 , 1.2600771 , 0.8721991 ,
0.3919416 , 0.628017 , 0.5571408 , 2.2277155 , 0.701397 ],
dtype=float32), 'mu_a': array(0.46919715, dtype=float32), 'z_a_state_zerosum__': array([ 0.96741545, 0.19968931, 0.55482584, -0.40800413, -0.783277 ,
0.6665936 ], dtype=float32), 'z_a_age_zerosum__': array([ 0.97223747, -0.88604414, -0.60649115, 0.691295 , -0.17161931],
dtype=float32), 'z_a_brand_zerosum__': array([-0.21275534], dtype=float32), 'z_a_cohort_zerosum__': array([-0.01881989], dtype=float32), 'roas_rv_log__': array([-1.8874438 , -1.3390365 , -2.1119297 , -0.30115628, -1.3759781 ,
-0.9544507 , 1.3654704 , -0.80472004, 1.3217607 , -1.6872417 ,
-1.0485291 , -0.90976775, -1.1248429 , -1.5477487 , -0.30651912,
0.51637214, 0.5301037 , -0.49982694, 1.757268 , 1.03213 ],
dtype=float32), 'z_b_state_zerosum__': array([[ 0.00446961, -0.85987175, -0.74123687, -0.46256822, -0.52106553,
0.28104278],
[ 0.05966987, -0.8486371 , -0.43098626, -0.12444586, 0.1801346 ,
-0.37303272],
[ 0.08682439, 0.53125477, -0.4337221 , -0.80694795, -0.41105202,
0.8999604 ],
[ 0.01910053, 0.2654662 , -0.07900505, 0.47407308, 0.7956779 ,
-0.64507806],
[ 0.10577084, -0.01806336, -0.4654986 , 0.00858531, 0.4964019 ,
-0.15452549],
[-0.58119875, -0.533203 , 0.8720117 , -0.9220113 , -0.08726341,
-0.33014426],
[ 0.5597552 , 0.21657923, -0.6274215 , -0.00888674, 0.5606966 ,
-0.6045255 ],
[-0.49455065, 0.5478223 , 0.9508188 , -0.7354254 , 0.19366987,
0.5816819 ],
[-0.82646775, -0.5263257 , 0.20099497, 0.88074464, -0.4345398 ,
0.06769424],
[ 0.26323393, 0.61359143, 0.01295813, -0.40680176, -0.3380146 ,
0.3240754 ],
[ 0.6390363 , -0.07461884, 0.17888807, -0.17294951, -0.8052904 ,
-0.2960819 ],
[-0.88565934, 0.13199767, -0.09011242, -0.57291055, 0.71278757,
-0.06531783],
[ 0.4843889 , 0.9435816 , 0.14761145, 0.2508237 , -0.02830961,
0.40583134],
[ 0.64028126, -0.09345473, 0.44015244, -0.18035695, 0.63984483,
0.40306124],
[ 0.85732955, 0.20738094, -0.77978706, -0.5081236 , 0.25628823,
0.9576838 ],
[-0.43284124, 0.49378812, -0.7574774 , -0.9391033 , 0.6099457 ,
-0.83641356],
[-0.6440243 , 0.68688387, 0.4862265 , 0.5263312 , -0.3289637 ,
0.18450338],
[-0.7553003 , 0.8161998 , -0.88512534, -0.06603678, 0.24693777,
-0.78690183],
[ 0.1868632 , -0.21966957, 0.3369232 , -0.9996609 , -0.35670304,
0.4821175 ],
[ 0.3532054 , -0.5449791 , -0.00193312, -0.16562222, 0.51523185,
-0.3292687 ]], dtype=float32), 'z_b_age_zerosum__': array([[ 0.4639124 , 0.94854355, -0.32051557, 0.5695813 , -0.464497 ],
[ 0.24478397, -0.38236296, 0.2442325 , 0.6532343 , -0.5803767 ],
[-0.43226814, 0.53636163, 0.3303343 , -0.42391777, -0.04977154],
[-0.0449917 , -0.18323518, 0.09939765, 0.44787315, 0.21340491],
[-0.48593655, 0.9875687 , -0.30522144, -0.24290714, 0.7979216 ],
[-0.5306124 , 0.43397802, -0.20600496, 0.8865641 , 0.36890575],
[ 0.16192637, -0.85434455, 0.14579847, -0.6387437 , -0.6332226 ],
[-0.00474756, -0.4770844 , -0.80014896, -0.4984475 , 0.08337943],
[-0.38859433, 0.81244034, -0.15071645, 0.7578935 , -0.22230786],
[ 0.21995819, -0.45969793, -0.05771023, -0.3626073 , 0.8941617 ],
[ 0.07187908, -0.25421968, 0.11764435, -0.01395176, -0.6094777 ],
[-0.13571994, 0.9205862 , -0.6560107 , 0.3603058 , -0.8363712 ],
[ 0.78542286, -0.4191767 , -0.27891508, 0.2105725 , 0.38422632],
[-0.7178596 , 0.49950433, -0.2591695 , -0.6500654 , 0.78156734],
[ 0.24467742, -0.09884497, -0.9059215 , 0.69811964, -0.04913842],
[-0.00168463, -0.1506732 , -0.8326015 , 0.260028 , -0.318087 ],
[ 0.40579167, -0.42483094, 0.15233344, 0.3852206 , 0.26324713],
[ 0.6354356 , -0.5003003 , -0.09142033, 0.80062026, -0.05573656],
[-0.64219224, 0.75683314, -0.25206646, 0.9859022 , -0.7528035 ],
[ 0.04350585, -0.21413967, 0.7432214 , -0.6038442 , 0.219704 ]],
dtype=float32), 'z_b_brand_zerosum__': array([[ 0.77682245],
[-0.68846065],
[ 0.15427889],
[ 0.85911787],
[ 0.7141093 ],
[ 0.11126634],
[ 0.8475281 ],
[-0.018953 ],
[-0.08016697],
[-0.4876936 ],
[ 0.92964715],
[ 0.77537847],
[ 0.23121522],
[ 0.11847817],
[-0.7639938 ],
[-0.22309716],
[-0.41808844],
[ 0.23701279],
[-0.04789526],
[ 0.09624694]], dtype=float32), 'z_b_cohort_zerosum__': array([[-0.5643933 ],
[ 0.9322058 ],
[-0.3288698 ],
[ 0.00913158],
[-0.7385534 ],
[-0.68776625],
[-0.5444413 ],
[ 0.42466304],
[ 0.7728684 ],
[-0.7700562 ],
[-0.18284462],
[ 0.6402906 ],
[ 0.6988464 ],
[ 0.8284381 ],
[-0.7142591 ],
[-0.12452263],
[ 0.1419726 ],
[ 0.20289187],
[ 0.6634637 ],
[ 0.77786607]], dtype=float32), 'mu_b_pos_con': array([-1.3094065, -1.7115856, -1.4250492, -2.2511344, -1.226164 ],
dtype=float32), 'z_b_pos_con_state_zerosum__': array([[-0.39936054, 0.8805549 , 0.97654635, 0.6494237 , 0.5060455 ,
0.8129397 ],
[-0.7020755 , 0.8573673 , -0.11473656, -0.81267875, 0.52816015,
0.25964367],
[-0.30995888, -0.909639 , -0.03129133, -0.83288676, -0.8827531 ,
0.8252884 ],
[-0.23201741, 0.5135355 , -0.8893724 , -0.00104977, -0.5592616 ,
0.8351593 ],
[-0.03384887, 0.25019094, -0.80081666, 0.45951134, -0.35681835,
-0.8254566 ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[-0.61742157, -0.09719887, 0.58104664, -0.92894936, -0.9795723 ],
[-0.42654088, 0.64068526, 0.30092153, 0.24177577, 0.2526327 ],
[-0.80097747, 0.9057477 , 0.43585142, -0.85004056, -0.01753056],
[-0.31914535, 0.14012223, -0.6530986 , 0.7002828 , 0.6456084 ],
[-0.16960691, 0.26178694, -0.47111732, -0.3870159 , 0.63950986]],
dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[-0.06096065],
[ 0.35775375],
[ 0.8893246 ],
[ 0.14325647],
[ 0.9434139 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.954737 ],
[-0.07808845],
[ 0.56892526],
[-0.37843582],
[-0.66838884]], dtype=float32), 'mu_b_neg_con': array([-0.25086808], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[-0.932691 , -0.56728685, -0.08727422, 0.06912095, 0.8635172 ,
-0.2142895 ]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[ 0.90207416, -0.88017714, -0.83211 , -0.5490533 , 0.6520192 ]],
dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.25788188]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[0.56738347]], dtype=float32), 'mu_b_lag': array([-2.4469194], dtype=float32), 'z_b_lag_state_zerosum__': array([[ 0.1700482 , 0.5989304 , 0.47253153, -0.75125474, -0.5838406 ,
0.68338937]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.33903205, -0.63544524, 0.03893599, 0.47806814, -0.04220857]],
dtype=float32), 'z_b_lag_brand_zerosum__': array([[-0.08016187]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.18082687]], dtype=float32), 'mu_b_fourier_year': array([ 0.8240624 , 0.646912 , -0.452118 , -0.08140495, -0.96048754,
-0.74027205, -0.9938018 , 0.20062245, -0.28748137, -0.82254994,
-0.4910437 , -0.322535 , -0.09896964, -0.30639052, 0.8899779 ,
0.2462373 , -0.25278836, 0.16529965, -0.17628683, 0.96998924],
dtype=float32), 'sd_y_log__': array(3.4784012, dtype=float32)} and am getting this error: model_logp_fn(point)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 703, in dtype
dt = np.result_type(x)
TypeError: data type 'sd_y_log__' not understood
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 5394, in array
dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in _lattice_result_type
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in <genexpr>
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 516, in _dtype_and_weaktype
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 705, in dtype
raise TypeError(f"Cannot determine dtype of {x}") from err
TypeError: Cannot determine dtype of sd_y_log__
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/pymc/sampling/jax.py", line 154, in logp_fn_wrap
return logp_fn(*x)[0]
File "/tmp/tmpa31z87gr", line 3, in jax_funcified_fgraph
tensor_variable = elemwise_fn(sd_y_log_)
File "/usr/local/lib/python3.10/dist-packages/pytensor/link/jax/dispatch/elemwise.py", line 17, in elemwise_fn
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 5592, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 5399, in array
dtype = dtypes._lattice_result_type(*leaves)[0]
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in _lattice_result_type
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in <genexpr>
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 516, in _dtype_and_weaktype
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 707, in dtype
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
TypeError: Value 'sd_y_log__' with dtype <U10 is not a valid JAX array type. Only arrays of numeric types are supported by JAX. I tried changing the value of |
Ah! The culprit is here: Lines 153 to 156 in a714b24
It's passing the keys instead of the values so we can do: model_logp_fn: Callable
if logp_dlogp_func is None:
model_logp_fn = get_jaxified_logp(model=model, negative_logp=False)
else:
def model_logp_fn(ip):
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0]
initial_points = []
for ipfn, seed in zip(ipfns, seeds):
rng = np.random.default_rng(seed)
for i in range(jitter_max_retries + 1):
point = ipfn(seed)
point_logp = model_logp_fn(point)
if not np.isfinite(point_logp):
if i == jitter_max_retries:
# Print informative message on last attempted point
model.check_start_vals(point.values())
# Retry with a new seed
seed = rng.integers(2**30, dtype=np.int64)
else:
break And that doesn't underflow |
Instead of changing the logic inside init_nuts, pass the jax logp fn wrapped in a callable that handles the conversion |
Yep, sounds good. Working on a PR now |
Describe the issue:
I have a hierarchical mmm model setup (with pymc, not pymc-marketing) and have been successfully using it with
float32
s up to the 5.19 release. It is using numpyro/jax to sample.With 5.19 I am getting errors in
_init_jitter
, it appears that there is something going wrong when passing the generated initial points to the compiled logp function. I think the use of zerosumNormal distributions is causing the problem but I'm not sure if it's the values returned by theipfn(seed)
or the evaluation in the compiledmodel_logp_fn
. I've included the verbose model debug return from my model, but the reproducible example below is using the example radon model.Reproduceable code example:
Error message:
PyMC version information:
Running on a windows machine in a linux container
pymc installed via poetry(pypi).
using
libopenblas
annotated-types 0.7.0
arviz 0.19.0
babel 2.16.0
blinker 1.4
build 1.2.2.post1
CacheControl 0.14.1
cachetools 5.5.0
certifi 2024.8.30
cfgv 3.4.0
charset-normalizer 3.4.0
cleo 2.1.0
click 8.1.7
cloudpickle 3.1.0
colorama 0.4.6
cons 0.4.6
contourpy 1.3.1
coverage 7.6.8
crashtest 0.4.1
cryptography 3.4.8
cycler 0.12.1
dbus-python 1.2.18
distlib 0.3.9
distro 1.7.0
distro-info 1.1+ubuntu0.2
dm-tree 0.1.8
dulwich 0.21.7
etuples 0.3.9
exceptiongroup 1.2.2
fastjsonschema 2.21.1
filelock 3.16.1
fonttools 4.55.1
ghp-import 2.1.0
graphviz 0.20.3
griffe 1.5.1
h5netcdf 1.4.1
h5py 3.12.1
httplib2 0.20.2
identify 2.6.3
idna 3.10
importlib_metadata 8.5.0
iniconfig 2.0.0
installer 0.7.0
jaraco.classes 3.4.0
jax 0.4.35
jaxlib 0.4.35
jeepney 0.7.1
Jinja2 3.1.4
joblib 1.4.2
keyring 24.3.1
kiwisolver 1.4.7
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
logical-unification 0.4.6
loguru 0.7.2
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.9.3
mdurl 0.1.2
mergedeep 1.3.4
miniKanren 1.0.3
mkdocs 1.6.1
mkdocs-autorefs 1.2.0
mkdocs-gen-files 0.5.0
mkdocs-get-deps 0.2.0
mkdocs-glightbox 0.4.0
mkdocs-literate-nav 0.6.1
mkdocs-material 9.5.47
mkdocs-material-extensions 1.3.1
mkdocs-section-index 0.3.9
mkdocstrings 0.26.2
mkdocstrings-python 1.12.2
ml_dtypes 0.5.0
mmm_v2 0.0.1 /workspaces/mmm_v2
more-itertools 8.10.0
msgpack 1.1.0
multimethod 1.10
multipledispatch 1.0.0
mypy-extensions 1.0.0
nodeenv 1.9.1
numpy 1.26.4
numpyro 0.15.3
oauthlib 3.2.0
opt_einsum 3.4.0
packaging 24.2
paginate 0.5.7
pandas 2.2.3
pandera 0.20.4
pathspec 0.12.1
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
pkginfo 1.12.0
platformdirs 4.3.6
pluggy 1.5.0
poetry 1.8.4
poetry-core 1.9.1
poetry-plugin-export 1.8.0
pre_commit 4.0.1
ptyprocess 0.7.0
pyarrow 18.1.0
pydantic 2.10.3
pydantic_core 2.27.1
Pygments 2.18.0
PyGObject 3.42.1
PyJWT 2.3.0
pymc 5.19.1
pymc-marketing 0.6.0
pymdown-extensions 10.12
pyparsing 3.2.0
pyproject_hooks 1.2.0
pytensor 2.26.4
pytest 8.3.4
pytest-cov 6.0.0
python-apt 2.4.0+ubuntu3
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
pyyaml_env_tag 0.1
RapidFuzz 3.10.1
regex 2024.11.6
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
ruff 0.8.1
scikit-learn 1.5.2
scipy 1.14.1
seaborn 0.13.2
SecretStorage 3.3.1
setuptools 75.6.0
shellingham 1.5.4
six 1.17.0
ssh-import-id 5.11
threadpoolctl 3.5.0
tomli 2.2.1
tomlkit 0.13.2
toolz 1.0.0
tqdm 4.67.1
trove-classifiers 2024.10.21.16
typeguard 4.4.1
typing_extensions 4.12.2
typing-inspect 0.9.0
tzdata 2024.2
unattended-upgrades 0.1
urllib3 2.2.3
virtualenv 20.28.0
wadllib 1.3.6
watchdog 6.0.0
wheel 0.38.4
wrapt 1.17.0
xarray 2024.9.0
xarray-einstats 0.8.0
zipp 3.21.0
Context for the issue:
I want to use f32 due to memory constraints, I've got a big model and dataset with a hierarchy that takes up 40+gb of ram when running using f32 so I'd need a huge box to go up to f64.
Maybe the optimisations in 5.19 make that problem moot? I've managed to get it running with f64 temporarily but I'm not sure if it'll be a long term solution
The text was updated successfully, but these errors were encountered: