Skip to content

Commit

Permalink
refactor: simplify logic for preventing unit canceling
Browse files Browse the repository at this point in the history
Removes one recursive walk through the expression tree.
  • Loading branch information
mgreminger committed Jan 3, 2025
1 parent 306bd45 commit b108310
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 35 deletions.
39 changes: 6 additions & 33 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,35 +897,6 @@ def custom_latex(expression: Expr) -> str:

_range = Function("_range")

def walk_tree(grandparent_func, parent_func, expr) -> Expr:

if is_matrix(expr):
rows = []
for i in range(expr.rows):
row = []
rows.append(row)
for j in range(expr.cols):
row.append(walk_tree(parent_func, Matrix, expr[i,j]))

return cast(Expr, Matrix(rows))

if len(expr.args) == 0:
if parent_func is not Pow and parent_func is not Inverse and expr.is_negative:
return -1*expr
else:
return expr

if expr.func == _range:
new_args = expr.args
else:
new_args = (walk_tree(parent_func, expr.func, arg) for arg in expr.args)

return expr.func(*new_args)

def subtraction_to_addition(expression: Expr | Matrix) -> Expr:
return walk_tree("root", "root", expression)


def ensure_dims_all_compatible(*args):
if args[0].is_zero:
if all(arg.is_zero for arg in args):
Expand Down Expand Up @@ -1184,6 +1155,9 @@ def custom_integral_dims(local_expr: Expr, global_expr: Expr, dummy_integral_var
return global_expr * lower_limit_dims # type: ignore
else:
return global_expr * integral_var # type: ignore

def custom_add_dims(*args: Expr):
return Add(*[Abs(arg) for arg in args])


CP = None
Expand Down Expand Up @@ -1494,6 +1468,7 @@ def get_next_id(self):
cast(Function, Function('_Integral')) : {"dim_func": custom_integral_dims, "sympy_func": custom_integral},
cast(Function, Function('_range')) : {"dim_func": custom_range, "sympy_func": custom_range},
cast(Function, Function('_factorial')) : {"dim_func": factorial, "sympy_func": CustomFactorial},
cast(Function, Function('_add')) : {"dim_func": custom_add_dims, "sympy_func": Add},
}

global_placeholder_set = set(global_placeholder_map.keys())
Expand Down Expand Up @@ -1612,10 +1587,8 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],
expression: Expr,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function]) -> tuple[Expr | None, Exception | None]:
# need to remove any subtractions or unary negative since this may
# lead to unintentional cancellation during the parameter substitution process
positive_only_expression = subtraction_to_addition(expression)
expression_with_parameter_subs = cast(Expr, positive_only_expression.xreplace(parameter_subs))

expression_with_parameter_subs = cast(Expr, expression.xreplace(parameter_subs))

error = None
final_expression = None
Expand Down
4 changes: 2 additions & 2 deletions src/parser/LatexToSympy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1987,11 +1987,11 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
}

visitAdd = (ctx: AddContext) => {
return `Add(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`;
return `_add(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`;
}

visitSubtract = (ctx: SubtractContext) => {
return `Add(${this.visit(ctx.expr(0))}, -(${this.visit(ctx.expr(1))}))`;
return `_add(${this.visit(ctx.expr(0))}, -(${this.visit(ctx.expr(1))}))`;
}

visitVariable = (ctx: VariableContext) => {
Expand Down

0 comments on commit b108310

Please sign in to comment.