diff --git a/public/dimensional_analysis.py b/public/dimensional_analysis.py index d7dfed95..6da84e70 100644 --- a/public/dimensional_analysis.py +++ b/public/dimensional_analysis.py @@ -952,22 +952,12 @@ def custom_latex(expression: Expr) -> str: _range = Function("_range") def ensure_dims_all_compatible(*args): - if args[0].is_zero: - if all(arg.is_zero for arg in args): - first_arg = sympify('0') - else: - first_arg = sympify('1') - else: - first_arg = args[0] - - if len(args) == 1: - return first_arg - - first_arg_dims = custom_get_dimensional_dependencies(first_arg) - if all(custom_get_dimensional_dependencies(arg) == first_arg_dims for arg in args[1:]): - return first_arg - - raise TypeError('All input arguments to function need to have compatible units') + try: + # try adding, will only succeed for compatible units + custom_get_dimensional_dependencies(custom_add_dims(*args)) + except TypeError: + raise TypeError('All input arguments to function need to have compatible units') + return args[0] def ensure_dims_all_compatible_scalar_or_matrix(*args): if len(args) == 1 and is_matrix(args[0]): @@ -1752,7 +1742,6 @@ def dimensional_analysis(dimensional_analysis_expression: Expr | None, dim_sub_e except TypeError as e: result = f"Dimension Error: {e}" result_latex = result - print(result) return result, result_latex, custom_units_defined, custom_units, custom_units_latex diff --git a/tests/test_basic.spec.mjs b/tests/test_basic.spec.mjs index 2793042d..e0153732 100644 --- a/tests/test_basic.spec.mjs +++ b/tests/test_basic.spec.mjs @@ -786,6 +786,10 @@ test('Test floating point exponent rounding', async () => { await page.setLatex(1, String.raw`1\left\lbrack kg\cdot s^{.0000000000001}\right\rbrack+2\left\lbrack kg\right\rbrack=`); await page.click('#add-math-cell'); await page.setLatex(2, String.raw`1\left\lbrack kg\cdot s^{.000000000001}\right\rbrack+2\left\lbrack kg\right\rbrack=`); + await page.click('#add-math-cell'); + await page.setLatex(3, String.raw`\mathrm{sum}\left(1\left\lbrack K\cdot s^{.0000000000001}\right\rbrack,3\left\lbrack K\right\rbrack\right)=`); + await page.click('#add-math-cell'); + await page.setLatex(4, String.raw`\mathrm{sum}\left(1\left\lbrack K\cdot s^{.000000000001}\right\rbrack,3\left\lbrack K\right\rbrack\right)=`); await page.waitForSelector('text=Updating...', {state: 'detached'}); @@ -799,7 +803,14 @@ test('Test floating point exponent rounding', async () => { content = await page.textContent('#result-units-1'); expect(content).toBe('kg'); - await expect(page.locator('#cell-2 >> text=Dimension Error')).toBeVisible(); + await expect(page.locator('#cell-2 >> text=Dimension Error: Only equivalent dimensions can be added or subtracted')).toBeVisible(); + + content = await page.textContent('#result-value-3'); + expect(parseLatexFloat(content)).toBeCloseTo(4, precision); + content = await page.textContent('#result-units-3'); + expect(content).toBe('K'); + + await expect(page.locator('#cell-4 >> text=Dimension Error: All input arguments to function need to have compatible units')).toBeVisible(); }); test('Test function notation with integrals', async () => {