Skip to content

Commit

Permalink
feat: add result to DimValues dict
Browse files Browse the repository at this point in the history
This allows functions like range that need to know the result of the calculation to set the correct dims. Range now works with inputs that have consistent units. New test added for this functionality.
  • Loading branch information
mgreminger committed Jan 8, 2025
1 parent 0255ce2 commit 14ecb9c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 21 deletions.
44 changes: 27 additions & 17 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,10 @@ class CombinedExpressionScatter(TypedDict):
CombinedExpression = CombinedExpressionBlank | CombinedExpressionNoRange | CombinedExpressionRange | \
CombinedExpressionScatter

class DimValues(TypedDict):
args: list[Expr]
result: Expr

# maps from mathjs dimensions object to sympy dimensions
dim_map: dict[int, Dimension] = {
0: mass,
Expand Down Expand Up @@ -1073,6 +1077,9 @@ def custom_range(*args: Expr):

return Matrix(values)

def custom_range_dims(dim_values: DimValues, *args: Expr):
return Matrix([ensure_dims_all_compatible(*args)]*len(cast(Matrix, dim_values["result"])))

class PlaceholderFunction(TypedDict):
dim_func: Callable | Function
sympy_func: object
Expand All @@ -1087,13 +1094,13 @@ def IndexMatrix(expression: Expr, i: Expr, j: Expr) -> Expr:

return expression[i-1, j-1] # type: ignore

def IndexMatrix_dims(dim_values: list[Expr], expression: Expr, i: Expr, j: Expr) -> Expr:
def IndexMatrix_dims(dim_values: DimValues, expression: Expr, i: Expr, j: Expr) -> Expr:
if custom_get_dimensional_dependencies(i) != {} or \
custom_get_dimensional_dependencies(j) != {}:
raise TypeError('Matrix Index Not Dimensionless')

i_value = dim_values[1]
j_value = dim_values[2]
i_value = dim_values["args"][1]
j_value = dim_values["args"][2]

return expression[i_value-1, j_value-1] # type: ignore

Expand Down Expand Up @@ -1167,10 +1174,10 @@ def custom_pow(base: Expr, exponent: Expr):
else:
return Pow(base, exponent)

def custom_pow_dims(dim_values: list[Expr], base: Expr, exponent: Expr):
def custom_pow_dims(dim_values: DimValues, base: Expr, exponent: Expr):
if custom_get_dimensional_dependencies(exponent) != {}:
raise TypeError('Exponent Not Dimensionless')
return Pow(base.evalf(PRECISION), (dim_values[1]).evalf(PRECISION))
return Pow(base.evalf(PRECISION), (dim_values["args"][1]).evalf(PRECISION))

CP = None

Expand Down Expand Up @@ -1481,7 +1488,7 @@ def get_next_id(self):
cast(Function, Function('_round')) : {"dim_func": ensure_unitless_in, "sympy_func": custom_round},
cast(Function, Function('_Derivative')) : {"dim_func": custom_derivative_dims, "sympy_func": custom_derivative},
cast(Function, Function('_Integral')) : {"dim_func": custom_integral_dims, "sympy_func": custom_integral},
cast(Function, Function('_range')) : {"dim_func": custom_range, "sympy_func": custom_range},
cast(Function, Function('_range')) : {"dim_func": custom_range_dims, "sympy_func": custom_range},
cast(Function, Function('_factorial')) : {"dim_func": factorial, "sympy_func": CustomFactorial},
cast(Function, Function('_add')) : {"dim_func": custom_add_dims, "sympy_func": Add},
cast(Function, Function('_Pow')) : {"dim_func": custom_pow_dims, "sympy_func": custom_pow},
Expand All @@ -1502,11 +1509,12 @@ def replace_sympy_funcs_with_placeholder_funcs(expression: Expr) -> Expr:

return expression


def replace_placeholder_funcs(expr: Expr,
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
dim_values_dict: dict[tuple[Basic,...], list[Expr]],
dim_values_dict: dict[tuple[Basic,...], DimValues],
function_parents: list[Basic],
data_table_subs: DataTableSubs | None) -> Expr:

Expand Down Expand Up @@ -1536,15 +1544,17 @@ def replace_placeholder_funcs(expr: Expr,
if func_key == "sympy_func":
child_expr = expr.args[1]
function_parents_snapshot = list(function_parents)
dim_values = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in child_expr.args]
dim_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in child_expr.args]
result = cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_args))
if data_table_subs is not None and len(data_table_subs.subs_stack) > 0:
dim_values_snapshot = list(dim_values)
for i, value in enumerate(dim_values_snapshot):
dim_values_snapshot[i] = cast(Expr, value.subs({key: cast(Matrix, value)[0,0] for key, value in data_table_subs.subs_stack[-1].items()}))
dim_values_dict[(expr.args[0], *function_parents_snapshot)] = dim_values_snapshot
dim_args_snapshot = list(dim_args)
for i, value in enumerate(dim_args_snapshot):
dim_args_snapshot[i] = cast(Expr, value.subs({key: cast(Matrix, value)[0,0] for key, value in data_table_subs.subs_stack[-1].items()}))
result_snapshot = cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_args_snapshot))
dim_values_dict[(expr.args[0], *function_parents_snapshot)] = DimValues(args=dim_args_snapshot, result=result_snapshot)
else:
dim_values_dict[(expr.args[0], *function_parents_snapshot)] = dim_values
return cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_values))
dim_values_dict[(expr.args[0], *function_parents_snapshot)] = DimValues(args=dim_args, result=result)
return result
else:
child_expr = expr.args[1]
dim_values = dim_values_dict[(expr.args[0],*function_parents)]
Expand Down Expand Up @@ -1607,7 +1617,7 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],
expression: Expr,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
dim_values_dict: dict[tuple[Basic,...], list[Expr]]) -> tuple[Expr | None, Exception | None]:
dim_values_dict: dict[tuple[Basic,...], DimValues]) -> tuple[Expr | None, Exception | None]:

expression_with_parameter_subs = cast(Expr, expression.xreplace(parameter_subs))

Expand Down Expand Up @@ -2342,9 +2352,9 @@ def get_evaluated_expression(expression: Expr,
parameter_subs: dict[Symbol, Expr],
simplify_symbolic_expressions: bool,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]], dict[tuple[Basic,...],list[Expr]]]:
placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]], dict[tuple[Basic,...],DimValues]]:
expression = cast(Expr, expression.xreplace(parameter_subs))
dim_values_dict: dict[tuple[Basic,...],list[Expr]] = {}
dim_values_dict: dict[tuple[Basic,...], DimValues] = {}
expression = replace_placeholder_funcs(expression,
"sympy_func",
placeholder_map,
Expand Down
9 changes: 7 additions & 2 deletions src/parser/LatexToSympy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { type Insertion, type Replacement, applyEdits,

import { RESERVED, GREEK_CHARS, UNASSIGNABLE, COMPARISON_MAP,
UNITS_WITH_OFFSET, TYPE_PARSING_ERRORS, BUILTIN_FUNCTION_MAP,
ZERO_PLACEHOLDER } from "./constants.js";
BUILTIN_FUNCTION_NEEDS_VALUES, ZERO_PLACEHOLDER } from "./constants.js";

import { MAX_MATRIX_COLS } from "../constants";

Expand Down Expand Up @@ -1273,7 +1273,12 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
if (!BUILTIN_FUNCTION_MAP.has(originalFunctionName)) {
return `${functionName}(${argumentString})`;
} else {
return `${BUILTIN_FUNCTION_MAP.get(originalFunctionName)}(${argumentString})`;
const functionPlaceholderName = BUILTIN_FUNCTION_MAP.get(originalFunctionName);
if(!BUILTIN_FUNCTION_NEEDS_VALUES.has(originalFunctionName)) {
return `${functionPlaceholderName}(${argumentString})`;
} else {
return `_dim_needs_values_wrapper(__unique_marker_${this.equationIndex}_${this.dimNeedsValuesIndex++},${functionPlaceholderName}(${argumentString}))`;
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/parser/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ export const BUILTIN_FUNCTION_MAP = new Map([
['stdevp', '_stdevp']
]);

export const BUILTIN_FUNCTION_NEEDS_VALUES= new Set(['range',]);

export const COMPARISON_MAP = new Map([
["<", "_StrictLessThan"],
["\\le", "_LessThan"],
Expand Down
13 changes: 11 additions & 2 deletions tests/test_matrix_functions.spec.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,17 @@ test('Test range that includes zero value multiplied by dimensioned value', asyn
expect(content).toBe(String.raw`\begin{bmatrix} 0\left\lbrack m\right\rbrack \\ 1\left\lbrack m\right\rbrack \\ 2\left\lbrack m\right\rbrack \\ 3\left\lbrack m\right\rbrack \\ 4\left\lbrack m\right\rbrack \\ 5\left\lbrack m\right\rbrack \end{bmatrix}`);
});

test('Test range input needs to be unitless', async () => {
await page.setLatex(0, String.raw`\mathrm{range}\left(1\left\lbrack m\right\rbrack,2\left\lbrack m\right\rbrack,.1\left\lbrack m\right\rbrack\right)=`);
test('Test range with consistent units', async () => {
await page.setLatex(0, String.raw`\mathrm{range}\left(1\left\lbrack m\right\rbrack,2\left\lbrack m\right\rbrack,1\left\lbrack m\right\rbrack\right)=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent(`#result-value-0`);
expect(content).toBe(String.raw`\begin{bmatrix} 1\left\lbrack m\right\rbrack \\ 2\left\lbrack m\right\rbrack \end{bmatrix}`);
});

test('Test range with inconsistent units', async () => {
await page.setLatex(0, String.raw`\mathrm{range}\left(1\left\lbrack m\right\rbrack,2\left\lbrack s\right\rbrack,1\left\lbrack m\right\rbrack\right)=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

Expand Down

0 comments on commit 14ecb9c

Please sign in to comment.