Skip to content

Commit

Permalink
refactor: move all multiplication dimensional analysis logic to place…
Browse files Browse the repository at this point in the history
…holder function

Simplifies replace_placeholder_funcs function by removing a special case
  • Loading branch information
mgreminger committed Jan 3, 2025
1 parent b108310 commit c9f5f73
Showing 1 changed file with 27 additions and 30 deletions.
57 changes: 27 additions & 30 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,10 +990,32 @@ def custom_matmul(exp1: Expr, exp2: Expr):
else:
return Mul(exp1, exp2)

def custom_matmul_dims(*args: Expr):
if len(args) == 2 and is_matrix(args[0]) and is_matrix(args[1]) and \
def custom_multiply_dims(matmult: bool, *args: Expr):
matrix_args: list[Matrix] = []
scalar_args: list[Expr] = []
for arg in args:
if is_matrix(arg):
matrix_args.append(arg)
else:
scalar_args.append(arg)

if len(matrix_args) > 0 and len(scalar_args) > 0:
first_matrix = matrix_args[0]
scalar = Mul(*scalar_args)
new_rows = []
for i in range(first_matrix.rows):
new_row = []
new_rows.append(new_row)
for j in range(first_matrix.cols):
new_row.append(scalar*first_matrix[i,j]) # type: ignore

matrix_args[0] = Matrix(new_rows)
args = cast(tuple[Expr], matrix_args)

if matmult and len(args) == 2 and is_matrix(args[0]) and is_matrix(args[1]) and \
(((args[0].rows == 3 and args[0].cols == 1) and (args[1].rows == 3 and args[1].cols == 1)) or \
((args[0].rows == 1 and args[0].cols == 3) and (args[1].rows == 1 and args[1].cols == 3))):
# cross product detected for matrix multiplication operator

result = Matrix([Add(Mul(args[0][1],args[1][2]),Mul(args[0][2],args[1][1])),
Add(Mul(args[0][2],args[1][0]),Mul(args[0][0],args[1][2])),
Expand Down Expand Up @@ -1455,8 +1477,8 @@ def get_next_id(self):
cast(Function, Function('_Inverse')) : {"dim_func": ensure_inverse_dims, "sympy_func": UniversalInverse},
cast(Function, Function('_Transpose')) : {"dim_func": custom_transpose, "sympy_func": custom_transpose},
cast(Function, Function('_Determinant')) : {"dim_func": custom_determinant, "sympy_func": custom_determinant},
cast(Function, Function('_mat_multiply')) : {"dim_func": custom_matmul_dims, "sympy_func": custom_matmul},
cast(Function, Function('_multiply')) : {"dim_func": Mul, "sympy_func": Mul},
cast(Function, Function('_mat_multiply')) : {"dim_func": partial(custom_multiply_dims, True), "sympy_func": custom_matmul},
cast(Function, Function('_multiply')) : {"dim_func": partial(custom_multiply_dims, False), "sympy_func": Mul},
cast(Function, Function('_IndexMatrix')) : {"dim_func": IndexMatrix, "sympy_func": IndexMatrix},
cast(Function, Function('_Eq')) : {"dim_func": Eq, "sympy_func": Eq},
cast(Function, Function('_norm')) : {"dim_func": custom_norm, "sympy_func": custom_norm},
Expand Down Expand Up @@ -1506,32 +1528,7 @@ def replace_placeholder_funcs(expr: Expr,
if len(expr.args) == 0:
return expr

if func_key == "dim_func" and expr.func in multiply_placeholder_set:
processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args]
matrix_args = []
scalar_args = []
for arg in processed_args:
if is_matrix(cast(Expr, arg)):
matrix_args.append(arg)
else:
scalar_args.append(arg)

if len(matrix_args) > 0 and len(scalar_args) > 0:
first_matrix = matrix_args[0]
scalar = math.prod(scalar_args)
new_rows = []
for i in range(first_matrix.rows):
new_row = []
new_rows.append(new_row)
for j in range(first_matrix.cols):
new_row.append(scalar*first_matrix[i,j])

matrix_args[0] = Matrix(new_rows)

return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*matrix_args))
else:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*processed_args))
elif expr.func in dummy_var_placeholder_set and func_key == "dim_func":
if expr.func in dummy_var_placeholder_set and func_key == "dim_func":
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
elif expr.func in placeholder_set:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args)))
Expand Down

0 comments on commit c9f5f73

Please sign in to comment.