Explicit RNG API #1706
ricardoV94
started this conversation in
Ideas
Replies: 4 comments 1 reply
-
|
I really like it! |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Dirty draft PR: #1707 |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
I like the way jax does explicit rng. Anything closer that (like this proposal) I prefer |
Beta Was this translation helpful? Give feedback.
1 reply
-
|
Like it |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I've been thinking we should cleanup the PyTensor RNG API to be more transparent:
It looks almost like numpy, except you get the next rng as an output, and you must use that afterwards. I added a warning if you try to call a method from the same rng twice.
It's not unlike the pattern
x = x[idx].set(y), where you have to use the nextxif you want it to be the version withyin it. All coming from the same requirement of immutability.This would replace the
RandomStreamhelper, and more importantly the wholeSharedVariable.default_updatehack that exists solely to try and hide the next_rng thing, and get it picked up by function. This also adds some thorny logic inside Scan, which tries to figure out if there was a newrng.foo()call inside the step function, and get default updates.All this is quite problematic because you can do graph cloning/manipulation correctly as you would need to check all shared variables default_update and decide what to do with those. Related to #1704
May be worth reviewing implementation details in https://pytensor.readthedocs.io/en/latest/tutorial/prng.html#prng
Beta Was this translation helpful? Give feedback.
All reactions