From 6f1c12e8120b8e4995bc0fdc1c9968a8df6a502b Mon Sep 17 00:00:00 2001 From: Jacan Chaplais Date: Sat, 25 Nov 2023 10:23:32 +0000 Subject: [PATCH] added caching to numba compiled funcs --- graphicle/calculate.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/graphicle/calculate.py b/graphicle/calculate.py index a29e82f..b395c45 100644 --- a/graphicle/calculate.py +++ b/graphicle/calculate.py @@ -425,7 +425,7 @@ def flow_trace( return traces -@nb.njit("float64[:](float64[:], float64[:], float64)") +@nb.njit("float64[:](float64[:], float64[:], float64)", cache=True) def _rapidity( energy: base.DoubleVector, z: base.DoubleVector, zero_tol: float ) -> base.DoubleVector: @@ -457,7 +457,7 @@ def _rapidity( return rap -@nb.vectorize("float64(float64, float64)") +@nb.vectorize("float64(float64, float64)", cache=True) def _root_diff_two_squares( x1: base.DoubleUfunc, x2: base.DoubleUfunc ) -> base.DoubleUfunc: @@ -494,6 +494,7 @@ def _root_diff_two_squares( @nb.njit( "float64[:, :](float64[:], float64[:], complex128[:], complex128[:])", parallel=True, + cache=True, ) def _delta_R( rapidity_1: base.DoubleVector, @@ -532,7 +533,7 @@ def _delta_R( return result -@nb.njit("float64[:, :](float64[:], complex128[:])", parallel=True) +@nb.njit("float64[:, :](float64[:], complex128[:])", parallel=True, cache=True) def _delta_R_symmetric( rapidity: base.DoubleVector, xy_pol: base.ComplexVector ) -> base.DoubleVector: @@ -569,7 +570,7 @@ def _delta_R_symmetric( return result -@nb.njit("float32[:](bool_[:, :])", parallel=True) +@nb.njit("float32[:](bool_[:, :])", parallel=True, cache=True) def _clust_coeffs(adj: base.BoolVector) -> base.FloatVector: num_nodes = adj.shape[0] coefs = np.empty(num_nodes, dtype=np.float32) @@ -664,7 +665,7 @@ def aggregate_momenta( return momentum_class(list(it.chain.from_iterable(pmu_sums))) -@nb.njit("float64(float64, float64, float64)") +@nb.njit("float64(float64, float64, float64)", cache=True) def _three_norm(x: float, y: float, z: float) -> float: max_component = max(abs(x), abs(y), abs(z)) max_recip = 1.0 / max_component @@ -674,7 +675,7 @@ def _three_norm(x: float, y: float, z: float) -> float: return max_component * np.sqrt(x * x + y * y + z * z) -@nb.njit("float64[:](float64[:])") +@nb.njit("float64[:](float64[:])", cache=True) def _angles_to_axis(axis_angles: base.DoubleVector) -> base.DoubleVector: """Given the azimuth and inclination angles, return the Cartesian coordinates a unit vector. @@ -699,7 +700,10 @@ def _angles_to_axis(axis_angles: base.DoubleVector) -> base.DoubleVector: ) -@nb.njit(nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE)) +@nb.njit( + nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE), + cache=True, +) def _thrust_with_grad( axis_angles: base.DoubleVector, momenta: base.VoidVector ) -> ty.Tuple[float, base.DoubleVector]: @@ -855,7 +859,10 @@ def thrust( return thrust_val -@nb.njit(nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE)) +@nb.njit( + nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE), + cache=True, +) def _spherocity_with_grad( axis_angles: base.DoubleVector, momenta: base.VoidVector ) -> ty.Tuple[float, base.DoubleVector]: @@ -1024,7 +1031,7 @@ def spherocity( return sph_val -@nb.njit(nb.float64(PMU_DTYPE)) +@nb.njit(nb.float64(PMU_DTYPE), cache=True) def _c_parameter(momenta: base.VoidVector) -> float: output = norm_sum = 0.0 for idx_i, pmu_i in enumerate(momenta): @@ -1071,7 +1078,8 @@ def c_parameter(pmu: "MomentumArray") -> float: [ "float64(bool_[:], bool_[:], Optional(float64[:]))", "float64(bool_[:], bool_[:], Omitted(None))", - ] + ], + cache=True, ) def _jaccard_distance( u: base.BoolVector,