Skip to content

Commit

Permalink
feat: add custom error messages for all dimension errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mgreminger committed Jan 12, 2025
1 parent b9e2d84 commit cb816a3
Showing 1 changed file with 48 additions and 45 deletions.
93 changes: 48 additions & 45 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def custom_latex(expression: Expr) -> str:

_range = Function("_range")

def ensure_dims_all_compatible(*args):
def ensure_dims_all_compatible(*args, error_message: str | None = None):
if args[0].is_zero:
if all(arg.is_zero for arg in args):
first_arg = S.Zero
Expand All @@ -970,30 +970,35 @@ def ensure_dims_all_compatible(*args):
if all(normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == first_arg_dims for arg in args[1:]):
return first_arg

raise TypeError('All input arguments to function need to have compatible units')
if error_message is None:
raise TypeError('All input arguments to function need to have compatible units')
else:
raise TypeError(error_message)

def ensure_dims_all_compatible_scalar_or_matrix(*args, func_name = ""):
error_message = f"{func_name} function requires that all input values have the same units"

def ensure_dims_all_compatible_scalar_or_matrix(*args):
if len(args) == 1 and is_matrix(args[0]):
return ensure_dims_all_compatible(*args[0])
return ensure_dims_all_compatible(*args[0], error_message=error_message)
else:
return ensure_dims_all_compatible(*args)
return ensure_dims_all_compatible(*args, error_message=error_message)

def ensure_dims_all_compatible_piecewise(*args):
# Need to make sure first element in tuples passed to Piecewise all have compatible units
# The second element of the tuples has already been checked by And, StrictLessThan, etc.
return ensure_dims_all_compatible(*[arg[0] for arg in args])
return ensure_dims_all_compatible(*[arg[0] for arg in args], error_message="Units not consistent for piecewise cell")

def ensure_unitless_in_angle_out(arg):
def ensure_unitless_in_angle_out(arg, func_name=""):
if normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == {}:
return angle
else:
raise TypeError('Unitless input argument required for function')
raise TypeError(f'Unitless input argument required for {func_name} function')

def ensure_unitless_in(arg):
def ensure_unitless_in(arg, func_name=""):
if normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == {}:
return arg
else:
raise TypeError('Unitless input argument required for function')
raise TypeError(f'Unitless input argument required for {func_name} function')

def ensure_any_unit_in_angle_out(arg):
# ensure input arg units make sense (will raise if inconsistent)
Expand All @@ -1019,11 +1024,9 @@ def ensure_inverse_dims(arg):
for j in range(arg.cols):
row.append(cast(Expr, arg[j,i])**-1)
column_dims.setdefault(i, []).append(arg[j,i])
try:
for _, values in column_dims.items():
ensure_dims_all_compatible(*values)
except TypeError:
raise TypeError('Dimensions not consistent for matrix inverse')

for _, values in column_dims.items():
ensure_dims_all_compatible(*values, error_message='Dimensions not consistent for matrix inverse')

return Matrix(rows)

Expand Down Expand Up @@ -1159,7 +1162,7 @@ def custom_range(*args: Expr):
return Matrix(values)

def custom_range_dims(dim_values: DimValues, *args: Expr):
return Matrix([ensure_dims_all_compatible(*args)]*len(cast(Matrix, dim_values["result"])))
return Matrix([ensure_dims_all_compatible(*args, error_message="All inputs to the range function must have the same units")]*len(cast(Matrix, dim_values["result"])))

class PlaceholderFunction(TypedDict):
dim_func: Callable | Function
Expand Down Expand Up @@ -1236,7 +1239,7 @@ def custom_integral_dims(local_expr: Expr, global_expr: Expr, dummy_integral_var
lower_limit: Expr | None = None, upper_limit: Expr | None = None,
lower_limit_dims: Expr | None = None, upper_limit_dims: Expr | None = None):
if lower_limit is not None and upper_limit is not None:
ensure_dims_all_compatible(lower_limit_dims, upper_limit_dims)
ensure_dims_all_compatible(lower_limit_dims, upper_limit_dims, error_message="Upper and lower integral limits must have the same dimensions")
return global_expr * lower_limit_dims # type: ignore
else:
return global_expr * integral_var # type: ignore
Expand Down Expand Up @@ -1304,8 +1307,8 @@ def fdiff(self, argindex=1):


def fluid_dims(fluid_function: FluidFunction, input1, input2):
ensure_dims_all_compatible(get_dims(fluid_function["input1Dims"]), input1)
ensure_dims_all_compatible(get_dims(fluid_function["input2Dims"]), input2)
ensure_dims_all_compatible(get_dims(fluid_function["input1Dims"]), input1, error_message=f"First input to fluid function {fluid_function['name'].removesuffix('_as_variable')} has the incorrect units")
ensure_dims_all_compatible(get_dims(fluid_function["input2Dims"]), input2, error_message=f"Second input to fluid function {fluid_function['name'].removesuffix('_as_variable')} has the incorrect units")

return get_dims(fluid_function["outputDims"])

Expand Down Expand Up @@ -1389,9 +1392,9 @@ def fdiff(self, argindex=1):


def HA_fluid_dims(fluid_function: FluidFunction, input1, input2, input3):
ensure_dims_all_compatible(get_dims(fluid_function["input1Dims"]), input1)
ensure_dims_all_compatible(get_dims(fluid_function["input2Dims"]), input2)
ensure_dims_all_compatible(get_dims(fluid_function.get("input3Dims", [])), input3)
ensure_dims_all_compatible(get_dims(fluid_function["input1Dims"]), input1, error_message=f"First input to fluid function {fluid_function['name'].removesuffix('_as_variable')} has the incorrect units")
ensure_dims_all_compatible(get_dims(fluid_function["input2Dims"]), input2, error_message=f"Second input to fluid function {fluid_function['name'].removesuffix('_as_variable')} has the incorrect units")
ensure_dims_all_compatible(get_dims(fluid_function.get("input3Dims", [])), input3, error_message=f"Third input to fluid function {fluid_function['name'].removesuffix('_as_variable')} has the incorrect units")

return get_dims(fluid_function["outputDims"])

Expand Down Expand Up @@ -1446,7 +1449,7 @@ def _imp_(arg1):

def _eval_evalf(self, prec):
if (len(self.args) != 1):
raise TypeError(f'The interpolation function {interpolation_function["name"]} requires 1 input value, ({len(self.args)} given)')
raise TypeError(f"The interpolation function {interpolation_function['name'].removesuffix('_as_variable')} requires 1 input value, ({len(self.args)} given)")

if (self.args[0].is_number):
float_input = float(cast(Expr, self.args[0]))
Expand All @@ -1465,7 +1468,7 @@ def fdiff(self, argindex=1):
interpolation_wrapper.__name__ = interpolation_function["name"]

def interpolation_dims_wrapper(input):
ensure_dims_all_compatible(get_dims(interpolation_function["inputDims"]), input)
ensure_dims_all_compatible(get_dims(interpolation_function["inputDims"]), input, error_message=f"Incorrect units for interpolation function {interpolation_function['name'].removesuffix('_as_variable')}")

return get_dims(interpolation_function["outputDims"])

Expand All @@ -1490,7 +1493,7 @@ def eval(cls, arg1: Expr):
polyfit_wrapper.__name__ = polyfit_function["name"]

def polyfit_dims_wrapper(input):
ensure_dims_all_compatible(get_dims(polyfit_function["inputDims"]), input)
ensure_dims_all_compatible(get_dims(polyfit_function["inputDims"]), input, error_message=f"Incorrect units for polyfit function {polyfit_function['name'].removesuffix('_as_variable')}")

return get_dims(polyfit_function["outputDims"])

Expand Down Expand Up @@ -1531,28 +1534,28 @@ def get_next_id(self):
function_id_wrapper = Function('_function_id_wrapper')

global_placeholder_map: dict[Function, PlaceholderFunction] = {
cast(Function, Function('_StrictLessThan')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": StrictLessThan},
cast(Function, Function('_LessThan')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": LessThan},
cast(Function, Function('_StrictGreaterThan')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": StrictGreaterThan},
cast(Function, Function('_GreaterThan')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": GreaterThan},
cast(Function, Function('_And')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": And},
cast(Function, Function('_StrictLessThan')) : {"dim_func": partial(ensure_dims_all_compatible, error_message="Piecewise cell comparison dimensions must match"), "sympy_func": StrictLessThan},
cast(Function, Function('_LessThan')) : {"dim_func": partial(ensure_dims_all_compatible, error_message="Piecewise cell comparison dimensions must match"), "sympy_func": LessThan},
cast(Function, Function('_StrictGreaterThan')) : {"dim_func": partial(ensure_dims_all_compatible, error_message="Piecewise cell comparison dimensions must match"), "sympy_func": StrictGreaterThan},
cast(Function, Function('_GreaterThan')) : {"dim_func": partial(ensure_dims_all_compatible, error_message="Piecewise cell comparison dimensions must match"), "sympy_func": GreaterThan},
cast(Function, Function('_And')) : {"dim_func": partial(ensure_dims_all_compatible, error_message="Piecewise cell comparison dimensions must match"), "sympy_func": And},
cast(Function, Function('_Piecewise')) : {"dim_func": ensure_dims_all_compatible_piecewise, "sympy_func": Piecewise},
cast(Function, Function('_asin')) : {"dim_func": ensure_unitless_in_angle_out, "sympy_func": asin},
cast(Function, Function('_acos')) : {"dim_func": ensure_unitless_in_angle_out, "sympy_func": acos},
cast(Function, Function('_atan')) : {"dim_func": ensure_unitless_in_angle_out, "sympy_func": atan},
cast(Function, Function('_asec')) : {"dim_func": ensure_unitless_in_angle_out, "sympy_func": asec},
cast(Function, Function('_acsc')) : {"dim_func": ensure_unitless_in_angle_out, "sympy_func": acsc},
cast(Function, Function('_acot')) : {"dim_func": ensure_unitless_in_angle_out, "sympy_func": acot},
cast(Function, Function('_asin')) : {"dim_func": partial(ensure_unitless_in_angle_out, func_name="arcsin"), "sympy_func": asin},
cast(Function, Function('_acos')) : {"dim_func": partial(ensure_unitless_in_angle_out, func_name="arccos"), "sympy_func": acos},
cast(Function, Function('_atan')) : {"dim_func": partial(ensure_unitless_in_angle_out, func_name="arctan"), "sympy_func": atan},
cast(Function, Function('_asec')) : {"dim_func": partial(ensure_unitless_in_angle_out, func_name="arcsec"), "sympy_func": asec},
cast(Function, Function('_acsc')) : {"dim_func": partial(ensure_unitless_in_angle_out, func_name="arcscs"), "sympy_func": acsc},
cast(Function, Function('_acot')) : {"dim_func": partial(ensure_unitless_in_angle_out, func_name="arccot"), "sympy_func": acot},
cast(Function, Function('_arg')) : {"dim_func": ensure_any_unit_in_angle_out, "sympy_func": arg},
cast(Function, Function('_re')) : {"dim_func": ensure_any_unit_in_same_out, "sympy_func": re},
cast(Function, Function('_im')) : {"dim_func": ensure_any_unit_in_same_out, "sympy_func": im},
cast(Function, Function('_conjugate')) : {"dim_func": ensure_any_unit_in_same_out, "sympy_func": conjugate},
cast(Function, Function('_Max')) : {"dim_func": ensure_dims_all_compatible_scalar_or_matrix, "sympy_func": custom_max},
cast(Function, Function('_Min')) : {"dim_func": ensure_dims_all_compatible_scalar_or_matrix, "sympy_func": custom_min},
cast(Function, Function('_sum')) : {"dim_func": ensure_dims_all_compatible_scalar_or_matrix, "sympy_func": custom_sum},
cast(Function, Function('_average')) : {"dim_func": ensure_dims_all_compatible_scalar_or_matrix, "sympy_func": custom_average},
cast(Function, Function('_stdev')) : {"dim_func": ensure_dims_all_compatible_scalar_or_matrix, "sympy_func": partial(custom_stdev, False)},
cast(Function, Function('_stdevp')) : {"dim_func": ensure_dims_all_compatible_scalar_or_matrix, "sympy_func": partial(custom_stdev, True)},
cast(Function, Function('_Max')) : {"dim_func": partial(ensure_dims_all_compatible_scalar_or_matrix, func_name="max"), "sympy_func": custom_max},
cast(Function, Function('_Min')) : {"dim_func": partial(ensure_dims_all_compatible_scalar_or_matrix, func_name="min"), "sympy_func": custom_min},
cast(Function, Function('_sum')) : {"dim_func": partial(ensure_dims_all_compatible_scalar_or_matrix, func_name="sum"), "sympy_func": custom_sum},
cast(Function, Function('_average')) : {"dim_func": partial(ensure_dims_all_compatible_scalar_or_matrix, func_name="average"), "sympy_func": custom_average},
cast(Function, Function('_stdev')) : {"dim_func": partial(ensure_dims_all_compatible_scalar_or_matrix, func_name="stdev"), "sympy_func": partial(custom_stdev, False)},
cast(Function, Function('_stdevp')) : {"dim_func": partial(ensure_dims_all_compatible_scalar_or_matrix, func_name="stdevp"), "sympy_func": partial(custom_stdev, True)},
cast(Function, Function('_count')) : {"dim_func": custom_count, "sympy_func": custom_count},
cast(Function, Function('_Abs')) : {"dim_func": ensure_any_unit_in_same_out, "sympy_func": Abs},
cast(Function, Function('_Inverse')) : {"dim_func": ensure_inverse_dims, "sympy_func": UniversalInverse},
Expand All @@ -1564,9 +1567,9 @@ def get_next_id(self):
cast(Function, Function('_Eq')) : {"dim_func": Eq, "sympy_func": Eq},
cast(Function, Function('_norm')) : {"dim_func": custom_norm, "sympy_func": custom_norm},
cast(Function, Function('_dot')) : {"dim_func": custom_dot, "sympy_func": custom_dot},
cast(Function, Function('_ceil')) : {"dim_func": ensure_unitless_in, "sympy_func": ceiling},
cast(Function, Function('_floor')) : {"dim_func": ensure_unitless_in, "sympy_func": floor},
cast(Function, Function('_round')) : {"dim_func": ensure_unitless_in, "sympy_func": custom_round},
cast(Function, Function('_ceil')) : {"dim_func": partial(ensure_unitless_in, func_name="ceil"), "sympy_func": ceiling},
cast(Function, Function('_floor')) : {"dim_func": partial(ensure_unitless_in, func_name="floor"), "sympy_func": floor},
cast(Function, Function('_round')) : {"dim_func": partial(ensure_unitless_in, func_name="round"), "sympy_func": custom_round},
cast(Function, Function('_Derivative')) : {"dim_func": custom_derivative_dims, "sympy_func": custom_derivative},
cast(Function, Function('_Integral')) : {"dim_func": custom_integral_dims, "sympy_func": custom_integral},
cast(Function, Function('_range')) : {"dim_func": custom_range_dims, "sympy_func": custom_range},
Expand Down

0 comments on commit cb816a3

Please sign in to comment.