Skip to content

Commit

Permalink
refactor: only perform unit exponent rounding when necessary for comp…
Browse files Browse the repository at this point in the history
…arisons
  • Loading branch information
mgreminger committed Jan 12, 2025
1 parent 77dcdbc commit ffeaec1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,8 @@ def ensure_dims_all_compatible(*args):
if len(args) == 1:
return first_arg

first_arg_dims = custom_get_dimensional_dependencies(first_arg)
if all(custom_get_dimensional_dependencies(arg) == first_arg_dims for arg in args[1:]):
first_arg_dims = normalize_dims_dict(custom_get_dimensional_dependencies(first_arg))
if all(normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == first_arg_dims for arg in args[1:]):
return first_arg

raise TypeError('All input arguments to function need to have compatible units')
Expand All @@ -984,13 +984,13 @@ def ensure_dims_all_compatible_piecewise(*args):
return ensure_dims_all_compatible(*[arg[0] for arg in args])

def ensure_unitless_in_angle_out(arg):
if custom_get_dimensional_dependencies(arg) == {}:
if normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == {}:
return angle
else:
raise TypeError('Unitless input argument required for function')

def ensure_unitless_in(arg):
if custom_get_dimensional_dependencies(arg) == {}:
if normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == {}:
return arg
else:
raise TypeError('Unitless input argument required for function')
Expand Down Expand Up @@ -1182,8 +1182,8 @@ def IndexMatrix(expression: Expr, i: Expr, j: Expr) -> Expr:
return expression[i-1, j-1] # type: ignore

def IndexMatrix_dims(dim_values: DimValues, expression: Expr, i: Expr, j: Expr) -> Expr:
if custom_get_dimensional_dependencies(i) != {} or \
custom_get_dimensional_dependencies(j) != {}:
if normalize_dims_dict(custom_get_dimensional_dependencies(i)) != {} or \
normalize_dims_dict(custom_get_dimensional_dependencies(j)) != {}:
raise TypeError('Matrix Index Not Dimensionless')

i_value = dim_values["args"][1]
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def custom_pow(base: Expr, exponent: Expr):
return Pow(base, exponent)

def custom_pow_dims(dim_values: DimValues, base: Expr, exponent: Expr):
if custom_get_dimensional_dependencies(exponent) != {}:
if normalize_dims_dict(custom_get_dimensional_dependencies(exponent)) != {}:
raise TypeError('Exponent Not Dimensionless')
return Pow(base.evalf(PRECISION), (dim_values["args"][1]).evalf(PRECISION))

Expand Down Expand Up @@ -1726,7 +1726,7 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],
def custom_get_dimensional_dependencies(expression: Expr | None):
if expression is not None:
expression = subs_wrapper(expression, {cast(Symbol, symbol): S.One for symbol in (expression.free_symbols - dimension_symbols)})
return normalize_dims_dict(dimsys_SI.get_dimensional_dependencies(expression))
return dimsys_SI.get_dimensional_dependencies(expression)

def dimensional_analysis(dimensional_analysis_expression: Expr | None, dim_sub_error: Exception | None,
custom_base_units: CustomBaseUnits | None = None):
Expand Down

0 comments on commit ffeaec1

Please sign in to comment.