-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch backend slow with pymc model #1110
Comments
The script I used for reference import arviz as az
import numpy as np
import multiprocessing
import pandas as pd
import pymc as pm
import pytensor as pt
import pytensor.tensor.random as ptr
import time
def main():
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic(
"county_floor_effect", raw * sd, dims="county"
)
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal(
"log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
)
from pymc.model.transform.optimization import freeze_dims_and_data
model = freeze_dims_and_data(model)
for mode in ("NUMBA", "PYTORCH"):
start = time.perf_counter()
trace = pm.sample(
model=model,
cores=1,
chains=1,
tune=500,
draws=500,
progressbar=False,
compute_convergence_checks=False,
return_inferencedata=False,
compile_kwargs=dict(mode=mode)
)
end = time.perf_counter()
idata = pm.to_inference_data(trace, model=model)
print(az.summary(idata, kind="diagnostics"))
print(mode, trace._report.t_sampling, end - start)
if __name__ == "__main__":
multiprocessing.freeze_support()
main() |
I called compile logp and dlogp as well to narrow down the time
|
You have to call it once (perhaps assert they output the same) to jit compile and then only timeit |
@ricardoV94 can you assign this issue to me by chance? I profiled a bit more. The logp and dlogp pytensor functions don't take long to generate, but executing it is slower. Both numba and torch do show that it gets faster if you execute it multiple times, but numba is much faster
I'm seeing some interesting data in the Graph Count: 8
Graph Break Count: 7
Op Count: 198
Break Reasons:
Break Reason 1:
Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
User Stack:
<FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 35 in pytorch_funcified_fgraph>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
Break Reason 2:
Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
User Stack:
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
Break Reason 3:
Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
User Stack:
<FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 75 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_35>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
Break Reason 4:
Reason: Dynamic slicing on data-dependent value is not supported
User Stack:
<FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 133 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_75>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 77 in inc_subtensor>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in indices_from_subtensor>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in <genexpr>>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 131 in convert_indices>
Break Reason 5:
Reason: Dynamic slicing on data-dependent value is not supported
User Stack:
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 81 in torch_dynamo_resume_in_inc_subtensor_at_78>
Break Reason 6:
Reason: Dynamic slicing on data-dependent value is not supported
User Stack:
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 81 in torch_dynamo_resume_in_inc_subtensor_at_81>
Break Reason 7:
Reason: Dynamic slicing on data-dependent value is not supported
User Stack:
<FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 193 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_133>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 77 in inc_subtensor>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in indices_from_subtensor>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in <genexpr>>
<FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 131 in convert_indices> The first thing I'm gonna do is clean up some of the warnings. These timings are only possible because of sending the warning logs out to dev null, if you don't the timings balloon a bit. |
The first time is the compilation, it's not as relevant since it's a one time thing. If we're recompiling multiple times that's a different thing. Also are you using Did you confirm the outputs match? |
Those breaks are interesting but most are not data dependent? Like the slice is constant in this model. Can you enable that capture scalar outputs option? Also are you freezing the model data and dims like in the original example? When we have static shapes we could forward those to the dispatch |
Yeah my bad; this is how the profiling looks from pymc.model.transform.optimization import freeze_dims_and_data
model = freeze_dims_and_data(model)
val = model.initial_point(123)
for mode in ("PYTORCH", "NUMBA"):
fn = model.compile_logp(mode=mode)
for _ in range(3):
start = time.perf_counter()
fn(val)
end = time.perf_counter()
print("| ", " | ".join((mode, "logp.__call__", "{:.4f}".format(end - start))), " |") |
Btw, I only showed logp, here is dlogp, which is potentially more problematic
dlogp seems to be a bit more problematic. I reran timing this morning and got really poor numbers (8s), but last night it was "better". Not sure what's going on there. import os
import arviz as az
import numpy as np
import multiprocessing
import pandas as pd
import pymc as pm
import pytensor as pt
import pytensor.tensor.random as ptr
import time
def main():
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic(
"county_floor_effect", raw * sd, dims="county"
)
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal(
"log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
)
from pymc.model.transform.optimization import freeze_dims_and_data
model = freeze_dims_and_data(model)
val = model.initial_point(123)
os.environ["PYTORCH_RECORD_EXPLAIN"] = "yes"
for mode in ("PYTORCH", "NUMBA"):
fn = model.compile_dlogp(mode=mode)
for _ in range(3):
start = time.perf_counter()
fn(val)
end = time.perf_counter()
print("| ", " |".join((mode, "dlogp.__call__", "{:.4f}".format(end - start))), " |")
if __name__ == "__main__":
multiprocessing.freeze_support()
main() |
Addressing the q's
yea, i pasted the profiling code, mb
I find this a bit strange too; i think what might be happening is we have tensors of a single value (like what we had in #1031 ) and that's causing a log of "data dependent" operators? I need to understand the plumbing a bit more I think to really know.
Unbelievably, this crashes
Yes! just using
I time 3 calls. I could time more but I mostly just cared to see that the first one was big and the rest were "smaller" |
The first slow next fast is really not surprising, compilation happens on the first call. 3 calls is not usually enough to measure without noise. I believe torch is slower of course but you should average the time it takes to eval like 100 times at least. You can use the %timeit magic to do this automatically on an ipython (or jupyter) environment. |
What's the compiled function. Curious why it doesn't think those slice indices are constant. You can do |
Yea, agree that the pattern isn't surprising, but i was worried each call was getting a recompile (like the dlogp seems to do perhaps intermittently). I'll do a more robust measure. What would a dynamic slice look like? |
a = pt.scalar("a", dtype=int)
x = pt.vector("x")
out = x[:a] |
Here are the graphs, i dumped them into a text file. Logp first, then dlogp, separated by |
looking at the times the problem is definitely the dlogp graph. How much better if we get it through torch autodiff? |
Okay just to drive the numbers point home, I did this in ipython. fn = model.compile_logp(mode="PYTORCH")
fn(val)
%timeit -n 100 -r 100 fn(val)
2.18 ms ± 411 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) fn = model.compile_logp(mode="NUMBA")
fn(val)
%timeit -n 100 -r 100 fn(val)
50.5 μs ± 4.29 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) That's for logp. For dlogp fn = model.compile_dlogp(mode="PYTORCH")
fn(val)
%timeit -n 100 -r 100 fn(val)
2.47 ms ± 430 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) fn = model.compile_dlogp(mode="NUMBA")
fn(val)
%timeit -n 100 -r 100 fn(val)
62.3 μs ± 5.34 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) I haven't looked at going through the torch autodiff engine for dlogp, I can put that on the list of possible things. |
I can't super tell from the dprint if the slices are constant or not. I do see
|
There are two warnings I wanna dig into.
|
@Ch0ronomato you need to eval the function once before doing the timeit. The first call will take care of the jit compilation which shouldn't be mixed with the eval cost. That's why you're getting those warnings that the slowest eval took x longer than fastest |
Cleaned up |
So it's 4-5x slower than numba. Can you show what you get on the default c backend and jax on your machine as well? |
My bad @ricardoV94 , here is jax and c fn = model.compile_logp(mode="JAX")
fn(val)
%timeit -n 100 -r 100 fn(val)
98.5 μs ± 11.1 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) fn = model.compile_logp()
fn(val)
%timeit -n 100 -r 100 fn(val)
117 μs ± 5.44 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) Here were the torch and numba as well fn = model.compile_logp(mode="NUMBA")
fn(val)
%timeit -n 100 -r 100 fn(val)
50.5 μs ± 4.29 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) fn = model.compile_logp(mode="PYTORCH")
fn(val)
%timeit -n 100 -r 100 fn(val)
2.18 ms ± 411 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) Similar for torch: 2.81 ms ± 576 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
numba: 59.7 μs ± 2.92 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
c: 117 μs ± 5.44 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) |
I run the profiler with the following code: import pymc as pm
import numpy as np
import pandas as pd
from pymc import DictToArrayBijection
from pymc.model.transform.optimization import freeze_dims_and_data
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic(
"county_floor_effect", raw * sd, dims="county"
)
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal(
"log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
)
frozen_model = freeze_dims_and_data(model)
def test_benchmark():
logp_f = frozen_model.logp_dlogp_function(mode="PYTORCH", ravel_inputs=True)._pytensor_function
logp_f.trust_input = True
ip = DictToArrayBijection.map(model.initial_point()).data
logp_f(ip) # trigger compilation
import cProfile
with cProfile.Profile() as pr:
[logp_f(ip) for _ in range(1000)]
pr.dump_stats("torch_logp_fn.prof") Then snakeviz shows there are 5 graph breaks, which is quite a lot? It seems For curiosity the overhead of |
I also grabbed a bit of info from torch. Unfortunately on it's own it's not incredibly useful. I'm digging around a little to see if I can get what is going on. For those functions, I wonder if those are in the torch guards, since those will always run in python. For example here is a shape check Guard 1778:
Name: "G['elemwise_fn_76'].__closure__[1].cell_contents.inputs[0].type.shape[0]"
Source: global
Create Function: CONSTANT_MATCH
Guard Types: ['EQUALS_MATCH']
Code List: ["G['elemwise_fn_76'].__closure__[1].cell_contents.inputs[0].type.shape[0] == 1"]
Object Weakref: None
Guarded Class Weakref: <weakref at 0x10c8ca7a0; to 'type' at 0x10c6b0538 (int)> there are thousands of them. |
I did some investigating on using autodiff instead of the way pytensor does it. It doesn't look like that gives us much in terms of raw performance. %timeit -n 100 -r 100 without_autograd_fn(initial_point)
2.15 ms ± 104 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
%timeit -n 100 -r 100 with_autograd_fn(initial_point)
3.05 ms ± 138 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) where the autodiff'd function looks like this AutogradOp.0 [id A] 0
├─ intercept [id B]
├─ county_raw_zerosum__ [id C]
├─ county_sd_log__ [id D]
├─ floor_effect [id E]
├─ county_floor_raw_zerosum__ [id F]
├─ county_floor_sd_log__ [id G]
└─ sigma_log__ [id H]
AutogradOp.1 [id A] 0
└─ ···
AutogradOp.2 [id A] 0
└─ ···
AutogradOp.3 [id A] 0
└─ ···
AutogradOp.4 [id A] 0
└─ ···
AutogradOp.5 [id A] 0
└─ ···
AutogradOp.6 [id A] 0
└─ ···
Inner graphs:
AutogradOp [id A]
← Add [id I]
├─ Check{sigma > 0} [id J] 'intercept_logprob'
│ ├─ Add [id K]
│ │ ├─ -3.221523658174155 [id L]
│ │ └─ Mul [id M]
│ │ ├─ -0.5 [id N]
│ │ └─ Pow [id O]
│ │ ├─ Mul [id P]
│ │ │ ├─ 0.1 [id Q]
│ │ │ └─ intercept [id R]
│ │ └─ 2 [id S]
│ └─ True [id T]
├─ Check{mean(value, axis=n_zerosum_axes) = 0} [id U]
│ ├─ Sum{axes=None} [id V]
│ │ └─ Sub [id W]
│ │ ├─ Mul [id X]
│ │ │ ├─ [-0.5] [id Y]
│ │ │ └─ Pow [id Z]
│ │ │ ├─ Sub [id BA]
│ │ │ │ ├─ Join [id BB]
│ │ │ │ │ ├─ 0 [id BC]
│ │ │ │ │ ├─ county_raw_zerosum__ [id BD]
│ │ │ │ │ └─ Sub [id BE]
│ │ │ │ │ ├─ True_div [id BF]
│ │ │ │ │ │ ├─ ExpandDims{axis=0} [id BG]
│ │ │ │ │ │ │ └─ Sum{axes=None} [id BH]
│ │ │ │ │ │ │ └─ county_raw_zerosum__ [id BD]
│ │ │ │ │ │ └─ Add [id BI]
│ │ │ │ │ │ ├─ Sqrt [id BJ]
│ │ │ │ │ │ │ └─ Cast{float64} [id BK]
│ │ │ │ │ │ │ └─ Add [id BL]
│ │ │ │ │ │ │ ├─ [1] [id BM]
│ │ │ │ │ │ │ └─ Shape [id BN]
│ │ │ │ │ │ │ └─ county_raw_zerosum__ [id BD]
│ │ │ │ │ │ └─ Cast{float64} [id BK]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ True_div [id BO]
│ │ │ │ │ ├─ ExpandDims{axis=0} [id BG]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ Sqrt [id BJ]
│ │ │ │ │ └─ ···
│ │ │ │ └─ True_div [id BF]
│ │ │ │ └─ ···
│ │ │ └─ [2] [id BP]
│ │ └─ True_div [id BQ]
│ │ ├─ Mul [id BR]
│ │ │ ├─ [0.91893853] [id BS]
│ │ │ └─ Cast{float64} [id BT]
│ │ │ └─ IncSubtensor{start:} [id BU]
│ │ │ ├─ Shape [id BV]
│ │ │ │ └─ Sub [id BA]
│ │ │ │ └─ ···
│ │ │ ├─ -1 [id BW]
│ │ │ └─ -1 [id BX]
│ │ └─ Cast{float64} [id BY]
│ │ └─ Shape [id BV]
│ │ └─ ···
│ └─ And [id BZ]
│ ├─ Le [id CA]
│ │ ├─ True_div [id CB]
│ │ │ ├─ Abs [id CC]
│ │ │ │ └─ Sum{axes=None} [id CD]
│ │ │ │ └─ Sub [id BA]
│ │ │ │ └─ ···
│ │ │ └─ Abs [id CE]
│ │ │ └─ Cast{float64} [id CF]
│ │ │ └─ DropDims{axis=0} [id CG]
│ │ │ └─ Shape [id BV]
│ │ │ └─ ···
│ │ └─ 1e-09 [id CH]
│ └─ Invert [id CI]
│ └─ Or [id CJ]
│ ├─ Isnan [id CK]
│ │ └─ True_div [id CL] 'mean'
│ │ ├─ Sum{axes=None} [id CD]
│ │ │ └─ ···
│ │ └─ Cast{float64} [id CF]
│ │ └─ ···
│ └─ Isinf [id CM]
│ └─ True_div [id CL] 'mean'
│ └─ ···
├─ Check{sigma > 0} [id CN]
│ ├─ Switch [id CO]
│ │ ├─ Ge [id CP]
│ │ │ ├─ Exp [id CQ]
│ │ │ │ └─ county_sd_log__ [id CR]
│ │ │ └─ 0.0 [id CS]
│ │ ├─ Add [id CT]
│ │ │ ├─ -0.22579135264472738 [id CU]
│ │ │ └─ Mul [id CV]
│ │ │ ├─ -0.5 [id N]
│ │ │ └─ Pow [id CW]
│ │ │ ├─ Exp [id CQ]
│ │ │ │ └─ ···
│ │ │ └─ 2 [id S]
│ │ └─ -inf [id CX]
│ └─ True [id T]
├─ county_sd_log__ [id CR]
├─ Check{sigma > 0} [id CY] 'floor_effect_logprob'
│ ├─ Add [id CZ]
│ │ ├─ -1.6120857156692723 [id DA]
│ │ └─ Mul [id DB]
│ │ ├─ -0.5 [id N]
│ │ └─ Pow [id DC]
│ │ ├─ Mul [id DD]
│ │ │ ├─ 0.5 [id DE]
│ │ │ └─ floor_effect [id DF]
│ │ └─ 2 [id S]
│ └─ True [id T]
├─ Check{mean(value, axis=n_zerosum_axes) = 0} [id DG]
│ ├─ Sum{axes=None} [id DH]
│ │ └─ Sub [id DI]
│ │ ├─ Mul [id DJ]
│ │ │ ├─ [-0.5] [id Y]
│ │ │ └─ Pow [id DK]
│ │ │ ├─ Sub [id DL]
│ │ │ │ ├─ Join [id DM]
│ │ │ │ │ ├─ 0 [id BC]
│ │ │ │ │ ├─ county_floor_raw_zerosum__ [id DN]
│ │ │ │ │ └─ Sub [id DO]
│ │ │ │ │ ├─ True_div [id DP]
│ │ │ │ │ │ ├─ ExpandDims{axis=0} [id DQ]
│ │ │ │ │ │ │ └─ Sum{axes=None} [id DR]
│ │ │ │ │ │ │ └─ county_floor_raw_zerosum__ [id DN]
│ │ │ │ │ │ └─ Add [id DS]
│ │ │ │ │ │ ├─ Sqrt [id DT]
│ │ │ │ │ │ │ └─ Cast{float64} [id DU]
│ │ │ │ │ │ │ └─ Add [id DV]
│ │ │ │ │ │ │ ├─ [1] [id BM]
│ │ │ │ │ │ │ └─ Shape [id DW]
│ │ │ │ │ │ │ └─ county_floor_raw_zerosum__ [id DN]
│ │ │ │ │ │ └─ Cast{float64} [id DU]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ True_div [id DX]
│ │ │ │ │ ├─ ExpandDims{axis=0} [id DQ]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ Sqrt [id DT]
│ │ │ │ │ └─ ···
│ │ │ │ └─ True_div [id DP]
│ │ │ │ └─ ···
│ │ │ └─ [2] [id BP]
│ │ └─ True_div [id DY]
│ │ ├─ Mul [id DZ]
│ │ │ ├─ [0.91893853] [id BS]
│ │ │ └─ Cast{float64} [id EA]
│ │ │ └─ IncSubtensor{start:} [id EB]
│ │ │ ├─ Shape [id EC]
│ │ │ │ └─ Sub [id DL]
│ │ │ │ └─ ···
│ │ │ ├─ -1 [id BW]
│ │ │ └─ -1 [id BX]
│ │ └─ Cast{float64} [id ED]
│ │ └─ Shape [id EC]
│ │ └─ ···
│ └─ And [id EE]
│ ├─ Le [id EF]
│ │ ├─ True_div [id EG]
│ │ │ ├─ Abs [id EH]
│ │ │ │ └─ Sum{axes=None} [id EI]
│ │ │ │ └─ Sub [id DL]
│ │ │ │ └─ ···
│ │ │ └─ Abs [id EJ]
│ │ │ └─ Cast{float64} [id EK]
│ │ │ └─ DropDims{axis=0} [id EL]
│ │ │ └─ Shape [id EC]
│ │ │ └─ ···
│ │ └─ 1e-09 [id CH]
│ └─ Invert [id EM]
│ └─ Or [id EN]
│ ├─ Isnan [id EO]
│ │ └─ True_div [id EP] 'mean'
│ │ ├─ Sum{axes=None} [id EI]
│ │ │ └─ ···
│ │ └─ Cast{float64} [id EK]
│ │ └─ ···
│ └─ Isinf [id EQ]
│ └─ True_div [id EP] 'mean'
│ └─ ···
├─ Check{sigma > 0} [id ER]
│ ├─ Switch [id ES]
│ │ ├─ Ge [id ET]
│ │ │ ├─ Exp [id EU]
│ │ │ │ └─ county_floor_sd_log__ [id EV]
│ │ │ └─ 0.0 [id CS]
│ │ ├─ Add [id EW]
│ │ │ ├─ -0.22579135264472738 [id CU]
│ │ │ └─ Mul [id EX]
│ │ │ ├─ -0.5 [id N]
│ │ │ └─ Pow [id EY]
│ │ │ ├─ Exp [id EU]
│ │ │ │ └─ ···
│ │ │ └─ 2 [id S]
│ │ └─ -inf [id CX]
│ └─ True [id T]
├─ county_floor_sd_log__ [id EV]
├─ Check{sigma > 0} [id EZ]
│ ├─ Switch [id FA]
│ │ ├─ Ge [id FB]
│ │ │ ├─ Exp [id FC]
│ │ │ │ └─ sigma_log__ [id FD]
│ │ │ └─ 0.0 [id CS]
│ │ ├─ Add [id FE]
│ │ │ ├─ -0.6312564488800027 [id FF]
│ │ │ └─ Mul [id FG]
│ │ │ ├─ -0.5 [id N]
│ │ │ └─ Pow [id FH]
│ │ │ ├─ Mul [id FI]
│ │ │ │ ├─ 0.6666666666666666 [id FJ]
│ │ │ │ └─ Exp [id FC]
│ │ │ │ └─ ···
│ │ │ └─ 2 [id S]
│ │ └─ -inf [id CX]
│ └─ True [id T]
├─ sigma_log__ [id FD]
└─ Sum{axes=None} [id FK]
└─ Check{sigma > 0} [id FL] 'log_radon_logprob'
├─ Sub [id FM]
│ ├─ Add [id FN]
│ │ ├─ [-0.91893853] [id FO]
│ │ └─ Mul [id FP]
│ │ ├─ [-0.5] [id Y]
│ │ └─ Pow [id FQ]
│ │ ├─ True_div [id FR]
│ │ │ ├─ Sub [id FS]
│ │ │ │ ├─ log_radon{[ 0.832909 ... .09861229]} [id FT]
│ │ │ │ └─ Add [id FU]
│ │ │ │ ├─ ExpandDims{axis=0} [id FV]
│ │ │ │ │ └─ intercept [id R]
│ │ │ │ ├─ AdvancedSubtensor [id FW]
│ │ │ │ │ ├─ Mul [id FX] 'county_effect'
│ │ │ │ │ │ ├─ Sub [id BA]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id FY]
│ │ │ │ │ │ └─ Exp [id CQ]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ [ 0 0 0 ... 83 84 84] [id FZ]
│ │ │ │ ├─ Mul [id GA]
│ │ │ │ │ ├─ ExpandDims{axis=0} [id GB]
│ │ │ │ │ │ └─ floor_effect [id DF]
│ │ │ │ │ └─ [1. 0. 0. ... 0. 0. 0.] [id GC]
│ │ │ │ └─ Mul [id GD]
│ │ │ │ ├─ AdvancedSubtensor [id GE]
│ │ │ │ │ ├─ Mul [id GF] 'county_floor_effect'
│ │ │ │ │ │ ├─ Sub [id DL]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id GG]
│ │ │ │ │ │ └─ Exp [id EU]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ [ 0 0 0 ... 83 84 84] [id FZ]
│ │ │ │ └─ [1. 0. 0. ... 0. 0. 0.] [id GC]
│ │ │ └─ ExpandDims{axis=0} [id GH]
│ │ │ └─ Exp [id FC]
│ │ │ └─ ···
│ │ └─ [2] [id BP]
│ └─ Log [id GI]
│ └─ ExpandDims{axis=0} [id GH]
│ └─ ···
└─ Gt [id GJ]
├─ Exp [id FC]
│ └─ ···
└─ 0 [id BC] and the normal one looks like this Join [id A] 117
├─ 0 [id B]
├─ Reshape{1} [id C] 116
│ ├─ Add [id D] 115
│ │ ├─ Mul [id E] 114
│ │ │ ├─ -0.0100000 ... 0000000002 [id F]
│ │ │ └─ intercept [id G]
│ │ └─ True_div [id H] 113
│ │ ├─ True_div [id I] 112
│ │ │ ├─ Sum{axes=None} [id J] 111
│ │ │ │ └─ Sub [id K] 41
│ │ │ │ ├─ log_radon{[ 0.832909 ... .09861229]} [id L]
│ │ │ │ └─ Add [id M] 40
│ │ │ │ ├─ ExpandDims{axis=0} [id N] 39
│ │ │ │ │ └─ intercept [id G]
│ │ │ │ ├─ AdvancedSubtensor1 [id O] 38
│ │ │ │ │ ├─ Mul [id P] 'county_effect' 37
│ │ │ │ │ │ ├─ Sub [id Q] 36
│ │ │ │ │ │ │ ├─ Join [id R] 35
│ │ │ │ │ │ │ │ ├─ 0 [id B]
│ │ │ │ │ │ │ │ ├─ county_raw_zerosum__ [id S]
│ │ │ │ │ │ │ │ └─ Sub [id T] 34
│ │ │ │ │ │ │ │ ├─ True_div [id U] 32
│ │ │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id V] 31
│ │ │ │ │ │ │ │ │ │ └─ Sum{axes=None} [id W] 30
│ │ │ │ │ │ │ │ │ │ └─ county_raw_zerosum__ [id S]
│ │ │ │ │ │ │ │ │ └─ Add [id X] 29
│ │ │ │ │ │ │ │ │ ├─ Sqrt [id Y] 28
│ │ │ │ │ │ │ │ │ │ └─ Cast{float64} [id Z] 27
│ │ │ │ │ │ │ │ │ │ └─ Add [id BA] 26
│ │ │ │ │ │ │ │ │ │ ├─ [1] [id BB]
│ │ │ │ │ │ │ │ │ │ └─ MakeVector{dtype='int64'} [id BC] 25
│ │ │ │ │ │ │ │ │ │ └─ Shape_i{0} [id BD] 24
│ │ │ │ │ │ │ │ │ │ └─ county_raw_zerosum__ [id S]
│ │ │ │ │ │ │ │ │ └─ Cast{float64} [id Z] 27
│ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ └─ True_div [id BE] 33
│ │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id V] 31
│ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ └─ Sqrt [id Y] 28
│ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ True_div [id U] 32
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id BF] 23
│ │ │ │ │ │ └─ Exp [id BG] 22
│ │ │ │ │ │ └─ county_sd_log__ [id BH]
│ │ │ │ │ └─ [ 0 0 0 ... 83 84 84] [id BI]
│ │ │ │ ├─ Mul [id BJ] 21
│ │ │ │ │ ├─ ExpandDims{axis=0} [id BK] 20
│ │ │ │ │ │ └─ floor_effect [id BL]
│ │ │ │ │ └─ [1. 0. 0. ... 0. 0. 0.] [id BM]
│ │ │ │ └─ Mul [id BN] 19
│ │ │ │ ├─ AdvancedSubtensor1 [id BO] 18
│ │ │ │ │ ├─ Mul [id BP] 'county_floor_effect' 17
│ │ │ │ │ │ ├─ Sub [id BQ] 16
│ │ │ │ │ │ │ ├─ Join [id BR] 15
│ │ │ │ │ │ │ │ ├─ 0 [id B]
│ │ │ │ │ │ │ │ ├─ county_floor_raw_zerosum__ [id BS]
│ │ │ │ │ │ │ │ └─ Sub [id BT] 14
│ │ │ │ │ │ │ │ ├─ True_div [id BU] 12
│ │ │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id BV] 11
│ │ │ │ │ │ │ │ │ │ └─ Sum{axes=None} [id BW] 10
│ │ │ │ │ │ │ │ │ │ └─ county_floor_raw_zerosum__ [id BS]
│ │ │ │ │ │ │ │ │ └─ Add [id BX] 9
│ │ │ │ │ │ │ │ │ ├─ Sqrt [id BY] 8
│ │ │ │ │ │ │ │ │ │ └─ Cast{float64} [id BZ] 7
│ │ │ │ │ │ │ │ │ │ └─ Add [id CA] 6
│ │ │ │ │ │ │ │ │ │ ├─ [1] [id BB]
│ │ │ │ │ │ │ │ │ │ └─ MakeVector{dtype='int64'} [id CB] 5
│ │ │ │ │ │ │ │ │ │ └─ Shape_i{0} [id CC] 4
│ │ │ │ │ │ │ │ │ │ └─ county_floor_raw_zerosum__ [id BS]
│ │ │ │ │ │ │ │ │ └─ Cast{float64} [id BZ] 7
│ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ └─ True_div [id CD] 13
│ │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id BV] 11
│ │ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ │ └─ Sqrt [id BY] 8
│ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ True_div [id BU] 12
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id CE] 3
│ │ │ │ │ │ └─ Exp [id CF] 2
│ │ │ │ │ │ └─ county_floor_sd_log__ [id CG]
│ │ │ │ │ └─ [ 0 0 0 ... 83 84 84] [id BI]
│ │ │ │ └─ [1. 0. 0. ... 0. 0. 0.] [id BM]
│ │ │ └─ Exp [id CH] 0
│ │ │ └─ sigma_log__ [id CI]
│ │ └─ Exp [id CH] 0
│ │ └─ ···
│ └─ [-1] [id CJ]
├─ Add [id CK] 110
│ ├─ Split{2}.0 [id CL] 102
│ │ ├─ Sub [id CM] 101
│ │ │ ├─ Mul [id CN] 100
│ │ │ │ ├─ AdvancedIncSubtensor1{no_inplace,inc} [id CO] 89
│ │ │ │ │ ├─ Alloc [id CP] 88
│ │ │ │ │ │ ├─ [0.] [id CQ]
│ │ │ │ │ │ └─ Add [id CR] 87
│ │ │ │ │ │ ├─ Shape_i{0} [id BD] 24
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id CS]
│ │ │ │ │ ├─ True_div [id CT] 86
│ │ │ │ │ │ ├─ True_div [id CU] 43
│ │ │ │ │ │ │ ├─ Sub [id K] 41
│ │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id CV] 42
│ │ │ │ │ │ │ └─ Exp [id CH] 0
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id CV] 42
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ [ 0 0 0 ... 83 84 84] [id BI]
│ │ │ │ └─ ExpandDims{axis=0} [id BF] 23
│ │ │ │ └─ ···
│ │ │ └─ Sub [id Q] 36
│ │ │ └─ ···
│ │ ├─ 0 [id B]
│ │ └─ MakeVector{dtype='int64'} [id CW] 99
│ │ ├─ Shape_i{0} [id BD] 24
│ │ │ └─ ···
│ │ └─ 1 [id CS]
│ ├─ True_div [id CX] 109
│ │ ├─ Sub [id CY] 108
│ │ │ ├─ SpecifyShape [id CZ] 103
│ │ │ │ ├─ Split{2}.1 [id CL] 102
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id DA]
│ │ │ └─ ExpandDims{axis=0} [id DB] 107
│ │ │ └─ Sum{axes=None} [id DC] 106
│ │ │ └─ Sub [id CM] 101
│ │ │ └─ ···
│ │ └─ Add [id X] 29
│ │ └─ ···
│ └─ True_div [id DD] 105
│ ├─ Neg [id DE] 104
│ │ └─ SpecifyShape [id CZ] 103
│ │ └─ ···
│ └─ Sqrt [id Y] 28
│ └─ ···
├─ Reshape{1} [id DF] 98
│ ├─ Add [id DG] 97
│ │ ├─ 1.0 [id DH]
│ │ └─ Mul [id DI] 96
│ │ ├─ Add [id DJ] 95
│ │ │ ├─ Switch [id DK] 94
│ │ │ │ ├─ Ge [id DL] 93
│ │ │ │ │ ├─ Exp [id BG] 22
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 0.0 [id DM]
│ │ │ │ ├─ Neg [id DN] 92
│ │ │ │ │ └─ Exp [id BG] 22
│ │ │ │ │ └─ ···
│ │ │ │ └─ 0.0 [id DM]
│ │ │ └─ Sum{axes=None} [id DO] 91
│ │ │ └─ Mul [id DP] 90
│ │ │ ├─ AdvancedIncSubtensor1{no_inplace,inc} [id CO] 89
│ │ │ │ └─ ···
│ │ │ └─ Sub [id Q] 36
│ │ │ └─ ···
│ │ └─ Exp [id BG] 22
│ │ └─ ···
│ └─ [-1] [id CJ]
├─ Reshape{1} [id DQ] 85
│ ├─ Add [id DR] 84
│ │ ├─ Mul [id DS] 83
│ │ │ ├─ -0.25 [id DT]
│ │ │ └─ floor_effect [id BL]
│ │ └─ True_div [id DU] 82
│ │ ├─ Sum{axes=None} [id DV] 81
│ │ │ └─ Mul [id DW] 55
│ │ │ ├─ True_div [id CU] 43
│ │ │ │ └─ ···
│ │ │ └─ [1. 0. 0. ... 0. 0. 0.] [id BM]
│ │ └─ Exp [id CH] 0
│ │ └─ ···
│ └─ [-1] [id CJ]
├─ Add [id DX] 80
│ ├─ Split{2}.0 [id DY] 72
│ │ ├─ Sub [id DZ] 71
│ │ │ ├─ Mul [id EA] 70
│ │ │ │ ├─ AdvancedIncSubtensor1{no_inplace,inc} [id EB] 59
│ │ │ │ │ ├─ Alloc [id EC] 58
│ │ │ │ │ │ ├─ [0.] [id CQ]
│ │ │ │ │ │ └─ Add [id ED] 57
│ │ │ │ │ │ ├─ Shape_i{0} [id CC] 4
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id CS]
│ │ │ │ │ ├─ True_div [id EE] 56
│ │ │ │ │ │ ├─ Mul [id DW] 55
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id CV] 42
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ [ 0 0 0 ... 83 84 84] [id BI]
│ │ │ │ └─ ExpandDims{axis=0} [id CE] 3
│ │ │ │ └─ ···
│ │ │ └─ Sub [id BQ] 16
│ │ │ └─ ···
│ │ ├─ 0 [id B]
│ │ └─ MakeVector{dtype='int64'} [id EF] 69
│ │ ├─ Shape_i{0} [id CC] 4
│ │ │ └─ ···
│ │ └─ 1 [id CS]
│ ├─ True_div [id EG] 79
│ │ ├─ Sub [id EH] 78
│ │ │ ├─ SpecifyShape [id EI] 73
│ │ │ │ ├─ Split{2}.1 [id DY] 72
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id DA]
│ │ │ └─ ExpandDims{axis=0} [id EJ] 77
│ │ │ └─ Sum{axes=None} [id EK] 76
│ │ │ └─ Sub [id DZ] 71
│ │ │ └─ ···
│ │ └─ Add [id BX] 9
│ │ └─ ···
│ └─ True_div [id EL] 75
│ ├─ Neg [id EM] 74
│ │ └─ SpecifyShape [id EI] 73
│ │ └─ ···
│ └─ Sqrt [id BY] 8
│ └─ ···
├─ Reshape{1} [id EN] 68
│ ├─ Add [id EO] 67
│ │ ├─ 1.0 [id DH]
│ │ └─ Mul [id EP] 66
│ │ ├─ Add [id EQ] 65
│ │ │ ├─ Switch [id ER] 64
│ │ │ │ ├─ Ge [id ES] 63
│ │ │ │ │ ├─ Exp [id CF] 2
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 0.0 [id DM]
│ │ │ │ ├─ Neg [id ET] 62
│ │ │ │ │ └─ Exp [id CF] 2
│ │ │ │ │ └─ ···
│ │ │ │ └─ 0.0 [id DM]
│ │ │ └─ Sum{axes=None} [id EU] 61
│ │ │ └─ Mul [id EV] 60
│ │ │ ├─ AdvancedIncSubtensor1{no_inplace,inc} [id EB] 59
│ │ │ │ └─ ···
│ │ │ └─ Sub [id BQ] 16
│ │ │ └─ ···
│ │ └─ Exp [id CF] 2
│ │ └─ ···
│ └─ [-1] [id CJ]
└─ Reshape{1} [id EW] 54
├─ Add [id EX] 53
│ ├─ 1.0 [id DH]
│ ├─ Switch [id EY] 52
│ │ ├─ Ge [id EZ] 51
│ │ │ ├─ Exp [id CH] 0
│ │ │ │ └─ ···
│ │ │ └─ 0.0 [id DM]
│ │ ├─ Mul [id FA] 50
│ │ │ ├─ Exp [id FB] 49
│ │ │ │ └─ Add [id FC] 48
│ │ │ │ ├─ sigma_log__ [id CI]
│ │ │ │ └─ sigma_log__ [id CI]
│ │ │ └─ -0.4444444444444444 [id FD]
│ │ └─ 0.0 [id DM]
│ ├─ Mul [id FE] 47
│ │ ├─ True_div [id FF] 46
│ │ │ ├─ Sum{axes=None} [id FG] 45
│ │ │ │ └─ Mul [id FH] 44
│ │ │ │ ├─ True_div [id CU] 43
│ │ │ │ │ └─ ···
│ │ │ │ └─ Sub [id K] 41
│ │ │ │ └─ ···
│ │ │ └─ Sqr [id FI] 1
│ │ │ └─ Exp [id CH] 0
│ │ │ └─ ···
│ │ └─ Exp [id CH] 0
│ │ └─ ···
│ └─ -919.0 [id FJ]
└─ [-1] [id CJ] |
Description
@ricardoV94 did a nice perf improvement in pymc-devs/pymc#7578 to try to speedup jitted backends. I tried out torch as well. The model performed quite slow.
We need to investigate why
When doing perf evaluations, keep in mind that torch does a lot of caching. If you want a truly cache-less eval, you can either add
torch.compiler.reset()
or set the env variable to disable the dynamo cache (google it).The text was updated successfully, but these errors were encountered: