-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
2,516 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import List | ||
import os | ||
import random | ||
import functools | ||
import json | ||
from argparse import ArgumentParser | ||
from tqdm import tqdm | ||
|
||
def precedence(operator: str) -> int: | ||
if operator == "+" or operator == "-": return 2 | ||
elif operator == "*" or operator == "/": return 1 | ||
else: raise Exception(f"Unknown operator {operator}") | ||
|
||
class Expression: | ||
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF/Handwritten_Math_Symbols")) | ||
|
||
def sample_images(self) -> List[str]: pass | ||
|
||
def __str__(self) -> str: pass | ||
|
||
def __len__(self) -> int: pass | ||
|
||
def value(self) -> int: pass | ||
|
||
def precedence(self) -> int: pass | ||
|
||
|
||
class Constant(Expression): | ||
def __init__(self, digit: int): | ||
super(Constant, self).__init__() | ||
self.digit = digit | ||
|
||
def __str__(self): | ||
return f"{self.digit}" | ||
|
||
def __len__(self): | ||
return 1 | ||
|
||
def sample_images(self) -> List[str]: | ||
imgs = Constant.images_of_digit(self.digit) | ||
return [imgs[random.randint(0, len(imgs) - 1)]] | ||
|
||
def value(self) -> int: | ||
return self.digit | ||
|
||
def precedence(self) -> int: | ||
return 0 | ||
|
||
@functools.lru_cache | ||
def images_of_digit(digit: int) -> List[str]: | ||
return [f"{digit}/{f}" for f in os.listdir(os.path.join(Expression.data_root, str(digit)))] | ||
|
||
|
||
class BinaryOperation(Expression): | ||
def __init__(self, operator: str, lhs: Expression, rhs: Expression): | ||
self.operator = operator | ||
self.lhs = lhs | ||
self.rhs = rhs | ||
|
||
def sample_images(self) -> List[str]: | ||
imgs = BinaryOperation.images_of_symbol(self.operator) | ||
s = [imgs[random.randint(0, len(imgs) - 1)]] | ||
l = self.lhs.sample_images() | ||
r = self.rhs.sample_images() | ||
return l + s + r | ||
|
||
def value(self) -> str: | ||
if self.operator == "+": return self.lhs.value() + self.rhs.value() | ||
elif self.operator == "-": return self.lhs.value() - self.rhs.value() | ||
elif self.operator == "*": return self.lhs.value() * self.rhs.value() | ||
elif self.operator == "/": return self.lhs.value() / self.rhs.value() | ||
else: raise Exception(f"Unknown operator {self.operator}") | ||
|
||
def __str__(self): | ||
return f"{self.lhs} {self.operator} {self.rhs}" | ||
|
||
def __len__(self): | ||
return len(self.lhs) + 1 + len(self.rhs) | ||
|
||
def precedence(self) -> int: | ||
return precedence(self.operator) | ||
|
||
@functools.lru_cache | ||
def images_of_symbol(symbol: str) -> List[str]: | ||
if symbol == "+": d = "+" | ||
elif symbol == "-": d = "-" | ||
elif symbol == "*": d = "times" | ||
elif symbol == "/": d = "div" | ||
else: raise Exception(f"Unknown symbol {symbol}") | ||
return [f"{d}/{f}" for f in os.listdir(os.path.join(Expression.data_root, d))] | ||
|
||
|
||
class ExpressionGenerator: | ||
def __init__(self, const_perc, max_depth, max_length, digits, operators, length): | ||
self.const_perc = const_perc | ||
self.max_depth = max_depth | ||
self.max_length = max_length | ||
self.digits = digits | ||
self.operators = operators | ||
self.length = length | ||
|
||
def generate_expr(self, depth=0): | ||
if depth >= self.max_depth or random.random() < self.const_perc: | ||
digit = self.digits[random.randint(0, len(self.digits) - 1)] | ||
expr = Constant(digit) | ||
else: | ||
symbol = self.operators[random.randint(0, len(self.operators) - 1)] | ||
lhs = self.generate_expr(depth + 1) | ||
if lhs is None or precedence(symbol) < lhs.precedence(): return None | ||
rhs = self.generate_expr(depth + 1) | ||
if rhs is None or precedence(symbol) < rhs.precedence(): return None | ||
if symbol == "/" and rhs.value() == 0: return None | ||
expr = BinaryOperation(symbol, lhs, rhs) | ||
if len(expr) > self.max_length: return None | ||
if depth == 0 and self.length is not None and len(expr) != self.length: return None | ||
return expr | ||
|
||
def generate_datapoint(self, id): | ||
while True: | ||
e = self.generate_expr() | ||
if e is not None: | ||
return {"id": str(id), "img_paths": e.sample_images(), "expr": str(e), "res": e.value()} | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("hwf/datagen") | ||
parser.add_argument("--operators", action="store", default=["+", "-", "*", "/"], nargs="*") | ||
parser.add_argument("--digits", action="store", type=int, default=list(range(10)), nargs="*") | ||
parser.add_argument("--num-datapoints", type=int, default=100000) | ||
parser.add_argument("--max-depth", type=int, default=3) | ||
parser.add_argument("--max-length", type=int, default=7) | ||
parser.add_argument("--length", type=int) | ||
parser.add_argument("--constant-percentage", type=float, default=0.1) | ||
parser.add_argument("--seed", type=int, default=1234) | ||
parser.add_argument("--output", type=str, default="dataset.json") | ||
args = parser.parse_args() | ||
|
||
# Parameters | ||
random.seed(args.seed) | ||
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF")) | ||
|
||
# Generate datapoints | ||
generator = ExpressionGenerator(args.constant_percentage, args.max_depth, args.max_length, args.digits, args.operators, args.length) | ||
data = [generator.generate_datapoint(i) for i in tqdm(range(args.num_datapoints))] | ||
|
||
# Dump data | ||
json.dump(data, open(os.path.join(data_root, args.output), "w")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from argparse import ArgumentParser | ||
import os | ||
import json | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("hwf/datamerge") | ||
parser.add_argument("inputs", action="store", nargs="*") | ||
parser.add_argument("--output", type=str, default="expr_merged.json") | ||
args = parser.parse_args() | ||
|
||
# Get the list of files | ||
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF")) | ||
all_data = [] | ||
for file in args.inputs: | ||
print(f"Loading file {file}") | ||
data = json.load(open(os.path.join(data_root, file))) | ||
all_data += data | ||
|
||
# Dump the result | ||
json.dump(all_data, open(os.path.join(data_root, args.output), "w")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from argparse import ArgumentParser | ||
import os | ||
import json | ||
import random | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--input", type=str, default="expr_train.json") | ||
parser.add_argument("--output", type=str, default="expr_train_0.5.json") | ||
parser.add_argument("--perc", type=float, default=0.5) | ||
parser.add_argument("--seed", type=int, default=1234) | ||
args = parser.parse_args() | ||
|
||
# Set random seed | ||
random.seed(args.seed) | ||
|
||
# Load input file | ||
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF/")) | ||
input_file = json.load(open(os.path.join(data_root, args.input))) | ||
|
||
# Shuffle the file and pick only the top arg.perc | ||
random.shuffle(input_file) | ||
end_index = int(len(input_file) * args.perc) | ||
input_file = input_file[0:end_index] | ||
|
||
# Output the file | ||
json.dump(input_file, open(os.path.join(data_root, args.output), "w")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from argparse import ArgumentParser | ||
import os | ||
import json | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("hwf/datastats") | ||
parser.add_argument("--dataset", type=str, default="expr_train.json") | ||
args = parser.parse_args() | ||
|
||
# Get the dataset | ||
data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF")) | ||
data = json.load(open(os.path.join(data_root, args.dataset))) | ||
|
||
# Compute stats | ||
lengths = {} | ||
for datapoint in data: | ||
if len(datapoint["img_paths"]) in lengths: lengths[len(datapoint["img_paths"])] += 1 | ||
else: lengths[len(datapoint["img_paths"])] = 1 | ||
print(lengths) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import "../scl/hwf_eval.scl" | ||
|
||
rel symbol = { | ||
(0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "6"), (0, "7"), (0, "8"), (0, "9"), (0, "+"), (0, "-"), (0, "*"), (0, "/"), | ||
(1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "6"), (1, "7"), (1, "8"), (1, "9"), (1, "+"), (1, "-"), (1, "*"), (1, "/"), | ||
(2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "6"), (2, "7"), (2, "8"), (2, "9"), (2, "+"), (2, "-"), (2, "*"), (2, "/"), | ||
} | ||
|
||
rel length(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import "../scl/hwf_eval.scl" | ||
|
||
rel symbol = { | ||
0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"), | ||
|
||
0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"), | ||
|
||
0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"), | ||
} | ||
|
||
rel length(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import "../scl/hwf_parser.scl" | ||
|
||
rel symbol = { | ||
(0, "0"), (0, "4"), (0, "/"), | ||
(1, "4"), (1, "*"), (1, "/"), | ||
(2, "2"), (2, "8"), (2, "/"), | ||
} | ||
|
||
rel length(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import "../scl/hwf_eval.scl" | ||
|
||
rel symbol = { | ||
0.9::(0, "1"), 0.05::(0, "4"), 0.05::(0, "/"), | ||
0.1::(1, "4"), 0.8::(1, "*"), 0.1::(1, "/"), | ||
0.2::(2, "2"), 0.7::(2, "8"), 0.1::(2, "/"), | ||
} | ||
|
||
rel length(3) | ||
|
||
query result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import "../scl/hwf_eval.scl" | ||
|
||
rel symbol = { | ||
0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"), | ||
0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"), | ||
0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"), | ||
0.07::(3, "0"), 0.07::(3, "1"), 0.07::(3, "2"), 0.07::(3, "3"), 0.07::(3, "4"), 0.07::(3, "5"), 0.07::(3, "6"), 0.07::(3, "7"), 0.07::(3, "8"), 0.07::(3, "9"), 0.07::(3, "+"), 0.07::(3, "-"), 0.07::(3, "*"), 0.07::(3, "/"), | ||
0.07::(4, "0"), 0.07::(4, "1"), 0.07::(4, "2"), 0.07::(4, "3"), 0.07::(4, "4"), 0.07::(4, "5"), 0.07::(4, "6"), 0.07::(4, "7"), 0.07::(4, "8"), 0.07::(4, "9"), 0.07::(4, "+"), 0.07::(4, "-"), 0.07::(4, "*"), 0.07::(4, "/"), | ||
0.07::(5, "0"), 0.07::(5, "1"), 0.07::(5, "2"), 0.07::(5, "3"), 0.07::(5, "4"), 0.07::(5, "5"), 0.07::(5, "6"), 0.07::(5, "7"), 0.07::(5, "8"), 0.07::(5, "9"), 0.07::(5, "+"), 0.07::(5, "-"), 0.07::(5, "*"), 0.07::(5, "/"), | ||
0.07::(6, "0"), 0.07::(6, "1"), 0.07::(6, "2"), 0.07::(6, "3"), 0.07::(6, "4"), 0.07::(6, "5"), 0.07::(6, "6"), 0.07::(6, "7"), 0.07::(6, "8"), 0.07::(6, "9"), 0.07::(6, "+"), 0.07::(6, "-"), 0.07::(6, "*"), 0.07::(6, "/"), | ||
} | ||
|
||
rel length(7) | ||
|
||
query result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import "../scl/hwf_eval.scl" | ||
|
||
rel symbol_ids = {0, 1, 2, 3, 4, 5, 6} | ||
|
||
rel all_symbol = { | ||
0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"), | ||
0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"), | ||
0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"), | ||
0.07::(3, "0"), 0.07::(3, "1"), 0.07::(3, "2"), 0.07::(3, "3"), 0.07::(3, "4"), 0.07::(3, "5"), 0.07::(3, "6"), 0.07::(3, "7"), 0.07::(3, "8"), 0.07::(3, "9"), 0.07::(3, "+"), 0.07::(3, "-"), 0.07::(3, "*"), 0.07::(3, "/"), | ||
0.07::(4, "0"), 0.07::(4, "1"), 0.07::(4, "2"), 0.07::(4, "3"), 0.07::(4, "4"), 0.07::(4, "5"), 0.07::(4, "6"), 0.07::(4, "7"), 0.07::(4, "8"), 0.07::(4, "9"), 0.07::(4, "+"), 0.07::(4, "-"), 0.07::(4, "*"), 0.07::(4, "/"), | ||
0.07::(5, "0"), 0.07::(5, "1"), 0.07::(5, "2"), 0.07::(5, "3"), 0.07::(5, "4"), 0.07::(5, "5"), 0.07::(5, "6"), 0.07::(5, "7"), 0.07::(5, "8"), 0.07::(5, "9"), 0.07::(5, "+"), 0.07::(5, "-"), 0.07::(5, "*"), 0.07::(5, "/"), | ||
0.07::(6, "0"), 0.07::(6, "1"), 0.07::(6, "2"), 0.07::(6, "3"), 0.07::(6, "4"), 0.07::(6, "5"), 0.07::(6, "6"), 0.07::(6, "7"), 0.07::(6, "8"), 0.07::(6, "9"), 0.07::(6, "+"), 0.07::(6, "-"), 0.07::(6, "*"), 0.07::(6, "/"), | ||
} | ||
|
||
rel sampled_symbol(n, v) = (n, v) := categorical<1>(n, v: all_symbol(n, v) where n: symbol_ids(n)) | ||
|
||
rel length(7) | ||
|
||
query sampled_symbol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import "../scl/hwf_unique_parser.scl" | ||
|
||
rel symbol = { | ||
(0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "6"), (0, "7"), (0, "8"), (0, "9"), (0, "+"), (0, "-"), (0, "*"), (0, "/"), | ||
(1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "6"), (1, "7"), (1, "8"), (1, "9"), (1, "+"), (1, "-"), (1, "*"), (1, "/"), | ||
(2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "6"), (2, "7"), (2, "8"), (2, "9"), (2, "+"), (2, "-"), (2, "*"), (2, "/"), | ||
(3, "0"), (3, "1"), (3, "2"), (3, "3"), (3, "4"), (3, "5"), (3, "6"), (3, "7"), (3, "8"), (3, "9"), (3, "+"), (3, "-"), (3, "*"), (3, "/"), | ||
(4, "0"), (4, "1"), (4, "2"), (4, "3"), (4, "4"), (4, "5"), (4, "6"), (4, "7"), (4, "8"), (4, "9"), (4, "+"), (4, "-"), (4, "*"), (4, "/"), | ||
(5, "0"), (5, "1"), (5, "2"), (5, "3"), (5, "4"), (5, "5"), (5, "6"), (5, "7"), (5, "8"), (5, "9"), (5, "+"), (5, "-"), (5, "*"), (5, "/"), | ||
(6, "0"), (6, "1"), (6, "2"), (6, "3"), (6, "4"), (6, "5"), (6, "6"), (6, "7"), (6, "8"), (6, "9"), (6, "+"), (6, "-"), (6, "*"), (6, "/"), | ||
} | ||
|
||
rel length(7) | ||
|
||
query result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import "../scl/hwf_parser.scl" | ||
|
||
// There are 7 characters and 4 ways to interpret: | ||
// (((9 - 3) - 2) + 8), Result: 12 <-- CORRECT | ||
// (((9 / 3) - 2) + 8), Result: 9 | ||
// ((9 - (3 / 2)) + 8), Result: 15.5 | ||
// (((9 / 3) / 2) + 8), Result: 9.5 | ||
rel symbol :- {1.0000::(0, "9")} | ||
rel symbol :- {0.9323::(1, "-"); 0.0677::(1, "/")} | ||
rel symbol :- {1.0000::(2, "3")} | ||
rel symbol :- {0.9085::(3, "-"); 0.0915::(3, "/")} | ||
rel symbol :- {1.0000::(4, "2")} | ||
rel symbol :- {0.9960::(5, "+")} | ||
rel symbol :- {1.0000::(6, "8")} | ||
|
||
rel length(7) | ||
|
||
query result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
import torch | ||
import scallopy | ||
|
||
this_file_path = os.path.abspath(os.path.join(__file__, "../")) | ||
|
||
# Create scallop context | ||
ctx = scallopy.ScallopContext(provenance="difftopbottomkclauses") | ||
ctx.import_file(os.path.join(this_file_path, "../scl/hwf_parser.scl")) | ||
|
||
# The symbols facts | ||
ctx.add_facts("symbol", [ | ||
(torch.tensor(0.2), (0, "3")), (torch.tensor(0.5), (0, "5")), | ||
(torch.tensor(0.1), (1, "*")), (torch.tensor(0.3), (1, "/")), | ||
(torch.tensor(0.01), (2, "4")), (torch.tensor(0.8), (2, "2")), | ||
]) | ||
|
||
# The length facts | ||
ctx.add_facts("length", [ | ||
(None, (3,)) | ||
]) | ||
|
||
# Run the context | ||
ctx.run(debug_input_provenance=True) | ||
|
||
# Inspect the result | ||
print(list(ctx.relation("result"))) |
Oops, something went wrong.