Skip to content

Commit

Permalink
fix: fix user function regressions
Browse files Browse the repository at this point in the history
Code generation still has issues
  • Loading branch information
mgreminger committed Jan 6, 2025
1 parent 0aa0e43 commit 439bee2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
66 changes: 39 additions & 27 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,7 @@ def get_next_id(self):
return self._next_id-1

dim_needs_values_wrapper = Function('_dim_needs_values_wrapper')
function_id_wrapper = Function('_function_id_wrapper')

global_placeholder_map: dict[Function, PlaceholderFunction] = {
cast(Function, Function('_StrictLessThan')) : {"dim_func": ensure_dims_all_compatible, "sympy_func": StrictLessThan},
Expand Down Expand Up @@ -1511,7 +1512,7 @@ def get_next_id(self):

global_placeholder_set = set(global_placeholder_map.keys())
dummy_var_placeholder_set = (Function('_Derivative'), Function('_Integral'))
dim_needs_values_wrapper_placeholder_set = (Function('_Pow'))
dim_needs_values_wrapper_placeholder_set = (Function('_Pow'), Function('_IndexMatrix'))
placeholder_inverse_map = { value["sympy_func"]: key for key, value in reversed(global_placeholder_map.items()) }
placeholder_inverse_set = set(placeholder_inverse_map.keys())

Expand All @@ -1528,8 +1529,14 @@ def replace_placeholder_funcs(expr: Expr,
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
dim_values_dict: dict[int, list[Expr]],
dim_values_dict: dict[tuple[int,...], list[Expr]],
function_parents: list[int],
data_table_subs: DataTableSubs | None) -> Expr:

if (not is_matrix(expr)) and expr.func == function_id_wrapper:
function_parents.append(int(cast(Expr, expr.args[0])))
expr = cast(Expr, expr.args[1])

if is_matrix(expr):
rows = []
for i in range(expr.rows):
Expand All @@ -1538,42 +1545,47 @@ def replace_placeholder_funcs(expr: Expr,
for j in range(expr.cols):
row.append(replace_placeholder_funcs(cast(Expr, expr[i,j]), func_key,
placeholder_map, placeholder_set,
dim_values_dict, data_table_subs) )
dim_values_dict, function_parents,
data_table_subs) )

return cast(Expr, Matrix(rows))

expr = cast(Expr,expr)

if len(expr.args) == 0:
return expr

if expr.func == dim_needs_values_wrapper:
if func_key == "sympy_func":
child_expr = expr.args[1]
dim_values = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs) for arg in child_expr.args]
dim_values_dict[int(cast(Expr, expr.args[0]))] = dim_values
function_parents_snapshot = list(function_parents)
dim_values = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in child_expr.args]
if data_table_subs is not None and len(data_table_subs.subs_stack) > 0:
dim_values_snapshot = list(dim_values)
for i, value in enumerate(dim_values_snapshot):
dim_values_snapshot[i] = cast(Expr, value.subs({key: cast(Matrix, value)[0,0] for key, value in data_table_subs.subs_stack[-1].items()}))
dim_values_dict[(int(cast(Expr, expr.args[0])), *function_parents_snapshot)] = dim_values_snapshot
else:
dim_values_dict[(int(cast(Expr, expr.args[0])), *function_parents_snapshot)] = dim_values
return cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(*dim_values))
else:
child_expr = expr.args[1]
dim_values = dim_values_dict[int(cast(Expr, expr.args[0]))]
child_processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs) for arg in child_expr.args]
dim_values = dim_values_dict[(int(cast(Expr, expr.args[0])),*function_parents)]
child_processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in child_expr.args]
return cast(Expr, cast(Callable, placeholder_map[cast(Function, child_expr.func)][func_key])(dim_values, *child_processed_args))
elif expr.func in dummy_var_placeholder_set and func_key == "dim_func":
if expr.func in dim_needs_values_wrapper_placeholder_set:
# Reached a dim function that needs values to analyze dims (exponent, for example)
# This path will only be reached in the case of a expression resulting from a system solve
# Will not have values so fall back to raw sympy function
return cast(Expr, cast(Callable, placeholder_map[expr.func]["sympy_func"])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
else:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
elif expr.func in placeholder_set:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs) for arg in expr.args)))
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in expr.args)))

elif data_table_subs is not None and expr.func == data_table_calc_wrapper:
if len(expr.args[0].atoms(data_table_id_wrapper)) == 0:
return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs)
return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs)

data_table_subs.subs_stack.append({})
data_table_subs.shortest_col_stack.append(None)

sub_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs)
sub_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs)

subs = data_table_subs.subs_stack.pop()
shortest_col = data_table_subs.shortest_col_stack.pop()
Expand All @@ -1594,7 +1606,7 @@ def replace_placeholder_funcs(expr: Expr,
return cast(Expr, Matrix([sub_expr,]*shortest_col))

elif data_table_subs is not None and expr.func == data_table_id_wrapper:
current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs)
current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs)
new_var = Symbol(f"_data_table_var_{data_table_subs.get_next_id()}")

if not is_matrix(current_expr):
Expand All @@ -1612,13 +1624,13 @@ def replace_placeholder_funcs(expr: Expr,
return cast(Expr, current_expr[0,0])

else:
return cast(Expr, expr.func(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, data_table_subs) for arg in expr.args)))
return cast(Expr, expr.func(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, dim_values_dict, function_parents, data_table_subs) for arg in expr.args)))

def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],
expression: Expr,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
dim_values_dict: dict[int, list[Expr]]) -> tuple[Expr | None, Exception | None]:
dim_values_dict: dict[tuple[int,...], list[Expr]]) -> tuple[Expr | None, Exception | None]:

expression_with_parameter_subs = cast(Expr, expression.xreplace(parameter_subs))

Expand All @@ -1628,7 +1640,7 @@ def get_dimensional_analysis_expression(parameter_subs: dict[Symbol, Expr],
try:
final_expression = replace_placeholder_funcs(expression_with_parameter_subs,
"dim_func", placeholder_map, placeholder_set,
dim_values_dict, DataTableSubs())
dim_values_dict, [], DataTableSubs())
except Exception as e:
error = e

Expand Down Expand Up @@ -1856,7 +1868,7 @@ def solve_system(statements: list[EqualityStatement], variables: list[str],
{unitless_sub_expression["name"]:unitless_sub_expression["expression"] for unitless_sub_expression in cast(list[UnitlessSubExpression], statement["unitlessSubExpressions"])})
equality = replace_placeholder_funcs(cast(Expr, equality),
"sympy_func",
placeholder_map, placeholder_set, {}, None)
placeholder_map, placeholder_set, {}, [], None)

system.append(cast(Expr, equality.doit()))

Expand Down Expand Up @@ -1947,7 +1959,7 @@ def solve_system_numerical(statements: list[EqualityStatement], variables: list[
equality = equality.subs(parameter_subs)
equality = replace_placeholder_funcs(cast(Expr, equality),
"sympy_func",
placeholder_map, placeholder_set, {}, None)
placeholder_map, placeholder_set, {}, [], None)
system.append(cast(Expr, equality.doit()))
new_statements.extend(statement["equalityUnitsQueries"])

Expand Down Expand Up @@ -2378,13 +2390,13 @@ def get_evaluated_expression(expression: Expr,
parameter_subs: dict[Symbol, Expr],
simplify_symbolic_expressions: bool,
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]], dict[int,list[Expr]]]:
placeholder_set: set[Function]) -> tuple[ExprWithAssumptions, str | list[list[str]], dict[tuple[int,...],list[Expr]]]:
expression = cast(Expr, expression.xreplace(parameter_subs))
dim_values_dict: dict[int,list[Expr]] = {}
dim_values_dict: dict[tuple[int,...],list[Expr]] = {}
expression = replace_placeholder_funcs(expression,
"sympy_func",
placeholder_map,
placeholder_set, dim_values_dict,
placeholder_set, dim_values_dict, [],
DataTableSubs())
if not is_matrix(expression):
if simplify_symbolic_expressions:
Expand Down Expand Up @@ -2630,7 +2642,7 @@ def evaluate_statements(statements: list[InputAndSystemStatement],
final_expression = replace_placeholder_funcs(final_expression,
"sympy_func",
placeholder_map,
placeholder_set, {},
placeholder_set, {}, [],
None)

unitless_sub_expression_subs[symbols(unitless_sub_expression_name+current_function_name)] = final_expression
Expand Down
8 changes: 5 additions & 3 deletions src/parser/LatexToSympy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,10 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
if (ctx.id()) {
if (!ctx.CARET_SINGLE_CHAR_ID_UNDERSCORE_SUBSCRIPT()) {
base = this.visitId(ctx.id(), ctx.UNDERSCORE_SUBSCRIPT().toString());
this.params.push(base);
} else {
base = this.visitId(ctx.id(), ctx.CARET_SINGLE_CHAR_ID_UNDERSCORE_SUBSCRIPT().toString().slice(2));
this.params.push(base);
}

cursor = this.params.length;
Expand Down Expand Up @@ -1350,7 +1352,7 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
currentFunction = {
type: "assignment",
name: functionName,
sympy: variableName,
sympy: `_function_id_wrapper(${cantorPairing(this.equationIndex,this.functionIndex)},${variableName})`,
params: [variableName],
isUnitlessSubExpression: false,
isFunctionArgument: false,
Expand All @@ -1367,7 +1369,7 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
currentFunction = {
type: "assignment",
name: functionName,
sympy: variableName,
sympy: `_function_id_wrapper(${cantorPairing(this.equationIndex,this.functionIndex)},${variableName})`,
params: [variableName],
isUnitlessSubExpression: false,
isFunctionArgument: false,
Expand All @@ -1388,7 +1390,7 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
const unitsFunction: UserFunction = {
type: "assignment",
name: currentFunction.unitsQueryFunction,
sympy: variableName,
sympy: `_function_id_wrapper(${cantorPairing(this.equationIndex,this.functionIndex)},${variableName})`,
params: [variableName],
isUnitlessSubExpression: false,
isFunctionArgument: false,
Expand Down

0 comments on commit 439bee2

Please sign in to comment.