diff --git a/public/dimensional_analysis.py b/public/dimensional_analysis.py index 5b0c5662..4d6e1254 100644 --- a/public/dimensional_analysis.py +++ b/public/dimensional_analysis.py @@ -966,8 +966,8 @@ def ensure_dims_all_compatible(*args): if len(args) == 1: return first_arg - first_arg_dims = custom_get_dimensional_dependencies(first_arg) - if all(custom_get_dimensional_dependencies(arg) == first_arg_dims for arg in args[1:]): + first_arg_dims = normalize_dims_dict(custom_get_dimensional_dependencies(first_arg)) + if all(normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == first_arg_dims for arg in args[1:]): return first_arg raise TypeError('All input arguments to function need to have compatible units') @@ -984,13 +984,13 @@ def ensure_dims_all_compatible_piecewise(*args): return ensure_dims_all_compatible(*[arg[0] for arg in args]) def ensure_unitless_in_angle_out(arg): - if custom_get_dimensional_dependencies(arg) == {}: + if normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == {}: return angle else: raise TypeError('Unitless input argument required for function') def ensure_unitless_in(arg): - if custom_get_dimensional_dependencies(arg) == {}: + if normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == {}: return arg else: raise TypeError('Unitless input argument required for function') @@ -1182,8 +1182,8 @@ def IndexMatrix(expression: Expr, i: Expr, j: Expr) -> Expr: return expression[i-1, j-1] # type: ignore def IndexMatrix_dims(dim_values: DimValues, expression: Expr, i: Expr, j: Expr) -> Expr: - if custom_get_dimensional_dependencies(i) != {} or \ - custom_get_dimensional_dependencies(j) != {}: + if normalize_dims_dict(custom_get_dimensional_dependencies(i)) != {} or \ + normalize_dims_dict(custom_get_dimensional_dependencies(j)) != {}: raise TypeError('Matrix Index Not Dimensionless') i_value = dim_values["args"][1] @@ -1262,7 +1262,7 @@ def custom_pow(base: Expr, exponent: Expr): return Pow(base, exponent) def custom_pow_dims(dim_values: DimValues, base: Expr, exponent: Expr): - if custom_get_dimensional_dependencies(exponent) != {}: + if normalize_dims_dict(custom_get_dimensional_dependencies(exponent)) != {}: raise TypeError('Exponent Not Dimensionless') return Pow(base.evalf(PRECISION), (dim_values["args"][1]).evalf(PRECISION)) @@ -1726,7 +1726,7 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr], def custom_get_dimensional_dependencies(expression: Expr | None): if expression is not None: expression = subs_wrapper(expression, {cast(Symbol, symbol): S.One for symbol in (expression.free_symbols - dimension_symbols)}) - return normalize_dims_dict(dimsys_SI.get_dimensional_dependencies(expression)) + return dimsys_SI.get_dimensional_dependencies(expression) def dimensional_analysis(dimensional_analysis_expression: Expr | None, dim_sub_error: Exception | None, custom_base_units: CustomBaseUnits | None = None):