Skip to content

Commit

Permalink
fix: use the adding dims logic for functions that accept multiple inp…
Browse files Browse the repository at this point in the history
…uts with same dims
  • Loading branch information
mgreminger committed Jan 12, 2025
1 parent 8a8814e commit 67f7ba4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 28 deletions.
55 changes: 27 additions & 28 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
factorial,
Basic,
Rational,
Integer
Integer,
S
)

class ExprWithAssumptions(Expr):
Expand Down Expand Up @@ -748,6 +749,20 @@ def get_base_units(custom_base_units: CustomBaseUnits | None= None) -> dict[tupl

ZERO_PLACEHOLDER = "implicit_param__zero"

def normalize_dims_dict(input):
keys_to_remove = set()
for key, value in input.items():
new_value = value.round(EXP_NUM_DIGITS)
if new_value == S.Zero:
keys_to_remove.add(key)
else:
input[key] = new_value

for key in keys_to_remove:
input.pop(key)

return input

# Monkey patch of SymPy's get_dimensional_dependencies so that units that have a small
# exponent difference (within EXP_NUM_DIGITS) are still considered equivalent for addition
def custom_get_dimensional_dependencies_for_name(self, dimension):
Expand All @@ -759,7 +774,7 @@ def custom_get_dimensional_dependencies_for_name(self, dimension):
if dimension.name.is_Symbol:
# Dimensions not included in the dependencies are considered
# as base dimensions:
return dict(self.dimensional_dependencies.get(dimension, {dimension: 1}))
return dict(self.dimensional_dependencies.get(dimension, {dimension: S.One}))

if dimension.name.is_number or dimension.name.is_NumberSymbol:
return {}
Expand All @@ -775,22 +790,7 @@ def custom_get_dimensional_dependencies_for_name(self, dimension):
return {k: v for (k, v) in ret.items() if v != 0}

if dimension.name.is_Add:
dicts = [get_for_name(i) for i in dimension.name.args]

for d in dicts:
keys_to_remove = set()
for key, exp in d.items():
if isinstance(exp, int):
exp = sympify(float(exp))

new_exp = exp.round(EXP_NUM_DIGITS)
if new_exp == sympify("0"):
keys_to_remove.add(key)
else:
d[key] = new_exp

for key in keys_to_remove:
d.pop(key)
dicts = [normalize_dims_dict(get_for_name(i)) for i in dimension.name.args]

if all(d == dicts[0] for d in dicts[1:]):
return dicts[0]
Expand Down Expand Up @@ -957,17 +957,17 @@ def custom_latex(expression: Expr) -> str:
def ensure_dims_all_compatible(*args):
if args[0].is_zero:
if all(arg.is_zero for arg in args):
first_arg = sympify('0')
first_arg = S.Zero
else:
first_arg = sympify('1')
first_arg = S.One
else:
first_arg = args[0]

if len(args) == 1:
return first_arg

first_arg_dims = custom_get_dimensional_dependencies(first_arg)
if all(custom_get_dimensional_dependencies(arg) == first_arg_dims for arg in args[1:]):
first_arg_dims = normalize_dims_dict(custom_get_dimensional_dependencies(first_arg))
if all(normalize_dims_dict(custom_get_dimensional_dependencies(arg)) == first_arg_dims for arg in args[1:]):
return first_arg

raise TypeError('All input arguments to function need to have compatible units')
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def ensure_inverse_dims(arg):
for j in range(arg.cols):
dim, _ = get_mathjs_units(cast(dict[Dimension, float], custom_get_dimensional_dependencies(cast(Expr, arg[j,i]))))
if dim == "":
row.append(sympify('0'))
row.append(S.Zero)
else:
row.append(cast(Expr, arg[j,i])**-1)
column_dims.setdefault(i, []).append(dim)
Expand Down Expand Up @@ -1137,8 +1137,8 @@ def custom_range(*args: Expr):
if not all( (arg.is_real and arg.is_finite and not isinstance(arg, Dimension) for arg in args ) ): # type: ignore
raise TypeError('All range inputs must be unitless and must evaluate to real and finite values')

start = cast(Expr, sympify('1'))
step = cast(Expr, sympify('1'))
start = cast(Expr, S.One)
step = cast(Expr, S.One)

if len(args) == 1:
stop = args[0]
Expand Down Expand Up @@ -1609,7 +1609,7 @@ def replace_placeholder_funcs(expr: Expr,
expr = cast(Expr, expr.args[1])

if (not is_matrix(expr)) and isinstance(expr, Symbol) and expr.name == "_zero_delayed_substitution":
return sympify('0')
return S.Zero

if is_matrix(expr):
rows = []
Expand Down Expand Up @@ -1725,7 +1725,7 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],

def custom_get_dimensional_dependencies(expression: Expr | None):
if expression is not None:
expression = subs_wrapper(expression, {cast(Symbol, symbol): sympify('1') for symbol in (expression.free_symbols - dimension_symbols)})
expression = subs_wrapper(expression, {cast(Symbol, symbol): S.One for symbol in (expression.free_symbols - dimension_symbols)})
return dimsys_SI.get_dimensional_dependencies(expression)

def dimensional_analysis(dimensional_analysis_expression: Expr | None, dim_sub_error: Exception | None,
Expand Down Expand Up @@ -1755,7 +1755,6 @@ def dimensional_analysis(dimensional_analysis_expression: Expr | None, dim_sub_e
except TypeError as e:
result = f"Dimension Error: {e}"
result_latex = result
print(result)

return result, result_latex, custom_units_defined, custom_units, custom_units_latex

Expand Down
21 changes: 21 additions & 0 deletions tests/test_basic.spec.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -781,12 +781,21 @@ test('Test zero canceling bug with exponent', async () => {
});

test('Test floating point exponent rounding', async () => {
// check matching equivalent dims for adding
await page.setLatex(0, String.raw`1\left\lbrack m\right\rbrack+1\left\lbrack\frac{N^{\frac13}}{m^{\frac23}}\right\rbrack\cdot1\left\lbrack\frac{m^{\frac53}}{N^{\frac13}}\right\rbrack=`);
await page.click('#add-math-cell');
await page.setLatex(1, String.raw`1\left\lbrack kg\cdot s^{.0000000000001}\right\rbrack+2\left\lbrack kg\right\rbrack=`);
await page.click('#add-math-cell');
await page.setLatex(2, String.raw`1\left\lbrack kg\cdot s^{.000000000001}\right\rbrack+2\left\lbrack kg\right\rbrack=`);

// check matching equivalent dims for sum function
await page.click('#add-math-cell');
await page.setLatex(3, String.raw`\mathrm{sum}\left(1\left\lbrack m\right\rbrack,1\left\lbrack\frac{N^{\frac13}}{m^{\frac23}}\right\rbrack\cdot3\left\lbrack\frac{m^{\frac53}}{N^{\frac13}}\right\rbrack\right)=`);
await page.click('#add-math-cell');
await page.setLatex(4, String.raw`\mathrm{sum}\left(1\left\lbrack K\cdot s^{.0000000000001}\right\rbrack,4\left\lbrack K\right\rbrack\right)=`);
await page.click('#add-math-cell');
await page.setLatex(5, String.raw`\mathrm{sum}\left(1\left\lbrack K\cdot s^{.000000000001}\right\rbrack,4\left\lbrack K\right\rbrack\right)=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent('#result-value-0');
Expand All @@ -800,6 +809,18 @@ test('Test floating point exponent rounding', async () => {
expect(content).toBe('kg');

await expect(page.locator('#cell-2 >> text=Dimension Error')).toBeVisible();

content = await page.textContent('#result-value-3');
expect(parseLatexFloat(content)).toBeCloseTo(4, precision);
content = await page.textContent('#result-units-3');
expect(content).toBe('m');

content = await page.textContent('#result-value-4');
expect(parseLatexFloat(content)).toBeCloseTo(5, precision);
content = await page.textContent('#result-units-4');
expect(content).toBe('K');

await expect(page.locator('#cell-5 >> text=Dimension Error')).toBeVisible();
});

test('Test function notation with integrals', async () => {
Expand Down

0 comments on commit 67f7ba4

Please sign in to comment.