-
Notifications
You must be signed in to change notification settings - Fork 139
HACK: Implement Numba VM with caching of individual nodes #1604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments; hope you don't mind
kwargs.setdefault("cache", config.numba__cache) | ||
kwargs.setdefault("no_cpython_wrapper", True) | ||
kwargs.setdefault("no_cfunc_wrapper", True) | ||
kwargs.setdefault("cache", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the numba currently can't detect changes in other modules, which could lead to outdated caches. I'm guessing that's why they've decided to make it opt-in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean? So far numba caching has been fine, to the extent that it was actually used (not much). It plays a much bigger role in this PR approach, but we are also pretty-much customizing it's behavior completely.
else: | ||
from pytensor.link.numba.dispatch.basic import numba_njit | ||
|
||
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) | ||
return jitted_fn | ||
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) | ||
return jitted_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose there's no need for this else
branch 🤷🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a stylistic choice. Should be enforced by ruff I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to open an issue to add https://docs.astral.sh/ruff/rules/superfluous-else-return/ to our rules.
pytensor/link/utils.py
Outdated
with open(filename, "wb") as f: | ||
f.write(src.encode()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with open(filename, "wb") as f: | |
f.write(src.encode()) | |
filename.write_bytes(src.encode()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL
I don't, but it's still too early for that sort of feedback. I'm just thinkering around at this point. |
037e889
to
aff2a9a
Compare
And make that the default backend
aff2a9a
to
978b701
Compare
) | ||
|
||
signature = create_numba_signature(node, force_scalar=True) | ||
# signature = create_numba_signature(node, force_scalar=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was causing eager compilation during dispatch, nullifying any caching benefits
@@ -239,11 +241,21 @@ def codegen( | |||
output_core_shapes, | |||
) | |||
|
|||
core_signature = typingctx.resolve_function_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes compilation of the inner function, so I moved it to the codegen, otherwise we had to pay the cost, even if we could cache it? Not sure. But even if not, I guess we generally want to lazy compile by default?
Disclaimer: this is still 100% on hack status, and I don't understand half of the things I did
When we tried #811 it was obvious that numba compile times were prohibitive.
This PR tries a different approach (still at the hacking stage), of using a mode more like the CVM, where individual nodes are jit_compiled, but the whole graph (i.e., the "VM") is not. This allows reusing pre-compiled/cached nodes across different functions, bringing the compilation cost down.
It requires interfacing with the numba cache locator to direct it to cached objects, which requires defining our own cache keys. Numba usually uses the file line position and contents as the cache key, but this doesn't work for dynamically generated files (at least not if stored in a random temp file) nor really for nested functions like those built for
Elemwise
, Also some Ops are string-generated, and others are regular python functions with globals which numba can usually cache. All this has to be re-examined.We are also not calling njit on inner Ops (the
store_core_outputs
/ScalarOp
) ofElemwise
, but instead doingregister_jittable
. This was needed for caching to work, but I also don't have a good mental picture of why yet.If we need this approach, we must move the control of when does an Op get jitted. Sounds like the right place would be in the
numba_funcify_FunctionGraph
.Results:
We're finally approaching the speed of the previous backend (at least for single function compilation + eval). Probably could get it there with more optimizing, but a small slowdown is acceptable.
TODO:
cache=True
failures with locally defined functions numba/numba#10098