Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 2, 2025

Disclaimer: this is still 100% on hack status, and I don't understand half of the things I did

When we tried #811 it was obvious that numba compile times were prohibitive.

This PR tries a different approach (still at the hacking stage), of using a mode more like the CVM, where individual nodes are jit_compiled, but the whole graph (i.e., the "VM") is not. This allows reusing pre-compiled/cached nodes across different functions, bringing the compilation cost down.

It requires interfacing with the numba cache locator to direct it to cached objects, which requires defining our own cache keys. Numba usually uses the file line position and contents as the cache key, but this doesn't work for dynamically generated files (at least not if stored in a random temp file) nor really for nested functions like those built for Elemwise, Also some Ops are string-generated, and others are regular python functions with globals which numba can usually cache. All this has to be re-examined.

We are also not calling njit on inner Ops (the store_core_outputs / ScalarOp) of Elemwise, but instead doing register_jittable. This was needed for caching to work, but I also don't have a good mental picture of why yet.
If we need this approach, we must move the control of when does an Op get jitted. Sounds like the right place would be in the numba_funcify_FunctionGraph.

Results:

Second pass over tests/tensor/rewriting/test_basic.py (to allow compiling everything first):
2s with C_VM backend
54s with Numba backend
34s with Numba VM without cache
4s with Numba VM with cache

We're finally approaching the speed of the previous backend (at least for single function compilation + eval). Probably could get it there with more optimizing, but a small slowdown is acceptable.

TODO:

  • We are still writing python strings to the filesystem to compile them, this is probably not needed as explored in Cache numba stuff #1326 (last commit?)
  • We have to compile some functions that don't really need so we can cache it, such as with Elemwise. This is related to https://numba.discourse.group/t/caching-redefined-functions/3057 but I don't yet have a clear picture.
  • Proper cache keys, I just hacked some quick things. Perhaps use the source code of the generated functions?
    • Composite key is certainly broken
    • Cache whole FunctionGraph, this would avoid recompiling identical graphs in the regular Numba mode, not just NumbaCVM (it's also needed for correct cache of Composite/Blockwise/Scan,OpFromGraph (i.e., anything with inner Ops)).
  • Figure out what happens with Ops that run with object mode?
  • Handle functions with pointers / large constants that can't traditionally be cached (not sure what's happening now). Related to cache=True failures with locally defined functions numba/numba#10098
  • Benchmark slowdown from the "VM" approach in realistic functions. Consider using/adapting CVM to orchestrate the calls to the individuals nodes (would need to use the thunk approach). Right now the VM is the python source code generated by the outermost unjitted FunctionGraph

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Contributor

@jorenham jorenham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments; hope you don't mind

kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
kwargs.setdefault("cache", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the numba currently can't detect changes in other modules, which could lead to outdated caches. I'm guessing that's why they've decided to make it opt-in

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean? So far numba caching has been fine, to the extent that it was actually used (not much). It plays a much bigger role in this PR approach, but we are also pretty-much customizing it's behavior completely.

Comment on lines +19 to +23
else:
from pytensor.link.numba.dispatch.basic import numba_njit

jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there's no need for this else branch 🤷🏻

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a stylistic choice. Should be enforced by ruff I guess

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to open an issue to add https://docs.astral.sh/ruff/rules/superfluous-else-return/ to our rules.

Comment on lines 594 to 595
with open(filename, "wb") as f:
f.write(src.encode())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with open(filename, "wb") as f:
f.write(src.encode())
filename.write_bytes(src.encode())

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL

@ricardoV94
Copy link
Member Author

left some comments; hope you don't mind

I don't, but it's still too early for that sort of feedback. I'm just thinkering around at this point.

@ricardoV94 ricardoV94 changed the title Implement Numba VM Implement Numba VM with caching of individual nodes Sep 14, 2025
@ricardoV94 ricardoV94 changed the title Implement Numba VM with caching of individual nodes HACK: Implement Numba VM with caching of individual nodes Sep 14, 2025
)

signature = create_numba_signature(node, force_scalar=True)
# signature = create_numba_signature(node, force_scalar=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing eager compilation during dispatch, nullifying any caching benefits

@@ -239,11 +241,21 @@ def codegen(
output_core_shapes,
)

core_signature = typingctx.resolve_function_type(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes compilation of the inner function, so I moved it to the codegen, otherwise we had to pay the cost, even if we could cache it? Not sure. But even if not, I guess we generally want to lazy compile by default?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants