Skip to content

Commit aff2a9a

Browse files
committed
Cache more Ops
1 parent ec4f39f commit aff2a9a

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from copy import copy
55
from functools import singledispatch
6+
from hashlib import sha256
67
from textwrap import dedent
78

89
import numba
@@ -14,7 +15,8 @@
1415
from numba import types
1516
from numba.core.errors import NumbaWarning, TypingError
1617
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
17-
from numba.extending import box, overload, register_jitable as _register_jitable
18+
from numba.extending import box, overload
19+
from numba.extending import register_jitable as _register_jitable
1820

1921
from pytensor import In, config
2022
from pytensor.compile import NUMBA
@@ -25,11 +27,9 @@
2527
from pytensor.graph.fg import FunctionGraph
2628
from pytensor.graph.type import Type
2729
from pytensor.ifelse import IfElse
30+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
2831
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
29-
from pytensor.link.utils import (
30-
compile_function_src,
31-
fgraph_to_python,
32-
)
32+
from pytensor.link.utils import fgraph_to_python
3333
from pytensor.scalar.basic import ScalarType
3434
from pytensor.sparse import SparseTensorType
3535
from pytensor.tensor.basic import Nonzero
@@ -558,7 +558,13 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
558558
"""
559559
)
560560

561-
specify_shape = compile_function_src(func, "specify_shape", globals())
561+
specify_shape = compile_and_cache_numba_function_src(
562+
func,
563+
"specify_shape",
564+
globals(),
565+
# use the sha256 of func_conditions in the key to avoid long keys
566+
key=f"SpecifyShape_{sha256(';'.join(func_conditions).encode()).hexdigest()}",
567+
)
562568
return numba_njit(specify_shape)
563569

564570

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import singledispatch
2+
from hashlib import sha256
23
from textwrap import dedent, indent
34

45
import numba
@@ -276,7 +277,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
276277
)
277278

278279
# TODO: Proper key
279-
key = "_".join(
280+
core_op_key = "_".join(
280281
map(
281282
str,
282283
(
@@ -287,7 +288,10 @@ def numba_funcify_Elemwise(op, node, **kwargs):
287288
),
288289
)
289290
)
290-
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout, core_op_key=key)
291+
core_op_key = sha256(core_op_key.encode()).hexdigest()
292+
core_op_fn = store_core_outputs(
293+
scalar_op_fn, nin=nin, nout=nout, core_op_key=core_op_key
294+
)
291295

292296
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
293297
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
@@ -348,11 +352,24 @@ def elemwise(*inputs):
348352
def ov_elemwise(*inputs):
349353
return elemwise_wrapper
350354

355+
# TODO: Also input dtypes in key
356+
elemwise_key = "_".join(
357+
map(
358+
str,
359+
(
360+
"Elemwise",
361+
core_op_key,
362+
input_bc_patterns,
363+
inplace_pattern,
364+
),
365+
)
366+
)
367+
elemwise_key = sha256(elemwise_key.encode()).hexdigest()
351368
f = compile_and_cache_numba_function_src(
352369
"def f(*inputs): return elemwise(*inputs)",
353370
"f",
354371
{**globals(), **{"elemwise": elemwise}},
355-
key=f"Elemwise_{key}",
372+
key=elemwise_key,
356373
)
357374

358375
return numba_njit(f)

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22

33
from pytensor.graph import Type
4+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
45
from pytensor.link.numba.dispatch import numba_funcify
56
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
6-
from pytensor.link.utils import compile_function_src, unique_name_generator
7+
from pytensor.link.utils import unique_name_generator
78
from pytensor.tensor import TensorType
89
from pytensor.tensor.rewriting.subtensor import is_full_slice
910
from pytensor.tensor.subtensor import (
@@ -95,10 +96,11 @@ def {function_name}({", ".join(input_names)}):
9596
return np.asarray(z)
9697
"""
9798

98-
func = compile_function_src(
99+
func = compile_and_cache_numba_function_src(
99100
subtensor_def_src,
100101
function_name=function_name,
101102
global_env=globals() | {"np": np},
103+
key=f"{function_name}({', '.join(str(i.type) for i in node.inputs)})",
102104
)
103105
return numba_njit(func, boundscheck=True)
104106

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import numpy as np
44

5+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
56
from pytensor.link.numba.dispatch import basic as numba_basic
67
from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify
7-
from pytensor.link.utils import compile_function_src, unique_name_generator
8+
from pytensor.link.utils import unique_name_generator
89
from pytensor.tensor.basic import (
910
Alloc,
1011
AllocEmpty,
@@ -49,8 +50,11 @@ def allocempty({", ".join(shape_var_names)}):
4950
return np.empty(scalar_shape, dtype)
5051
"""
5152

52-
alloc_fn = compile_function_src(
53-
alloc_def_src, "allocempty", {**globals(), **global_env}
53+
alloc_fn = compile_and_cache_numba_function_src(
54+
alloc_def_src,
55+
"allocempty",
56+
{**globals(), **global_env},
57+
key=f"AllocEmpty({op.dtype})",
5458
)
5559

5660
return numba_basic.numba_njit(alloc_fn)
@@ -93,7 +97,12 @@ def alloc(val, {", ".join(shape_var_names)}):
9397
res[...] = val
9498
return res
9599
"""
96-
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
100+
alloc_fn = compile_and_cache_numba_function_src(
101+
alloc_def_src,
102+
"alloc",
103+
{**globals(), **global_env},
104+
key="Alloc",
105+
)
97106

98107
return numba_basic.numba_njit(alloc_fn)
99108

@@ -212,8 +221,11 @@ def makevector({", ".join(input_names)}):
212221
return np.array({create_list_string(input_names)}, dtype=dtype)
213222
"""
214223

215-
makevector_fn = compile_function_src(
216-
makevector_def_src, "makevector", {**globals(), **global_env}
224+
makevector_fn = compile_and_cache_numba_function_src(
225+
makevector_def_src,
226+
"makevector",
227+
{**globals(), **global_env},
228+
key=f"MakeVector({op.dtype})",
217229
)
218230

219231
return numba_basic.numba_njit(makevector_fn)

0 commit comments

Comments
 (0)