Skip to content

Commit 3080071

Browse files
committed
Add cli method
1 parent 2f4125b commit 3080071

File tree

3 files changed

+52
-34
lines changed

3 files changed

+52
-34
lines changed

auto_heuristic/code.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ def _is_float(element: any) -> bool:
3030
return {k: f'"{k}"' for k in return_values}
3131

3232

33-
def decision_tree_to_python(tree: DecisionNode, feature_names: List[str], return_format: dict, performance: float) -> str:
33+
def decision_tree_to_python(
34+
tree: DecisionNode, feature_names: List[str], return_format: dict, performance: float
35+
) -> str:
3436
variable_names = {f: f.replace(" ", "_").lower() for f in feature_names}
3537
code = "def predict({}):".format(", ".join(variable_names.values())) + "\n"
36-
code += " # Accuracy: {}%".format(int(performance*100)) + "\n"
38+
code += " # Accuracy: {}%".format(int(performance * 100)) + "\n"
3739

3840
def _decision_node_to_python(node, depth=1):
3941
indent = " " * depth
@@ -60,7 +62,7 @@ def decision_tree_to_js(tree: DecisionNode, feature_names: List[str], return_for
6062
for f in feature_names
6163
}
6264
code = "function predict({}) {{".format(", ".join(variable_names.values())) + "\n"
63-
code += " // Accuracy: {}%".format(int(performance*100)) + "\n"
65+
code += " // Accuracy: {}%".format(int(performance * 100)) + "\n"
6466

6567
def _decision_node_to_js(node, depth=1):
6668
indent = " " * depth

cli.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from auto_heuristic import (
2+
load_dataset,
3+
get_model,
4+
extract_decision_tree,
5+
get_variable_list,
6+
decision_tree_to_python,
7+
decision_tree_to_js,
8+
format_return_value,
9+
)
10+
import argparse
11+
12+
13+
def auto_heuristic(csv_path: str, target_column: str, python_path: str = None, js_path: str = None):
14+
X, y, feature_names, class_names = load_dataset(csv_path, target_column)
15+
models = get_model(X, y)
16+
assert models, "No successful heuristic found"
17+
return_format = format_return_value(set(y))
18+
depth = max(models)
19+
model, score = models[depth]
20+
21+
print("Best depth:", depth)
22+
print("Score:", score)
23+
24+
formatted_tree = extract_decision_tree(model, feature_names, class_names)
25+
variable_list = get_variable_list(formatted_tree)
26+
27+
if python_path:
28+
python_code = decision_tree_to_python(formatted_tree, variable_list, return_format, score)
29+
with open(python_path, "w") as f:
30+
f.write(python_code + "\n")
31+
32+
if js_path:
33+
js_code = decision_tree_to_js(formatted_tree, variable_list, return_format, score)
34+
with open(js_path, "w") as f:
35+
f.write(js_code + "\n")
36+
37+
38+
if __name__ == "__main__":
39+
parser = argparse.ArgumentParser(description="Generate heuristic rules for a given CSV")
40+
parser.add_argument("file", metavar="N", type=str, help="CSV file to process")
41+
parser.add_argument("--target", type=str, required=True, help="Column to target")
42+
parser.add_argument("--python", type=str, required=False, help="Python file to generate")
43+
parser.add_argument("--js", type=str, required=False, help="JS file to generate")
44+
45+
args = parser.parse_args()
46+
47+
auto_heuristic(args.file, args.target, args.python, args.js)

test.py

-31
This file was deleted.

0 commit comments

Comments
 (0)