Skip to content

Commit

Permalink
added caching to numba compiled funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Nov 25, 2023
1 parent 6e3080f commit 6f1c12e
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions graphicle/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def flow_trace(
return traces


@nb.njit("float64[:](float64[:], float64[:], float64)")
@nb.njit("float64[:](float64[:], float64[:], float64)", cache=True)
def _rapidity(
energy: base.DoubleVector, z: base.DoubleVector, zero_tol: float
) -> base.DoubleVector:
Expand Down Expand Up @@ -457,7 +457,7 @@ def _rapidity(
return rap


@nb.vectorize("float64(float64, float64)")
@nb.vectorize("float64(float64, float64)", cache=True)
def _root_diff_two_squares(
x1: base.DoubleUfunc, x2: base.DoubleUfunc
) -> base.DoubleUfunc:
Expand Down Expand Up @@ -494,6 +494,7 @@ def _root_diff_two_squares(
@nb.njit(
"float64[:, :](float64[:], float64[:], complex128[:], complex128[:])",
parallel=True,
cache=True,
)
def _delta_R(
rapidity_1: base.DoubleVector,
Expand Down Expand Up @@ -532,7 +533,7 @@ def _delta_R(
return result


@nb.njit("float64[:, :](float64[:], complex128[:])", parallel=True)
@nb.njit("float64[:, :](float64[:], complex128[:])", parallel=True, cache=True)
def _delta_R_symmetric(
rapidity: base.DoubleVector, xy_pol: base.ComplexVector
) -> base.DoubleVector:
Expand Down Expand Up @@ -569,7 +570,7 @@ def _delta_R_symmetric(
return result


@nb.njit("float32[:](bool_[:, :])", parallel=True)
@nb.njit("float32[:](bool_[:, :])", parallel=True, cache=True)
def _clust_coeffs(adj: base.BoolVector) -> base.FloatVector:
num_nodes = adj.shape[0]
coefs = np.empty(num_nodes, dtype=np.float32)
Expand Down Expand Up @@ -664,7 +665,7 @@ def aggregate_momenta(
return momentum_class(list(it.chain.from_iterable(pmu_sums)))


@nb.njit("float64(float64, float64, float64)")
@nb.njit("float64(float64, float64, float64)", cache=True)
def _three_norm(x: float, y: float, z: float) -> float:
max_component = max(abs(x), abs(y), abs(z))
max_recip = 1.0 / max_component
Expand All @@ -674,7 +675,7 @@ def _three_norm(x: float, y: float, z: float) -> float:
return max_component * np.sqrt(x * x + y * y + z * z)


@nb.njit("float64[:](float64[:])")
@nb.njit("float64[:](float64[:])", cache=True)
def _angles_to_axis(axis_angles: base.DoubleVector) -> base.DoubleVector:
"""Given the azimuth and inclination angles, return the Cartesian
coordinates a unit vector.
Expand All @@ -699,7 +700,10 @@ def _angles_to_axis(axis_angles: base.DoubleVector) -> base.DoubleVector:
)


@nb.njit(nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE))
@nb.njit(
nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE),
cache=True,
)
def _thrust_with_grad(
axis_angles: base.DoubleVector, momenta: base.VoidVector
) -> ty.Tuple[float, base.DoubleVector]:
Expand Down Expand Up @@ -855,7 +859,10 @@ def thrust(
return thrust_val


@nb.njit(nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE))
@nb.njit(
nb.types.Tuple((nb.float64, nb.float64[:]))(nb.float64[:], PMU_DTYPE),
cache=True,
)
def _spherocity_with_grad(
axis_angles: base.DoubleVector, momenta: base.VoidVector
) -> ty.Tuple[float, base.DoubleVector]:
Expand Down Expand Up @@ -1024,7 +1031,7 @@ def spherocity(
return sph_val


@nb.njit(nb.float64(PMU_DTYPE))
@nb.njit(nb.float64(PMU_DTYPE), cache=True)
def _c_parameter(momenta: base.VoidVector) -> float:
output = norm_sum = 0.0
for idx_i, pmu_i in enumerate(momenta):
Expand Down Expand Up @@ -1071,7 +1078,8 @@ def c_parameter(pmu: "MomentumArray") -> float:
[
"float64(bool_[:], bool_[:], Optional(float64[:]))",
"float64(bool_[:], bool_[:], Omitted(None))",
]
],
cache=True,
)
def _jaccard_distance(
u: base.BoolVector,
Expand Down

0 comments on commit 6f1c12e

Please sign in to comment.