Skip to content

Commit

Permalink
Adding hwf experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
Liby99 committed Feb 23, 2024
1 parent c9114bc commit 3dbe1e9
Show file tree
Hide file tree
Showing 33 changed files with 2,516 additions and 0 deletions.
147 changes: 147 additions & 0 deletions experiments/hwf/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import List
import os
import random
import functools
import json
from argparse import ArgumentParser
from tqdm import tqdm

def precedence(operator: str) -> int:
if operator == "+" or operator == "-": return 2
elif operator == "*" or operator == "/": return 1
else: raise Exception(f"Unknown operator {operator}")

class Expression:
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF/Handwritten_Math_Symbols"))

def sample_images(self) -> List[str]: pass

def __str__(self) -> str: pass

def __len__(self) -> int: pass

def value(self) -> int: pass

def precedence(self) -> int: pass


class Constant(Expression):
def __init__(self, digit: int):
super(Constant, self).__init__()
self.digit = digit

def __str__(self):
return f"{self.digit}"

def __len__(self):
return 1

def sample_images(self) -> List[str]:
imgs = Constant.images_of_digit(self.digit)
return [imgs[random.randint(0, len(imgs) - 1)]]

def value(self) -> int:
return self.digit

def precedence(self) -> int:
return 0

@functools.lru_cache
def images_of_digit(digit: int) -> List[str]:
return [f"{digit}/{f}" for f in os.listdir(os.path.join(Expression.data_root, str(digit)))]


class BinaryOperation(Expression):
def __init__(self, operator: str, lhs: Expression, rhs: Expression):
self.operator = operator
self.lhs = lhs
self.rhs = rhs

def sample_images(self) -> List[str]:
imgs = BinaryOperation.images_of_symbol(self.operator)
s = [imgs[random.randint(0, len(imgs) - 1)]]
l = self.lhs.sample_images()
r = self.rhs.sample_images()
return l + s + r

def value(self) -> str:
if self.operator == "+": return self.lhs.value() + self.rhs.value()
elif self.operator == "-": return self.lhs.value() - self.rhs.value()
elif self.operator == "*": return self.lhs.value() * self.rhs.value()
elif self.operator == "/": return self.lhs.value() / self.rhs.value()
else: raise Exception(f"Unknown operator {self.operator}")

def __str__(self):
return f"{self.lhs} {self.operator} {self.rhs}"

def __len__(self):
return len(self.lhs) + 1 + len(self.rhs)

def precedence(self) -> int:
return precedence(self.operator)

@functools.lru_cache
def images_of_symbol(symbol: str) -> List[str]:
if symbol == "+": d = "+"
elif symbol == "-": d = "-"
elif symbol == "*": d = "times"
elif symbol == "/": d = "div"
else: raise Exception(f"Unknown symbol {symbol}")
return [f"{d}/{f}" for f in os.listdir(os.path.join(Expression.data_root, d))]


class ExpressionGenerator:
def __init__(self, const_perc, max_depth, max_length, digits, operators, length):
self.const_perc = const_perc
self.max_depth = max_depth
self.max_length = max_length
self.digits = digits
self.operators = operators
self.length = length

def generate_expr(self, depth=0):
if depth >= self.max_depth or random.random() < self.const_perc:
digit = self.digits[random.randint(0, len(self.digits) - 1)]
expr = Constant(digit)
else:
symbol = self.operators[random.randint(0, len(self.operators) - 1)]
lhs = self.generate_expr(depth + 1)
if lhs is None or precedence(symbol) < lhs.precedence(): return None
rhs = self.generate_expr(depth + 1)
if rhs is None or precedence(symbol) < rhs.precedence(): return None
if symbol == "/" and rhs.value() == 0: return None
expr = BinaryOperation(symbol, lhs, rhs)
if len(expr) > self.max_length: return None
if depth == 0 and self.length is not None and len(expr) != self.length: return None
return expr

def generate_datapoint(self, id):
while True:
e = self.generate_expr()
if e is not None:
return {"id": str(id), "img_paths": e.sample_images(), "expr": str(e), "res": e.value()}


if __name__ == "__main__":
parser = ArgumentParser("hwf/datagen")
parser.add_argument("--operators", action="store", default=["+", "-", "*", "/"], nargs="*")
parser.add_argument("--digits", action="store", type=int, default=list(range(10)), nargs="*")
parser.add_argument("--num-datapoints", type=int, default=100000)
parser.add_argument("--max-depth", type=int, default=3)
parser.add_argument("--max-length", type=int, default=7)
parser.add_argument("--length", type=int)
parser.add_argument("--constant-percentage", type=float, default=0.1)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--output", type=str, default="dataset.json")
args = parser.parse_args()

# Parameters
random.seed(args.seed)
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF"))

# Generate datapoints
generator = ExpressionGenerator(args.constant_percentage, args.max_depth, args.max_length, args.digits, args.operators, args.length)
data = [generator.generate_datapoint(i) for i in tqdm(range(args.num_datapoints))]

# Dump data
json.dump(data, open(os.path.join(data_root, args.output), "w"))
20 changes: 20 additions & 0 deletions experiments/hwf/datamerge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from argparse import ArgumentParser
import os
import json

if __name__ == "__main__":
parser = ArgumentParser("hwf/datamerge")
parser.add_argument("inputs", action="store", nargs="*")
parser.add_argument("--output", type=str, default="expr_merged.json")
args = parser.parse_args()

# Get the list of files
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF"))
all_data = []
for file in args.inputs:
print(f"Loading file {file}")
data = json.load(open(os.path.join(data_root, file)))
all_data += data

# Dump the result
json.dump(all_data, open(os.path.join(data_root, args.output), "w"))
27 changes: 27 additions & 0 deletions experiments/hwf/dataslice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from argparse import ArgumentParser
import os
import json
import random

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input", type=str, default="expr_train.json")
parser.add_argument("--output", type=str, default="expr_train_0.5.json")
parser.add_argument("--perc", type=float, default=0.5)
parser.add_argument("--seed", type=int, default=1234)
args = parser.parse_args()

# Set random seed
random.seed(args.seed)

# Load input file
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF/"))
input_file = json.load(open(os.path.join(data_root, args.input)))

# Shuffle the file and pick only the top arg.perc
random.shuffle(input_file)
end_index = int(len(input_file) * args.perc)
input_file = input_file[0:end_index]

# Output the file
json.dump(input_file, open(os.path.join(data_root, args.output), "w"))
19 changes: 19 additions & 0 deletions experiments/hwf/datastats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from argparse import ArgumentParser
import os
import json

if __name__ == "__main__":
parser = ArgumentParser("hwf/datastats")
parser.add_argument("--dataset", type=str, default="expr_train.json")
args = parser.parse_args()

# Get the dataset
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF"))
data = json.load(open(os.path.join(data_root, args.dataset)))

# Compute stats
lengths = {}
for datapoint in data:
if len(datapoint["img_paths"]) in lengths: lengths[len(datapoint["img_paths"])] += 1
else: lengths[len(datapoint["img_paths"])] = 1
print(lengths)
Binary file added experiments/hwf/docs/+_96.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added experiments/hwf/docs/1_47.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added experiments/hwf/docs/3_91.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added experiments/hwf/docs/5_237.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added experiments/hwf/docs/div_942.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions experiments/hwf/examples/hwf_length_3_all.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import "../scl/hwf_eval.scl"

rel symbol = {
(0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "6"), (0, "7"), (0, "8"), (0, "9"), (0, "+"), (0, "-"), (0, "*"), (0, "/"),
(1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "6"), (1, "7"), (1, "8"), (1, "9"), (1, "+"), (1, "-"), (1, "*"), (1, "/"),
(2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "6"), (2, "7"), (2, "8"), (2, "9"), (2, "+"), (2, "-"), (2, "*"), (2, "/"),
}

rel length(3)
11 changes: 11 additions & 0 deletions experiments/hwf/examples/hwf_length_3_all_prob.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import "../scl/hwf_eval.scl"

rel symbol = {
0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"),

0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"),

0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"),
}

rel length(3)
9 changes: 9 additions & 0 deletions experiments/hwf/examples/hwf_length_3_top_3.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import "../scl/hwf_parser.scl"

rel symbol = {
(0, "0"), (0, "4"), (0, "/"),
(1, "4"), (1, "*"), (1, "/"),
(2, "2"), (2, "8"), (2, "/"),
}

rel length(3)
11 changes: 11 additions & 0 deletions experiments/hwf/examples/hwf_length_3_top_3_prob.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import "../scl/hwf_eval.scl"

rel symbol = {
0.9::(0, "1"), 0.05::(0, "4"), 0.05::(0, "/"),
0.1::(1, "4"), 0.8::(1, "*"), 0.1::(1, "/"),
0.2::(2, "2"), 0.7::(2, "8"), 0.1::(2, "/"),
}

rel length(3)

query result
15 changes: 15 additions & 0 deletions experiments/hwf/examples/hwf_length_7_all.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import "../scl/hwf_eval.scl"

rel symbol = {
0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"),
0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"),
0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"),
0.07::(3, "0"), 0.07::(3, "1"), 0.07::(3, "2"), 0.07::(3, "3"), 0.07::(3, "4"), 0.07::(3, "5"), 0.07::(3, "6"), 0.07::(3, "7"), 0.07::(3, "8"), 0.07::(3, "9"), 0.07::(3, "+"), 0.07::(3, "-"), 0.07::(3, "*"), 0.07::(3, "/"),
0.07::(4, "0"), 0.07::(4, "1"), 0.07::(4, "2"), 0.07::(4, "3"), 0.07::(4, "4"), 0.07::(4, "5"), 0.07::(4, "6"), 0.07::(4, "7"), 0.07::(4, "8"), 0.07::(4, "9"), 0.07::(4, "+"), 0.07::(4, "-"), 0.07::(4, "*"), 0.07::(4, "/"),
0.07::(5, "0"), 0.07::(5, "1"), 0.07::(5, "2"), 0.07::(5, "3"), 0.07::(5, "4"), 0.07::(5, "5"), 0.07::(5, "6"), 0.07::(5, "7"), 0.07::(5, "8"), 0.07::(5, "9"), 0.07::(5, "+"), 0.07::(5, "-"), 0.07::(5, "*"), 0.07::(5, "/"),
0.07::(6, "0"), 0.07::(6, "1"), 0.07::(6, "2"), 0.07::(6, "3"), 0.07::(6, "4"), 0.07::(6, "5"), 0.07::(6, "6"), 0.07::(6, "7"), 0.07::(6, "8"), 0.07::(6, "9"), 0.07::(6, "+"), 0.07::(6, "-"), 0.07::(6, "*"), 0.07::(6, "/"),
}

rel length(7)

query result
19 changes: 19 additions & 0 deletions experiments/hwf/examples/hwf_length_7_all_prob.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import "../scl/hwf_eval.scl"

rel symbol_ids = {0, 1, 2, 3, 4, 5, 6}

rel all_symbol = {
0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"),
0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"),
0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"),
0.07::(3, "0"), 0.07::(3, "1"), 0.07::(3, "2"), 0.07::(3, "3"), 0.07::(3, "4"), 0.07::(3, "5"), 0.07::(3, "6"), 0.07::(3, "7"), 0.07::(3, "8"), 0.07::(3, "9"), 0.07::(3, "+"), 0.07::(3, "-"), 0.07::(3, "*"), 0.07::(3, "/"),
0.07::(4, "0"), 0.07::(4, "1"), 0.07::(4, "2"), 0.07::(4, "3"), 0.07::(4, "4"), 0.07::(4, "5"), 0.07::(4, "6"), 0.07::(4, "7"), 0.07::(4, "8"), 0.07::(4, "9"), 0.07::(4, "+"), 0.07::(4, "-"), 0.07::(4, "*"), 0.07::(4, "/"),
0.07::(5, "0"), 0.07::(5, "1"), 0.07::(5, "2"), 0.07::(5, "3"), 0.07::(5, "4"), 0.07::(5, "5"), 0.07::(5, "6"), 0.07::(5, "7"), 0.07::(5, "8"), 0.07::(5, "9"), 0.07::(5, "+"), 0.07::(5, "-"), 0.07::(5, "*"), 0.07::(5, "/"),
0.07::(6, "0"), 0.07::(6, "1"), 0.07::(6, "2"), 0.07::(6, "3"), 0.07::(6, "4"), 0.07::(6, "5"), 0.07::(6, "6"), 0.07::(6, "7"), 0.07::(6, "8"), 0.07::(6, "9"), 0.07::(6, "+"), 0.07::(6, "-"), 0.07::(6, "*"), 0.07::(6, "/"),
}

rel sampled_symbol(n, v) = (n, v) := categorical<1>(n, v: all_symbol(n, v) where n: symbol_ids(n))

rel length(7)

query sampled_symbol
15 changes: 15 additions & 0 deletions experiments/hwf/examples/hwf_unique_length_7_all.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import "../scl/hwf_unique_parser.scl"

rel symbol = {
(0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "6"), (0, "7"), (0, "8"), (0, "9"), (0, "+"), (0, "-"), (0, "*"), (0, "/"),
(1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "6"), (1, "7"), (1, "8"), (1, "9"), (1, "+"), (1, "-"), (1, "*"), (1, "/"),
(2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "6"), (2, "7"), (2, "8"), (2, "9"), (2, "+"), (2, "-"), (2, "*"), (2, "/"),
(3, "0"), (3, "1"), (3, "2"), (3, "3"), (3, "4"), (3, "5"), (3, "6"), (3, "7"), (3, "8"), (3, "9"), (3, "+"), (3, "-"), (3, "*"), (3, "/"),
(4, "0"), (4, "1"), (4, "2"), (4, "3"), (4, "4"), (4, "5"), (4, "6"), (4, "7"), (4, "8"), (4, "9"), (4, "+"), (4, "-"), (4, "*"), (4, "/"),
(5, "0"), (5, "1"), (5, "2"), (5, "3"), (5, "4"), (5, "5"), (5, "6"), (5, "7"), (5, "8"), (5, "9"), (5, "+"), (5, "-"), (5, "*"), (5, "/"),
(6, "0"), (6, "1"), (6, "2"), (6, "3"), (6, "4"), (6, "5"), (6, "6"), (6, "7"), (6, "8"), (6, "9"), (6, "+"), (6, "-"), (6, "*"), (6, "/"),
}

rel length(7)

query result
18 changes: 18 additions & 0 deletions experiments/hwf/examples/hwf_with_disjunction.scl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import "../scl/hwf_parser.scl"

// There are 7 characters and 4 ways to interpret:
// (((9 - 3) - 2) + 8), Result: 12 <-- CORRECT
// (((9 / 3) - 2) + 8), Result: 9
// ((9 - (3 / 2)) + 8), Result: 15.5
// (((9 / 3) / 2) + 8), Result: 9.5
rel symbol :- {1.0000::(0, "9")}
rel symbol :- {0.9323::(1, "-"); 0.0677::(1, "/")}
rel symbol :- {1.0000::(2, "3")}
rel symbol :- {0.9085::(3, "-"); 0.0915::(3, "/")}
rel symbol :- {1.0000::(4, "2")}
rel symbol :- {0.9960::(5, "+")}
rel symbol :- {1.0000::(6, "8")}

rel length(7)

query result
27 changes: 27 additions & 0 deletions experiments/hwf/examples/run_hwf_unique_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import torch
import scallopy

this_file_path = os.path.abspath(os.path.join(__file__, "../"))

# Create scallop context
ctx = scallopy.ScallopContext(provenance="difftopbottomkclauses")
ctx.import_file(os.path.join(this_file_path, "../scl/hwf_parser.scl"))

# The symbols facts
ctx.add_facts("symbol", [
(torch.tensor(0.2), (0, "3")), (torch.tensor(0.5), (0, "5")),
(torch.tensor(0.1), (1, "*")), (torch.tensor(0.3), (1, "/")),
(torch.tensor(0.01), (2, "4")), (torch.tensor(0.8), (2, "2")),
])

# The length facts
ctx.add_facts("length", [
(None, (3,))
])

# Run the context
ctx.run(debug_input_provenance=True)

# Inspect the result
print(list(ctx.relation("result")))
Loading

0 comments on commit 3dbe1e9

Please sign in to comment.