Skip to content

Commit

Permalink
Remove workarounds now that sparse-0.9.1 has numba support
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-wieser committed Jan 23, 2020
1 parent aa01463 commit 039f924
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
27 changes: 10 additions & 17 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,16 @@ def _get_mult_function(mt: sparse.COO):
func : function (array_like (n_dims,), array_like (n_dims,)) -> array_like (n_dims,)
A function that computes the appropriate multiplication
"""
# unpack for numba
dims = mt.shape[1]
k_list, l_list, m_list = mt.coords
mult_table_vals = mt.data

@numba.generated_jit(nopython=True)
def mv_mult(value, other_value):
# this casting will be done at jit-time
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype)
mult_table_vals_t = mult_table_vals.astype(ret_dtype)
ret_dtype = _get_mult_function_result_type(value, other_value, mt.dtype)
mt_t = mt.astype(ret_dtype)

def mult_inner(value, other_value):
output = np.zeros(dims, dtype=ret_dtype)
for k, l, m, val in zip(k_list, l_list, m_list, mult_table_vals_t):
output = np.zeros(mt_t.shape[1], dtype=ret_dtype)
k_list, l_list, m_list = mt_t.coords[0], mt_t.coords[1], mt_t.coords[2]
for k, l, m, val in zip(k_list, l_list, m_list, mt_t.data):
output[l] += value[k] * val * other_value[m]
return output

Expand All @@ -227,27 +223,24 @@ def _get_mult_function_runtime_sparse(mt: sparse.COO):
TODO: determine if this actually helps.
"""
# unpack for numba
dims = mt.shape[1]
k_list, l_list, m_list = mt.coords
mult_table_vals = mt.data

@numba.generated_jit(nopython=True)
def mv_mult(value, other_value):
# this casting will be done at jit-time
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype)
mult_table_vals_t = mult_table_vals.astype(ret_dtype)
ret_dtype = _get_mult_function_result_type(value, other_value, mt.dtype)
mt_t = mt.astype(ret_dtype)

def mult_inner(value, other_value):
output = np.zeros(dims, dtype=ret_dtype)
output = np.zeros(mt_t.shape[1], dtype=ret_dtype)
k_list, l_list, m_list = mt_t.coords[0], mt_t.coords[1], mt_t.coords[2]
for ind, k in enumerate(k_list):
v_val = value[k]
if v_val != 0.0:
m = m_list[ind]
ov_val = other_value[m]
if ov_val != 0.0:
l = l_list[ind]
output[l] += v_val * mult_table_vals_t[ind] * ov_val
output[l] += v_val * mt_t.data[ind] * ov_val
return output
return mult_inner

Expand Down
7 changes: 1 addition & 6 deletions clifford/tools/g3c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,10 @@ def scale_TR_translation(TR, scale):

def left_gmt_generator(mt=layout.gmt):
# unpack for numba
k_list, l_list, m_list = mt.coords
mult_table_vals = mt.data
gaDims = mt.shape[1]
val_get_left_gmt_matrix = cf._numba_val_get_left_gmt_matrix

@numba.njit
def get_left_gmt(x_val):
return val_get_left_gmt_matrix(
x_val, k_list, l_list, m_list, mult_table_vals, gaDims)
return val_get_left_gmt_matrix(x_val, mt)
return get_left_gmt


Expand Down

0 comments on commit 039f924

Please sign in to comment.