From 3dbe1e95cba52e697c47a065213474461088fca1 Mon Sep 17 00:00:00 2001 From: Ziyang Li Date: Fri, 23 Feb 2024 09:21:53 -0500 Subject: [PATCH] Adding hwf experiments --- experiments/hwf/datagen.py | 147 ++++++ experiments/hwf/datamerge.py | 20 + experiments/hwf/dataslice.py | 27 ++ experiments/hwf/datastats.py | 19 + experiments/hwf/docs/+_96.jpg | Bin 0 -> 832 bytes experiments/hwf/docs/1_47.jpg | Bin 0 -> 1132 bytes experiments/hwf/docs/3_91.jpg | Bin 0 -> 1439 bytes experiments/hwf/docs/5_237.jpg | Bin 0 -> 1421 bytes experiments/hwf/docs/div_942.jpg | Bin 0 -> 767 bytes experiments/hwf/examples/hwf_length_3_all.scl | 9 + .../hwf/examples/hwf_length_3_all_prob.scl | 11 + .../hwf/examples/hwf_length_3_top_3.scl | 9 + .../hwf/examples/hwf_length_3_top_3_prob.scl | 11 + experiments/hwf/examples/hwf_length_7_all.scl | 15 + .../hwf/examples/hwf_length_7_all_prob.scl | 19 + .../hwf/examples/hwf_unique_length_7_all.scl | 15 + .../hwf/examples/hwf_with_disjunction.scl | 18 + .../hwf/examples/run_hwf_unique_parser.py | 27 ++ experiments/hwf/plot_confusion_matrix.py | 116 +++++ experiments/hwf/readme.md | 68 +++ experiments/hwf/run.py | 255 +++++++++++ experiments/hwf/run_with_hwf_parser.py | 285 ++++++++++++ .../hwf/run_with_hwf_parser_w_sample.py | 254 ++++++++++ experiments/hwf/run_with_mbs.py | 433 ++++++++++++++++++ .../hwf/run_with_purely_discrete_sample.py | 285 ++++++++++++ experiments/hwf/scl/hwf_eval.scl | 22 + experiments/hwf/scl/hwf_parser.scl | 36 ++ experiments/hwf/scl/hwf_parser_w_sample.scl | 39 ++ experiments/hwf/scl/hwf_parser_wo_hash.scl | 41 ++ experiments/hwf/scl/hwf_sample.scl | 4 + experiments/hwf/scl/hwf_unique_parser.scl | 30 ++ experiments/hwf/test_hwf_model.py | 148 ++++++ experiments/hwf/variants/inc_dec.py | 153 +++++++ 33 files changed, 2516 insertions(+) create mode 100644 experiments/hwf/datagen.py create mode 100644 experiments/hwf/datamerge.py create mode 100644 experiments/hwf/dataslice.py create mode 100644 experiments/hwf/datastats.py create mode 100644 experiments/hwf/docs/+_96.jpg create mode 100644 experiments/hwf/docs/1_47.jpg create mode 100644 experiments/hwf/docs/3_91.jpg create mode 100644 experiments/hwf/docs/5_237.jpg create mode 100644 experiments/hwf/docs/div_942.jpg create mode 100644 experiments/hwf/examples/hwf_length_3_all.scl create mode 100644 experiments/hwf/examples/hwf_length_3_all_prob.scl create mode 100644 experiments/hwf/examples/hwf_length_3_top_3.scl create mode 100644 experiments/hwf/examples/hwf_length_3_top_3_prob.scl create mode 100644 experiments/hwf/examples/hwf_length_7_all.scl create mode 100644 experiments/hwf/examples/hwf_length_7_all_prob.scl create mode 100644 experiments/hwf/examples/hwf_unique_length_7_all.scl create mode 100644 experiments/hwf/examples/hwf_with_disjunction.scl create mode 100644 experiments/hwf/examples/run_hwf_unique_parser.py create mode 100644 experiments/hwf/plot_confusion_matrix.py create mode 100644 experiments/hwf/readme.md create mode 100644 experiments/hwf/run.py create mode 100644 experiments/hwf/run_with_hwf_parser.py create mode 100644 experiments/hwf/run_with_hwf_parser_w_sample.py create mode 100644 experiments/hwf/run_with_mbs.py create mode 100644 experiments/hwf/run_with_purely_discrete_sample.py create mode 100644 experiments/hwf/scl/hwf_eval.scl create mode 100644 experiments/hwf/scl/hwf_parser.scl create mode 100644 experiments/hwf/scl/hwf_parser_w_sample.scl create mode 100644 experiments/hwf/scl/hwf_parser_wo_hash.scl create mode 100644 experiments/hwf/scl/hwf_sample.scl create mode 100644 experiments/hwf/scl/hwf_unique_parser.scl create mode 100644 experiments/hwf/test_hwf_model.py create mode 100644 experiments/hwf/variants/inc_dec.py diff --git a/experiments/hwf/datagen.py b/experiments/hwf/datagen.py new file mode 100644 index 0000000..d4fc0fe --- /dev/null +++ b/experiments/hwf/datagen.py @@ -0,0 +1,147 @@ +from typing import List +import os +import random +import functools +import json +from argparse import ArgumentParser +from tqdm import tqdm + +def precedence(operator: str) -> int: + if operator == "+" or operator == "-": return 2 + elif operator == "*" or operator == "/": return 1 + else: raise Exception(f"Unknown operator {operator}") + +class Expression: + data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF/Handwritten_Math_Symbols")) + + def sample_images(self) -> List[str]: pass + + def __str__(self) -> str: pass + + def __len__(self) -> int: pass + + def value(self) -> int: pass + + def precedence(self) -> int: pass + + +class Constant(Expression): + def __init__(self, digit: int): + super(Constant, self).__init__() + self.digit = digit + + def __str__(self): + return f"{self.digit}" + + def __len__(self): + return 1 + + def sample_images(self) -> List[str]: + imgs = Constant.images_of_digit(self.digit) + return [imgs[random.randint(0, len(imgs) - 1)]] + + def value(self) -> int: + return self.digit + + def precedence(self) -> int: + return 0 + + @functools.lru_cache + def images_of_digit(digit: int) -> List[str]: + return [f"{digit}/{f}" for f in os.listdir(os.path.join(Expression.data_root, str(digit)))] + + +class BinaryOperation(Expression): + def __init__(self, operator: str, lhs: Expression, rhs: Expression): + self.operator = operator + self.lhs = lhs + self.rhs = rhs + + def sample_images(self) -> List[str]: + imgs = BinaryOperation.images_of_symbol(self.operator) + s = [imgs[random.randint(0, len(imgs) - 1)]] + l = self.lhs.sample_images() + r = self.rhs.sample_images() + return l + s + r + + def value(self) -> str: + if self.operator == "+": return self.lhs.value() + self.rhs.value() + elif self.operator == "-": return self.lhs.value() - self.rhs.value() + elif self.operator == "*": return self.lhs.value() * self.rhs.value() + elif self.operator == "/": return self.lhs.value() / self.rhs.value() + else: raise Exception(f"Unknown operator {self.operator}") + + def __str__(self): + return f"{self.lhs} {self.operator} {self.rhs}" + + def __len__(self): + return len(self.lhs) + 1 + len(self.rhs) + + def precedence(self) -> int: + return precedence(self.operator) + + @functools.lru_cache + def images_of_symbol(symbol: str) -> List[str]: + if symbol == "+": d = "+" + elif symbol == "-": d = "-" + elif symbol == "*": d = "times" + elif symbol == "/": d = "div" + else: raise Exception(f"Unknown symbol {symbol}") + return [f"{d}/{f}" for f in os.listdir(os.path.join(Expression.data_root, d))] + + +class ExpressionGenerator: + def __init__(self, const_perc, max_depth, max_length, digits, operators, length): + self.const_perc = const_perc + self.max_depth = max_depth + self.max_length = max_length + self.digits = digits + self.operators = operators + self.length = length + + def generate_expr(self, depth=0): + if depth >= self.max_depth or random.random() < self.const_perc: + digit = self.digits[random.randint(0, len(self.digits) - 1)] + expr = Constant(digit) + else: + symbol = self.operators[random.randint(0, len(self.operators) - 1)] + lhs = self.generate_expr(depth + 1) + if lhs is None or precedence(symbol) < lhs.precedence(): return None + rhs = self.generate_expr(depth + 1) + if rhs is None or precedence(symbol) < rhs.precedence(): return None + if symbol == "/" and rhs.value() == 0: return None + expr = BinaryOperation(symbol, lhs, rhs) + if len(expr) > self.max_length: return None + if depth == 0 and self.length is not None and len(expr) != self.length: return None + return expr + + def generate_datapoint(self, id): + while True: + e = self.generate_expr() + if e is not None: + return {"id": str(id), "img_paths": e.sample_images(), "expr": str(e), "res": e.value()} + + +if __name__ == "__main__": + parser = ArgumentParser("hwf/datagen") + parser.add_argument("--operators", action="store", default=["+", "-", "*", "/"], nargs="*") + parser.add_argument("--digits", action="store", type=int, default=list(range(10)), nargs="*") + parser.add_argument("--num-datapoints", type=int, default=100000) + parser.add_argument("--max-depth", type=int, default=3) + parser.add_argument("--max-length", type=int, default=7) + parser.add_argument("--length", type=int) + parser.add_argument("--constant-percentage", type=float, default=0.1) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output", type=str, default="dataset.json") + args = parser.parse_args() + + # Parameters + random.seed(args.seed) + data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF")) + + # Generate datapoints + generator = ExpressionGenerator(args.constant_percentage, args.max_depth, args.max_length, args.digits, args.operators, args.length) + data = [generator.generate_datapoint(i) for i in tqdm(range(args.num_datapoints))] + + # Dump data + json.dump(data, open(os.path.join(data_root, args.output), "w")) diff --git a/experiments/hwf/datamerge.py b/experiments/hwf/datamerge.py new file mode 100644 index 0000000..a1d1d62 --- /dev/null +++ b/experiments/hwf/datamerge.py @@ -0,0 +1,20 @@ +from argparse import ArgumentParser +import os +import json + +if __name__ == "__main__": + parser = ArgumentParser("hwf/datamerge") + parser.add_argument("inputs", action="store", nargs="*") + parser.add_argument("--output", type=str, default="expr_merged.json") + args = parser.parse_args() + + # Get the list of files + data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF")) + all_data = [] + for file in args.inputs: + print(f"Loading file {file}") + data = json.load(open(os.path.join(data_root, file))) + all_data += data + + # Dump the result + json.dump(all_data, open(os.path.join(data_root, args.output), "w")) diff --git a/experiments/hwf/dataslice.py b/experiments/hwf/dataslice.py new file mode 100644 index 0000000..4f12504 --- /dev/null +++ b/experiments/hwf/dataslice.py @@ -0,0 +1,27 @@ +from argparse import ArgumentParser +import os +import json +import random + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input", type=str, default="expr_train.json") + parser.add_argument("--output", type=str, default="expr_train_0.5.json") + parser.add_argument("--perc", type=float, default=0.5) + parser.add_argument("--seed", type=int, default=1234) + args = parser.parse_args() + + # Set random seed + random.seed(args.seed) + + # Load input file + data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF/")) + input_file = json.load(open(os.path.join(data_root, args.input))) + + # Shuffle the file and pick only the top arg.perc + random.shuffle(input_file) + end_index = int(len(input_file) * args.perc) + input_file = input_file[0:end_index] + + # Output the file + json.dump(input_file, open(os.path.join(data_root, args.output), "w")) diff --git a/experiments/hwf/datastats.py b/experiments/hwf/datastats.py new file mode 100644 index 0000000..370119c --- /dev/null +++ b/experiments/hwf/datastats.py @@ -0,0 +1,19 @@ +from argparse import ArgumentParser +import os +import json + +if __name__ == "__main__": + parser = ArgumentParser("hwf/datastats") + parser.add_argument("--dataset", type=str, default="expr_train.json") + args = parser.parse_args() + + # Get the dataset + data_root = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data/HWF")) + data = json.load(open(os.path.join(data_root, args.dataset))) + + # Compute stats + lengths = {} + for datapoint in data: + if len(datapoint["img_paths"]) in lengths: lengths[len(datapoint["img_paths"])] += 1 + else: lengths[len(datapoint["img_paths"])] = 1 + print(lengths) diff --git a/experiments/hwf/docs/+_96.jpg b/experiments/hwf/docs/+_96.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f945790c09ac1849307577094799edeaaa44416 GIT binary patch literal 832 zcmex=iF;N$`UAd82aiwDH{e}I9TgF%-;myuDB;r|f^d7$B}U^9U>Kma2XGYcyl zJ5Vn-P{CFKpqB)&N!40_7Q41X+a?4ISBp0~6Vm3Pp?>CobercG`GQH0a_7 z72~9$CQdFfaS2H&RW)@DO)V2sGjj_|D`yv1H+K(Dui%i-u<(e;sN|H?wDgS3tm2Z= zvhs?`s^*r~w)T$Bu1S-pOr17;#>`oZ7B5-4Z25|nt2S-kvUS_`9Xod&I(+2lvEwIB zp1O4T%GGPvZ`{1~@X_NZPoF)1@$%KjPoKYh{r3IG&tLyru_&7R@!pG937i@LNQ zaissaxuEV!yj*=2KU2l-N9-N@A4dPXAf?v1zQ%6rN4e!IBH}8v>RBq%qJQL+H0x~p zaOYOctDx73Yt~8!?8)ZTQBr^VM*HnApkW{QZ`}Xqr~Pf?e}>Hsm+QaARhaGi&(K@H z_4&c|O?6t^e?*p6uc(;+NNnNqrSEz_96uPw9(?`1?vqNzu9t5YTWs7|-SO{SsrL14 zmmf~On7*sMEPKL#hPD%b7wvy3ynhS(KaqNV`=6^9|C3h!*7cuZiQm5q@;?Iq*!*XB zwr*ej->LNvTI0FxI5Yln|ET`B{=xS*jUUAi%ZYvXyKm~sd-iWCS?70}o>w}5qVL_~ z;Q3Lzd+zL6Epyr9+WKen4HzX0UVdT!V^#i-Gxa}%$K}8C|1*3z`=b6&cbxo#_#bNi ze}wZNroY|#Xnx~Ai~kHwHGlZ8=i2FHexGOefoD^V`o(+pe79bj&6B%Wsw;C?b)&$M zJiS~~j+B<4f2Vy#0j#l`E+4RcYm-kmm$J9Ei~J$;XkJ$OW1pxs80RaI300000000010s{mE1_uZU3Jd?l0JRVR0s#X90t5pE z1q1{D00Dgg0s{a95d{(Xb($mz{*4NnC+Tr5kF{gBKN`Gar|K8>ng*U`hepw2Np&4!>4=WkO|gmH1;zE%!^LZDZWuJz z6U7qz&MW2*_$T+lPZf9z_R#p9tbA3}H2(k+c*@)1e}!i8*NODWUh>An#+tsPuWDW( z(`>G0K`c=vn)q56?dAUfNnIyT@aj*!Ahl}yK8h%>0P+6-{AADVlcIbdgZ(GLKLx(c z;;$TI=eE?mRXR8J9+fBGhHtS#X4WURm1CaX;!g}-Ld96o{{RH?{{V{5z@HF*X<7N7 z;@`#%e_euWU-Q7T5hVgCRGo&NxoS$@qw_HX_T{{Yh+oqPQi<^KR0{yuoC!u~3^ z_%HAS;qIT~kBJw0dTMW=>YgBdD@@U}!tx-$)oe(H(eX>ie+fTl^wxeGYrok#zPs@Q!QK;(SxZ=^ zn@opX@iv*F&v|Z>E5dGJhW=Th^P`p#G;x_?nlSM{VenJnN5ena^TM|t68L4M>IoJ0 yp!fQ&rLK5__gB>QOQuNT`$guVs7*E9eRvk%KLpS;;H% zNQj{{9)(RLrZA(7@hW-DOfvq=nBR`}yXW3>@44T--#Pc33y0wtP&?}E>I`IL0FaRy zfFr;G$jYo^J<6^_UT!_2C6)C8wgP1| zumqMeGU@=n0$AyC)Q_D~h5sQVD<_XqP?XLpO9}OAQcXEIsS-*mBW2%|_JO=QN<)9o zAq7njKgI1=w5;MX3(y9KYuYxQ?B^R=`(KU6C~52bq`TS3*u->)na$pPws!XW9UPsU zkGQxV{q2;em$%RFrvuIfo)5YZd@l^5eP0!k2zUt`g`uBDB0Q3Ea!J%Omn>#)+IW_(H%gk57x5e*E!ewM- z6|w<({cL!!_7N%J=W~O)u(Cw5B{IZkzLf0dcRre;1S1h0`)xKSC2?}hj$vux`*Bx8 z@7hiJc1P8ktEg>=y-b2Yf_5sB^g*EY-D&L>1b#eLLI-D~zsOeDm)LTl!u$ZQu01H| zMRMm1&N{k+mZMa)XxjuHK~PvysrjX=oI1+Bi>vae#;k+Rnz^bCimtaYix446xpbM+x2`y`boV;|14pzh8-{TY2VpGs%bSq zUO=E*G|{3K+s)Vpfg_(7*L>(L*T~mtAE&%@XSxX85$pr26b7yqiF1Q)(Xaw(d1N z;c!U_$LzJZ;+SEWn(|pSIUi8Q3a(Sqf7K^nMq&ghT*}K-#!gbZsul;4>%uGXa2{n} ztelF1b8#o!%7P`^Yv?hjhg=(SeF~{Q+{Ro|QPXp(-GZSIf;h@nJU1AEizM?-hBgF= z-s6f%7tOG)ywK9Ay1l$q`&gZ8E$|x~!@@DxCT(xy~ zTGYp|Z`t9DHtQlB0=u`*N*8-q;!~X&~qLL&0-I>4`Q)Myj>NS zx?96EbCZY1K~GNrP|*_OH9Ex~0+cu~Ll_~*n2+YSxFSjW3WEJq^?Y`Kum*y(v|=$CiQ^lz6QkQVo1~EK3n2)b z)$YnJ`zn}a`z{aTdUnv61Eyg=r$vm?**iDfrg0=a#JbT8(Jt}!6s~dlM4uX|NMg1j zobVp!FrFPGTssVb=ZD8jgx4k7Uc$jy8jaOZd7`uY`KgUoWI2iQa{*JLOOasm#xvMN ziF){aAL30>!^wF#0}+W_K^*s%XMrKe5=uZQ@o4Qa|Hob%-l8 zyy(pl@ez9ne4atj!-8P+v`G5R7$Z>x1TTfL+HRt=WGBQI0yHs^v5kE2l7S=QmCF(4 r)lWYULgk6CPc8MwtJbezL+GNd_mcb@8sm6dKK@~uHp8Lx6CvwA(+Is7 literal 0 HcmV?d00001 diff --git a/experiments/hwf/docs/5_237.jpg b/experiments/hwf/docs/5_237.jpg new file mode 100644 index 0000000000000000000000000000000000000000..89afb6a414921cf3f91236c5204e0b1c053c4aa9 GIT binary patch literal 1421 zcmV;81#OW1pxs80RaI300000000010s{mE1_uZU3Jd?l0JRVR0s#X90t5pE z1q1{D00Dgg0s{a95d{(Xb($mz{*4NnC+Tr5kHAUq3H_Wuvh2Prc<1cezQ5=H0Kzlj zO&?6U)wSJw$nwp{_I{ON6xZ5^kdma6eV!RSn2DB6k*hfQ!{c7H`*MEOzACr8{ha<9 zd~onLh2_)HF0}su8u;VF{tD1MO{N!+w5f4xrfWVd(=VdVlR5tNaPqbg$VjQ}F$lk8Zxp7Lj}6 zYv#4Kw}q~v7m_5x&h=rrTSIVd`8tKV+gV;+O!kXur_9z6>K@@m6`$ds7I;_2nsl1SgS<_r z>ei6z7j_d}X_u)yX?1IHC9T!DWF!e;xQ19{L-Qn&Ixq~R4GPckXN3GKhlurGhuVFPzo2-g`t|N4z3{%Vf2+r&+T6)8Y2NzCX?Hd<%PY?JQQF&H$#B!& zPYu7gQCO;-J`Z@fdGBc1If%3a5G zt)ZGld7;)WqPBET0E2yh;R5i^ ztKokTe17o9#Lw+R;cHJA_{J?iRq&E&zhpr`xD`pi~ii&AHdxQQM`&0_P6&MZOZHVgj2`=00|Ye bzhs!LgIilzm~WiF;N$`UAd82aiwDH{e}I9TgF%-;myuDB;r|f^d7$B}U^9U>Kma2XGYcyl zJ5Vn-P{CFKpqB)&N!40_7Q41X+a?4ISBp0~6Vm3Pp?>CobercG`GQH0a_7 z72~9$CQdFfaS2H&RW)@DO)V2sGjj_|D`yv1H+K(Dui%i-u<(e;sN|H?wDgS3tm2Z= zvhs?`s^*r~w)T$Bu1S-pOr17;#>`oZ7B5-4Z25|nt2S-kvUS_`9Xod&I(+2lvEwIB zp1O4T%GGPvZ`{1~@X_NZPoF)1@$%KjPoKYh{r3IG&tLyru_&7R@!pLO<;Ki+>h z@BHKRl0OXdL$2uU319v4ZQZWAsSkhOzVYREj8F28D|)&QD@;x-GrO*t_r0I%&dFSp z0?9b_&^H>3x!+(ahiod&U%KtN{CjaBo{?EYi{AT?V`Tq=D`VYmXFZ#Q<{QmFq zKPrD$Nxmn{}DC+*RTAo{y)Q!%KSR^9{|pyr z|2zMm;m4#~&;K(RPQflb_5GJXSCPOt`QN2?>TXBHAGVII{-!(IfA8-0fNk49+PCZO zvCXvGU%Nae?v`V5@uGWKQB&8=POZ&M`@TJ}`$yi}=hHNjOSpHY9QfN&$NKN`KbDVS zJ3pE~jz20d5t%XT-h7G6|D=8_K6c^9w3q9j)SEqyy^_znN^N$nR;Af=% Ule=+FFLc2#J8)nx+W-G1041_sYybcN literal 0 HcmV?d00001 diff --git a/experiments/hwf/examples/hwf_length_3_all.scl b/experiments/hwf/examples/hwf_length_3_all.scl new file mode 100644 index 0000000..4a6034f --- /dev/null +++ b/experiments/hwf/examples/hwf_length_3_all.scl @@ -0,0 +1,9 @@ +import "../scl/hwf_eval.scl" + +rel symbol = { + (0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "6"), (0, "7"), (0, "8"), (0, "9"), (0, "+"), (0, "-"), (0, "*"), (0, "/"), + (1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "6"), (1, "7"), (1, "8"), (1, "9"), (1, "+"), (1, "-"), (1, "*"), (1, "/"), + (2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "6"), (2, "7"), (2, "8"), (2, "9"), (2, "+"), (2, "-"), (2, "*"), (2, "/"), +} + +rel length(3) diff --git a/experiments/hwf/examples/hwf_length_3_all_prob.scl b/experiments/hwf/examples/hwf_length_3_all_prob.scl new file mode 100644 index 0000000..af2535b --- /dev/null +++ b/experiments/hwf/examples/hwf_length_3_all_prob.scl @@ -0,0 +1,11 @@ +import "../scl/hwf_eval.scl" + +rel symbol = { + 0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"), + + 0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"), + + 0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"), +} + +rel length(3) diff --git a/experiments/hwf/examples/hwf_length_3_top_3.scl b/experiments/hwf/examples/hwf_length_3_top_3.scl new file mode 100644 index 0000000..aedf13e --- /dev/null +++ b/experiments/hwf/examples/hwf_length_3_top_3.scl @@ -0,0 +1,9 @@ +import "../scl/hwf_parser.scl" + +rel symbol = { + (0, "0"), (0, "4"), (0, "/"), + (1, "4"), (1, "*"), (1, "/"), + (2, "2"), (2, "8"), (2, "/"), +} + +rel length(3) diff --git a/experiments/hwf/examples/hwf_length_3_top_3_prob.scl b/experiments/hwf/examples/hwf_length_3_top_3_prob.scl new file mode 100644 index 0000000..57543ed --- /dev/null +++ b/experiments/hwf/examples/hwf_length_3_top_3_prob.scl @@ -0,0 +1,11 @@ +import "../scl/hwf_eval.scl" + +rel symbol = { + 0.9::(0, "1"), 0.05::(0, "4"), 0.05::(0, "/"), + 0.1::(1, "4"), 0.8::(1, "*"), 0.1::(1, "/"), + 0.2::(2, "2"), 0.7::(2, "8"), 0.1::(2, "/"), +} + +rel length(3) + +query result diff --git a/experiments/hwf/examples/hwf_length_7_all.scl b/experiments/hwf/examples/hwf_length_7_all.scl new file mode 100644 index 0000000..cd98ee6 --- /dev/null +++ b/experiments/hwf/examples/hwf_length_7_all.scl @@ -0,0 +1,15 @@ +import "../scl/hwf_eval.scl" + +rel symbol = { + 0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"), + 0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"), + 0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"), + 0.07::(3, "0"), 0.07::(3, "1"), 0.07::(3, "2"), 0.07::(3, "3"), 0.07::(3, "4"), 0.07::(3, "5"), 0.07::(3, "6"), 0.07::(3, "7"), 0.07::(3, "8"), 0.07::(3, "9"), 0.07::(3, "+"), 0.07::(3, "-"), 0.07::(3, "*"), 0.07::(3, "/"), + 0.07::(4, "0"), 0.07::(4, "1"), 0.07::(4, "2"), 0.07::(4, "3"), 0.07::(4, "4"), 0.07::(4, "5"), 0.07::(4, "6"), 0.07::(4, "7"), 0.07::(4, "8"), 0.07::(4, "9"), 0.07::(4, "+"), 0.07::(4, "-"), 0.07::(4, "*"), 0.07::(4, "/"), + 0.07::(5, "0"), 0.07::(5, "1"), 0.07::(5, "2"), 0.07::(5, "3"), 0.07::(5, "4"), 0.07::(5, "5"), 0.07::(5, "6"), 0.07::(5, "7"), 0.07::(5, "8"), 0.07::(5, "9"), 0.07::(5, "+"), 0.07::(5, "-"), 0.07::(5, "*"), 0.07::(5, "/"), + 0.07::(6, "0"), 0.07::(6, "1"), 0.07::(6, "2"), 0.07::(6, "3"), 0.07::(6, "4"), 0.07::(6, "5"), 0.07::(6, "6"), 0.07::(6, "7"), 0.07::(6, "8"), 0.07::(6, "9"), 0.07::(6, "+"), 0.07::(6, "-"), 0.07::(6, "*"), 0.07::(6, "/"), +} + +rel length(7) + +query result diff --git a/experiments/hwf/examples/hwf_length_7_all_prob.scl b/experiments/hwf/examples/hwf_length_7_all_prob.scl new file mode 100644 index 0000000..501a886 --- /dev/null +++ b/experiments/hwf/examples/hwf_length_7_all_prob.scl @@ -0,0 +1,19 @@ +import "../scl/hwf_eval.scl" + +rel symbol_ids = {0, 1, 2, 3, 4, 5, 6} + +rel all_symbol = { + 0.07::(0, "0"), 0.07::(0, "1"), 0.07::(0, "2"), 0.07::(0, "3"), 0.07::(0, "4"), 0.07::(0, "5"), 0.07::(0, "6"), 0.07::(0, "7"), 0.07::(0, "8"), 0.07::(0, "9"), 0.07::(0, "+"), 0.07::(0, "-"), 0.07::(0, "*"), 0.07::(0, "/"), + 0.07::(1, "0"), 0.07::(1, "1"), 0.07::(1, "2"), 0.07::(1, "3"), 0.07::(1, "4"), 0.07::(1, "5"), 0.07::(1, "6"), 0.07::(1, "7"), 0.07::(1, "8"), 0.07::(1, "9"), 0.07::(1, "+"), 0.07::(1, "-"), 0.07::(1, "*"), 0.07::(1, "/"), + 0.07::(2, "0"), 0.07::(2, "1"), 0.07::(2, "2"), 0.07::(2, "3"), 0.07::(2, "4"), 0.07::(2, "5"), 0.07::(2, "6"), 0.07::(2, "7"), 0.07::(2, "8"), 0.07::(2, "9"), 0.07::(2, "+"), 0.07::(2, "-"), 0.07::(2, "*"), 0.07::(2, "/"), + 0.07::(3, "0"), 0.07::(3, "1"), 0.07::(3, "2"), 0.07::(3, "3"), 0.07::(3, "4"), 0.07::(3, "5"), 0.07::(3, "6"), 0.07::(3, "7"), 0.07::(3, "8"), 0.07::(3, "9"), 0.07::(3, "+"), 0.07::(3, "-"), 0.07::(3, "*"), 0.07::(3, "/"), + 0.07::(4, "0"), 0.07::(4, "1"), 0.07::(4, "2"), 0.07::(4, "3"), 0.07::(4, "4"), 0.07::(4, "5"), 0.07::(4, "6"), 0.07::(4, "7"), 0.07::(4, "8"), 0.07::(4, "9"), 0.07::(4, "+"), 0.07::(4, "-"), 0.07::(4, "*"), 0.07::(4, "/"), + 0.07::(5, "0"), 0.07::(5, "1"), 0.07::(5, "2"), 0.07::(5, "3"), 0.07::(5, "4"), 0.07::(5, "5"), 0.07::(5, "6"), 0.07::(5, "7"), 0.07::(5, "8"), 0.07::(5, "9"), 0.07::(5, "+"), 0.07::(5, "-"), 0.07::(5, "*"), 0.07::(5, "/"), + 0.07::(6, "0"), 0.07::(6, "1"), 0.07::(6, "2"), 0.07::(6, "3"), 0.07::(6, "4"), 0.07::(6, "5"), 0.07::(6, "6"), 0.07::(6, "7"), 0.07::(6, "8"), 0.07::(6, "9"), 0.07::(6, "+"), 0.07::(6, "-"), 0.07::(6, "*"), 0.07::(6, "/"), +} + +rel sampled_symbol(n, v) = (n, v) := categorical<1>(n, v: all_symbol(n, v) where n: symbol_ids(n)) + +rel length(7) + +query sampled_symbol diff --git a/experiments/hwf/examples/hwf_unique_length_7_all.scl b/experiments/hwf/examples/hwf_unique_length_7_all.scl new file mode 100644 index 0000000..4c5969b --- /dev/null +++ b/experiments/hwf/examples/hwf_unique_length_7_all.scl @@ -0,0 +1,15 @@ +import "../scl/hwf_unique_parser.scl" + +rel symbol = { + (0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "6"), (0, "7"), (0, "8"), (0, "9"), (0, "+"), (0, "-"), (0, "*"), (0, "/"), + (1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "6"), (1, "7"), (1, "8"), (1, "9"), (1, "+"), (1, "-"), (1, "*"), (1, "/"), + (2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "6"), (2, "7"), (2, "8"), (2, "9"), (2, "+"), (2, "-"), (2, "*"), (2, "/"), + (3, "0"), (3, "1"), (3, "2"), (3, "3"), (3, "4"), (3, "5"), (3, "6"), (3, "7"), (3, "8"), (3, "9"), (3, "+"), (3, "-"), (3, "*"), (3, "/"), + (4, "0"), (4, "1"), (4, "2"), (4, "3"), (4, "4"), (4, "5"), (4, "6"), (4, "7"), (4, "8"), (4, "9"), (4, "+"), (4, "-"), (4, "*"), (4, "/"), + (5, "0"), (5, "1"), (5, "2"), (5, "3"), (5, "4"), (5, "5"), (5, "6"), (5, "7"), (5, "8"), (5, "9"), (5, "+"), (5, "-"), (5, "*"), (5, "/"), + (6, "0"), (6, "1"), (6, "2"), (6, "3"), (6, "4"), (6, "5"), (6, "6"), (6, "7"), (6, "8"), (6, "9"), (6, "+"), (6, "-"), (6, "*"), (6, "/"), +} + +rel length(7) + +query result diff --git a/experiments/hwf/examples/hwf_with_disjunction.scl b/experiments/hwf/examples/hwf_with_disjunction.scl new file mode 100644 index 0000000..f50e9e7 --- /dev/null +++ b/experiments/hwf/examples/hwf_with_disjunction.scl @@ -0,0 +1,18 @@ +import "../scl/hwf_parser.scl" + +// There are 7 characters and 4 ways to interpret: +// (((9 - 3) - 2) + 8), Result: 12 <-- CORRECT +// (((9 / 3) - 2) + 8), Result: 9 +// ((9 - (3 / 2)) + 8), Result: 15.5 +// (((9 / 3) / 2) + 8), Result: 9.5 +rel symbol :- {1.0000::(0, "9")} +rel symbol :- {0.9323::(1, "-"); 0.0677::(1, "/")} +rel symbol :- {1.0000::(2, "3")} +rel symbol :- {0.9085::(3, "-"); 0.0915::(3, "/")} +rel symbol :- {1.0000::(4, "2")} +rel symbol :- {0.9960::(5, "+")} +rel symbol :- {1.0000::(6, "8")} + +rel length(7) + +query result diff --git a/experiments/hwf/examples/run_hwf_unique_parser.py b/experiments/hwf/examples/run_hwf_unique_parser.py new file mode 100644 index 0000000..11412f4 --- /dev/null +++ b/experiments/hwf/examples/run_hwf_unique_parser.py @@ -0,0 +1,27 @@ +import os +import torch +import scallopy + +this_file_path = os.path.abspath(os.path.join(__file__, "../")) + +# Create scallop context +ctx = scallopy.ScallopContext(provenance="difftopbottomkclauses") +ctx.import_file(os.path.join(this_file_path, "../scl/hwf_parser.scl")) + +# The symbols facts +ctx.add_facts("symbol", [ + (torch.tensor(0.2), (0, "3")), (torch.tensor(0.5), (0, "5")), + (torch.tensor(0.1), (1, "*")), (torch.tensor(0.3), (1, "/")), + (torch.tensor(0.01), (2, "4")), (torch.tensor(0.8), (2, "2")), +]) + +# The length facts +ctx.add_facts("length", [ + (None, (3,)) +]) + +# Run the context +ctx.run(debug_input_provenance=True) + +# Inspect the result +print(list(ctx.relation("result"))) diff --git a/experiments/hwf/plot_confusion_matrix.py b/experiments/hwf/plot_confusion_matrix.py new file mode 100644 index 0000000..2918269 --- /dev/null +++ b/experiments/hwf/plot_confusion_matrix.py @@ -0,0 +1,116 @@ +from argparse import ArgumentParser +import os +import random + +# Computation +import torch +import torchvision +import numpy +import functools +from sklearn.metrics import confusion_matrix +from PIL import Image + +# Plotting +import matplotlib.pyplot as plt +import seaborn as sn +import pandas as pd + +from run_with_hwf_parser import * + +class HWFSymbolDataset(torch.utils.data.Dataset): + def __init__(self, root: str, amount: int = 1000): + super(HWFSymbolDataset, self).__init__() + self.root = root + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,)) + ]) + self.amount = amount + self.data = [] + self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"] + self.data = [] + for _ in range(self.amount): + r1 = random.randint(0, 13) + if r1 < 10: images = self.images_of_digit(r1) + else: images = self.images_of_symbol(self.symbols[r1]) + r2 = random.randint(0, len(images) - 1) + img_path = images[r2] + self.data.append((img_path, r1)) + + @functools.lru_cache + def images_of_digit(self, digit: int): + return [f"{digit}/{f}" for f in os.listdir(os.path.join(self.root, "HWF/Handwritten_Math_Symbols", str(digit)))] + + @functools.lru_cache + def images_of_symbol(self, symbol: str): + if symbol == "+": d = "+" + elif symbol == "-": d = "-" + elif symbol == "*": d = "times" + elif symbol == "/": d = "div" + else: raise Exception(f"Unknown symbol {symbol}") + return [f"{d}/{f}" for f in os.listdir(os.path.join(self.root, "HWF/Handwritten_Math_Symbols", d))] + + def __len__(self): + return self.amount + + def __getitem__(self, index): + (img_path, label) = self.data[index] + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + return (img, label) + + +if __name__ == "__main__": + # Argument parser + parser = ArgumentParser("plot_confusion_matrix") + parser.add_argument("--amount", type=int, default=1000) + parser.add_argument("--model-name", default="hwf/hwf.pkl") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--plot-image", action="store_true") + parser.add_argument("--image-file", default="confusion.png") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--accuracy", action="store_true") + args = parser.parse_args() + + # Directories + random.seed(args.seed) + torch.manual_seed(args.seed) + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model")) + + # Load mnist dataset + dataset = HWFSymbolDataset(data_dir, amount=args.amount) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) + + # Load model + hwf_net = torch.load(open(os.path.join(model_dir, args.model_name), "rb")) + symbol_net: SymbolNet = hwf_net.symbol_cnn + symbol_net.eval() + + # Get prediction result + y_true, y_pred = [], [] + with torch.no_grad(): + for (imgs, digits) in dataloader: + pred_digits = numpy.argmax(symbol_net(imgs), axis=1) + y_true += [d.item() for d in digits] + y_pred += [d.item() for d in pred_digits] + + # Compute accuracy if asked + if args.accuracy: + acc = float(len([() for (x, y) in zip(y_true, y_pred) if x == y and x != 0])) / float(len([() for x in y_true if x != 0])) + print(f"Accuracy: {acc:4f}") + + # Compute confusion matrix + cm = confusion_matrix(y_true, y_pred) + + # Plot image or print + if args.plot_image: + df_cm = pd.DataFrame(cm, index=list(range(14)), columns=list(range(14))) + plt.figure(figsize=(10,7)) + sn.heatmap(df_cm, annot=True, cmap=plt.cm.Blues) + plt.ylabel("Actual") + plt.xlabel("Predicted") + plt.savefig(args.image_file) + else: + print(cm) diff --git a/experiments/hwf/readme.md b/experiments/hwf/readme.md new file mode 100644 index 0000000..e269ed7 --- /dev/null +++ b/experiments/hwf/readme.md @@ -0,0 +1,68 @@ +# Hand-written Formula (HWF) + +![1](docs/1_47.jpg) ![+](docs/+_96.jpg) ![3](docs/3_91.jpg) ![div/](docs/div_942.jpg) ![5](docs/5_237.jpg) + +This hand-written formula tests the weak-supervised learning setting for parsing and evaluating hand-written formula. +The above example should be representing a formula `1 + 3 / 5` and be evaluated to `1.6`. +In fact the five images and the final result `1.6` are all the training procedure takes in. +Scallop will serve as a differentiable and probabilistic formula parser and evaluator in this case, taking in proabilistic symbol recognition distributions and return the most likely evaluated result. + +The context free grammar of this simple expression language is defined formally below: + +``` +DIGIT ::= 0 | 1 | ... | 9 +MULT_EXPR ::= DIGIT + | MULT_EXPR * DIGIT + | MULT_EXPR / DIGIT +ADD_EXPR ::= MULT_EXPR + | ADD_EXPR + MULT_EXPR + | ADD_EXPR - MULT_EXPR +``` + +The Scallop implementation of the probabilistic parser for this expression language is written in [this file](scl/hwf_parser.scl). +If you wish to see an example of how a program works, checkout [examples/hwf_length_3_top_3_prob.scl](examples/hwf_length_3_top_3_prob.scl) and run it with + +``` +$ scli examples/hwf_length_3_top_3_prob.scl --provenance topkproofs +``` + +This file encodes a length-3 expression that most likely evaluates to `1 * 8 = 8`. +But certainly other results will be produced with lower probability too: + +``` +result: { + 0.063::(0.125), + 0.020870000000000003::(0.5), + 0.14400000000000004::(2), + 0.504::(8), + 0.028::(32) +} +``` + +## Experiment + +To run the experiment, please make sure you have the dataset downloaded. +The folder structure should look like + +``` +[SCALLOP_V2_DIR]/experiments/data/ +> HWF/ + > Handwritten_Math_Symbols/ + > -/ + > -_66.jpg + > -_87.jpg + > +/ + > +_10.jpg + > +_20.jpg + > 0/... + > 1/... + > ... + > expr_train.json + > expr_test.json +``` + +Then do + +``` +$ python run_with_hwf_parser.py +``` diff --git a/experiments/hwf/run.py b/experiments/hwf/run.py new file mode 100644 index 0000000..4e12578 --- /dev/null +++ b/experiments/hwf/run.py @@ -0,0 +1,255 @@ +import os +import json +import random +from argparse import ArgumentParser +from tqdm import tqdm +import math + +import torch +from torch import nn, optim +import torch.nn.functional as F +import torchvision +from PIL import Image + +import scallopy +import math + +class HWFDataset(torch.utils.data.Dataset): + def __init__(self, root: str, prefix: str, split: str): + super(HWFDataset, self).__init__() + self.root = root + self.split = split + self.metadata = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,))]) + + def __getitem__(self, index): + sample = self.metadata[index] + + # Input is a sequence of images + img_seq = [] + for img_path in sample["img_paths"]: + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + img_seq.append(img) + img_seq_len = len(img_seq) + + # Output is the "res" in the sample of metadata + res = sample["res"] + + # Return (input, output) pair + return (img_seq, img_seq_len, res) + + def __len__(self): + return len(self.metadata) + + @staticmethod + def collate_fn(batch): + max_len = max([img_seq_len for (_, img_seq_len, _) in batch]) + zero_img = torch.zeros_like(batch[0][0][0]) + pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq)) + img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _) in batch]) + img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _) in batch]) + results = torch.stack([torch.tensor(res) for (_, _, res) in batch]) + return (img_seqs, img_seq_len, results) + + +def hwf_loader(data_dir, batch_size, prefix): + train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + return (train_loader, test_loader) + + +class SymbolNet(nn.Module): + def __init__(self): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, 14) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = F.dropout(x, p=0.25, training=self.training) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class HWFNet(nn.Module): + def __init__(self, no_sample_k, sample_k, provenance, k): + super(HWFNet, self).__init__() + self.no_sample_k = no_sample_k + self.sample_k = sample_k + self.provenance = provenance + self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"] + + # Symbol embedding + self.symbol_cnn = SymbolNet() + + # Scallop context + self.eval_formula = scallopy.Module( + file=os.path.abspath(os.path.join(os.path.abspath(__file__), f"../scl/hwf_eval.scl")), + non_probabilistic=["length"], + input_mappings={"symbol": scallopy.Map({0: range(7), 1: self.symbols}, retain_k=sample_k, sample_dim=1, sample_strategy="categorical")}, + output_relation="result", + ) + + def forward(self, img_seq, img_seq_len): + batch_size, formula_length, _, _, _ = img_seq.shape + length = [[(l.item(),)] for l in img_seq_len] + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1) + (mapping, probs) = self.eval_formula(symbol=symbol, length=length) + return ([v for (v,) in mapping], probs) + + +class Trainer(): + def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, no_sample_k, sample_k, provenance, k): + self.network = HWFNet(no_sample_k, sample_k, provenance, k).to(device) + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + self.train_loader = train_loader + self.test_loader = test_loader + self.device = device + self.loss_fn = F.binary_cross_entropy + self.model_root = model_root + self.model_name = model_name + self.min_test_loss = 100000000.0 + + def eval_result_eq(self, a, b, threshold=0.01): + result = abs(a - b) < threshold + return result + + def train_epoch(self, epoch): + self.network.train() + num_items = 0 + train_loss = 0 + total_correct = 0 + iter = tqdm(self.train_loader, total=len(self.train_loader)) + for (i, (img_seq, img_seq_len, label)) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if not math.isnan(loss.item()): + train_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = train_loss / (i + 1) + + # Prints + iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + def test_epoch(self, epoch): + self.network.eval() + num_items = 0 + test_loss = 0 + total_correct = 0 + with torch.no_grad(): + iter = tqdm(self.test_loader, total=len(self.test_loader)) + for i, (img_seq, img_seq_len, label) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + if not math.isnan(loss.item()): + test_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = test_loss / (i + 1) + + # Prints + iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + # Save model + if test_loss < self.min_test_loss: + self.min_test_loss = test_loss + torch.save(self.network, os.path.join(self.model_root, self.model_name)) + + def train(self, n_epochs): + # self.test_epoch(0) + for epoch in range(1, n_epochs + 1): + self.train_epoch(epoch) + self.test_epoch(epoch) + + +if __name__ == "__main__": + # Command line arguments + parser = ArgumentParser("hwf") + parser.add_argument("--model-name", type=str, default="hwf.pkl") + parser.add_argument("--n-epochs", type=int, default=100) + parser.add_argument("--no-sample-k", action="store_true") + parser.add_argument("--sample-k", type=int, default=3) + parser.add_argument("--dataset-prefix", type=str, default="expr") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--loss-fn", type=str, default="bce") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--do-not-use-hash", action="store_true") + parser.add_argument("--provenance", type=str, default="difftopkproofs") + parser.add_argument("--top-k", type=int, default=3) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--jit", action="store_true") + parser.add_argument("--recompile", action="store_true") + args = parser.parse_args() + + # Parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Data + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf")) + if not os.path.exists(model_dir): os.makedirs(model_dir) + train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix) + + # Training + trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.no_sample_k, args.sample_k, args.provenance, args.top_k) + trainer.train(args.n_epochs) diff --git a/experiments/hwf/run_with_hwf_parser.py b/experiments/hwf/run_with_hwf_parser.py new file mode 100644 index 0000000..7edd880 --- /dev/null +++ b/experiments/hwf/run_with_hwf_parser.py @@ -0,0 +1,285 @@ +import os +import json +import random +from argparse import ArgumentParser +from tqdm import tqdm +import math + +import torch +from torch import nn, optim +import torch.nn.functional as F +import torchvision +from PIL import Image + +import scallopy +import math + +class HWFDataset(torch.utils.data.Dataset): + def __init__(self, root: str, prefix: str, split: str): + super(HWFDataset, self).__init__() + self.root = root + self.split = split + self.metadata = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,)) + ]) + + def __getitem__(self, index): + sample = self.metadata[index] + + # Input is a sequence of images + img_seq = [] + for img_path in sample["img_paths"]: + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + img_seq.append(img) + img_seq_len = len(img_seq) + + # Output is the "res" in the sample of metadata + res = sample["res"] + + # Return (input, output) pair + return (img_seq, img_seq_len, res) + + def __len__(self): + return len(self.metadata) + + @staticmethod + def collate_fn(batch): + max_len = max([img_seq_len for (_, img_seq_len, _) in batch]) + zero_img = torch.zeros_like(batch[0][0][0]) + pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq)) + img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _) in batch]) + img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _) in batch]) + results = torch.stack([torch.tensor(res) for (_, _, res) in batch]) + return (img_seqs, img_seq_len, results) + + +def hwf_loader(data_dir, batch_size, prefix): + train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + return (train_loader, test_loader) + + +class SymbolNet(nn.Module): + def __init__(self): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, 14) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = F.dropout(x, p=0.25, training=self.training) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class HWFNet(nn.Module): + def __init__(self, no_sample_k, sample_k, provenance, k, debug=False): + super(HWFNet, self).__init__() + self.no_sample_k = no_sample_k + self.sample_k = sample_k + self.provenance = provenance + self.debug = debug + + # Symbol embedding + self.symbol_cnn = SymbolNet() + + # Scallop context + self.scallop_file = "hwf_eval.scl" if not args.do_not_use_hash else "hwf_parser_wo_hash.scl" + self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"] + self.ctx = scallopy.ScallopContext(provenance=provenance, k=k) + self.ctx.import_file(os.path.abspath(os.path.join(os.path.abspath(__file__), f"../scl/{self.scallop_file}"))) + self.ctx.set_non_probabilistic("length") + self.ctx.set_input_mapping("symbol", [(i, s) for i in range(7) for s in self.symbols]) + if self.debug: + self.eval_formula = self.ctx.forward_function("result", dispatch="single", debug_provenance=True) + else: + self.eval_formula = self.ctx.forward_function("result", jit=args.jit, recompile=args.recompile) + + def forward(self, img_seq, img_seq_len): + batch_size, formula_length, _, _, _ = img_seq.shape + length = [[(l.item(),)] for l in img_seq_len] + if self.no_sample_k: return self._forward_with_no_sampling(batch_size, img_seq, length) + else: return self._forward_with_sampling(batch_size, formula_length, img_seq, img_seq_len, length) + + def _forward_with_no_sampling(self, batch_size, img_seq, length): + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, -1) + (mapping, probs) = self.eval_formula(symbol=symbol, length=length) + return ([v for (v,) in mapping], probs) + + def _forward_with_sampling(self, batch_size, formula_length, img_seq, img_seq_len, length): + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1) + symbol_facts = [[] for _ in range(batch_size)] + disjunctions = [[] for _ in range(batch_size)] + for task_id in range(batch_size): + for symbol_id in range(img_seq_len[task_id]): + # Compute the distribution and sample + symbols_distr = symbol[task_id, symbol_id] # Get the predicted distrubution + categ = torch.distributions.Categorical(symbols_distr) # Create a categorical distribution + sample_ids = [k.item() for k in categ.sample((self.sample_k,))] # Sample from this distribution + sample_ids = list(dict.fromkeys(sample_ids)) # Deduplicate the ids + + # Create facts + curr_symbol_facts = [(symbols_distr[k], (symbol_id, self.symbols[k])) for k in sample_ids] + + # Generate disjunction from facts + disjunctions[task_id].append([len(symbol_facts[task_id]) + i for i in range(len(curr_symbol_facts))]) + symbol_facts[task_id] += curr_symbol_facts + (mapping, probs) = self.eval_formula(symbol=symbol_facts, length=length, disjunctions={"symbol": disjunctions}) + return ([v for (v,) in mapping], probs) + + +class Trainer(): + def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, no_sample_k, sample_k, provenance, k): + self.network = HWFNet(no_sample_k, sample_k, provenance, k).to(device) + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + self.train_loader = train_loader + self.test_loader = test_loader + self.device = device + self.loss_fn = F.binary_cross_entropy + self.model_root = model_root + self.model_name = model_name + self.min_test_loss = 100000000.0 + + def eval_result_eq(self, a, b, threshold=0.01): + result = abs(a - b) < threshold + return result + + def train_epoch(self, epoch): + self.network.train() + num_items = 0 + train_loss = 0 + total_correct = 0 + iter = tqdm(self.train_loader, total=len(self.train_loader)) + for (i, (img_seq, img_seq_len, label)) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if not math.isnan(loss.item()): + train_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = train_loss / (i + 1) + + # Prints + iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + def test_epoch(self, epoch): + self.network.eval() + num_items = 0 + test_loss = 0 + total_correct = 0 + with torch.no_grad(): + iter = tqdm(self.test_loader, total=len(self.test_loader)) + for i, (img_seq, img_seq_len, label) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + if not math.isnan(loss.item()): + test_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = test_loss / (i + 1) + + # Prints + iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + # Save model + if test_loss < self.min_test_loss: + self.min_test_loss = test_loss + torch.save(self.network, os.path.join(self.model_root, self.model_name)) + + def train(self, n_epochs): + # self.test_epoch(0) + for epoch in range(1, n_epochs + 1): + self.train_epoch(epoch) + self.test_epoch(epoch) + + +if __name__ == "__main__": + # Command line arguments + parser = ArgumentParser("hwf") + parser.add_argument("--model-name", type=str, default="hwf.pkl") + parser.add_argument("--n-epochs", type=int, default=100) + parser.add_argument("--no-sample-k", action="store_true") + parser.add_argument("--sample-k", type=int, default=7) + parser.add_argument("--dataset-prefix", type=str, default="expr") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--loss-fn", type=str, default="bce") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--do-not-use-hash", action="store_true") + parser.add_argument("--provenance", type=str, default="difftopbottomkclauses") + parser.add_argument("--top-k", type=int, default=3) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--jit", action="store_true") + parser.add_argument("--recompile", action="store_true") + args = parser.parse_args() + + # Parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Data + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf")) + if not os.path.exists(model_dir): os.makedirs(model_dir) + train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix) + + # Training + trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.no_sample_k, args.sample_k, args.provenance, args.top_k) + trainer.train(args.n_epochs) diff --git a/experiments/hwf/run_with_hwf_parser_w_sample.py b/experiments/hwf/run_with_hwf_parser_w_sample.py new file mode 100644 index 0000000..a39f8c6 --- /dev/null +++ b/experiments/hwf/run_with_hwf_parser_w_sample.py @@ -0,0 +1,254 @@ +import os +import json +import random +from argparse import ArgumentParser +from tqdm import tqdm +import math + +import torch +from torch import nn, optim +import torch.nn.functional as F +import torchvision +from PIL import Image + +import scallopy +import math + +class HWFDataset(torch.utils.data.Dataset): + def __init__(self, root: str, prefix: str, split: str): + super(HWFDataset, self).__init__() + self.root = root + self.split = split + self.metadata = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,)) + ]) + + def __getitem__(self, index): + sample = self.metadata[index] + + # Input is a sequence of images + img_seq = [] + for img_path in sample["img_paths"]: + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + img_seq.append(img) + img_seq_len = len(img_seq) + + # Output is the "res" in the sample of metadata + res = sample["res"] + + # Return (input, output) pair + return (img_seq, img_seq_len, res) + + def __len__(self): + return len(self.metadata) + + @staticmethod + def collate_fn(batch): + max_len = max([img_seq_len for (_, img_seq_len, _) in batch]) + zero_img = torch.zeros_like(batch[0][0][0]) + pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq)) + img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _) in batch]) + img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _) in batch]) + results = torch.stack([torch.tensor(res) for (_, _, res) in batch]) + return (img_seqs, img_seq_len, results) + + +def hwf_loader(data_dir, batch_size, prefix): + train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + return (train_loader, test_loader) + + +class SymbolNet(nn.Module): + def __init__(self): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, 14) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = F.dropout(x, p=0.25, training=self.training) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class HWFNet(nn.Module): + def __init__(self, provenance, k, debug=False): + super(HWFNet, self).__init__() + self.provenance = provenance + self.debug = debug + + # Symbol embedding + self.symbol_cnn = SymbolNet() + + # Scallop context + self.scallop_file = "hwf_parser.scl" if not args.do_not_use_hash else "hwf_parser_wo_hash.scl" + self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"] + self.ctx = scallopy.ScallopContext(provenance=provenance, k=k) + self.ctx.import_file(os.path.abspath(os.path.join(os.path.abspath(__file__), f"../scl/{self.scallop_file}"))) + self.ctx.set_non_probabilistic("length") + self.ctx.set_input_mapping("symbol", [(i, s) for i in range(7) for s in self.symbols]) + if self.debug: self.eval_formula = self.ctx.forward_function("result", dispatch="single", debug_provenance=True) + else: self.eval_formula = self.ctx.forward_function("result", jit=args.jit, recompile=args.recompile) + + def forward(self, img_seq, img_seq_len): + batch_size, _, _, _, _ = img_seq.shape + length = [[(l.item(),)] for l in img_seq_len] + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, -1) + (mapping, probs) = self.eval_formula(symbol=symbol, length=length) + return ([v for (v,) in mapping], probs) + + +class Trainer(): + def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, provenance, k): + self.network = HWFNet(provenance, k).to(device) + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + self.train_loader = train_loader + self.test_loader = test_loader + self.device = device + self.loss_fn = F.binary_cross_entropy + self.model_root = model_root + self.model_name = model_name + self.min_test_loss = 100000000.0 + + def eval_result_eq(self, a, b, threshold=0.01): + result = abs(a - b) < threshold + return result + + def train_epoch(self, epoch): + self.network.train() + num_items = 0 + train_loss = 0 + total_correct = 0 + iter = tqdm(self.train_loader, total=len(self.train_loader)) + for (i, (img_seq, img_seq_len, label)) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if not math.isnan(loss.item()): + train_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = train_loss / (i + 1) + + # Prints + iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + def test_epoch(self, epoch): + self.network.eval() + num_items = 0 + test_loss = 0 + total_correct = 0 + with torch.no_grad(): + iter = tqdm(self.test_loader, total=len(self.test_loader)) + for i, (img_seq, img_seq_len, label) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + if not math.isnan(loss.item()): + test_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = test_loss / (i + 1) + + # Prints + iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + # Save model + if test_loss < self.min_test_loss: + self.min_test_loss = test_loss + torch.save(self.network, os.path.join(self.model_root, self.model_name)) + + def train(self, n_epochs): + # self.test_epoch(0) + for epoch in range(1, n_epochs + 1): + self.train_epoch(epoch) + self.test_epoch(epoch) + + +if __name__ == "__main__": + # Command line arguments + parser = ArgumentParser("hwf") + parser.add_argument("--model-name", type=str, default="hwf.pkl") + parser.add_argument("--n-epochs", type=int, default=100) + parser.add_argument("--dataset-prefix", type=str, default="expr") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--loss-fn", type=str, default="bce") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--do-not-use-hash", action="store_true") + parser.add_argument("--provenance", type=str, default="difftopkproofs") + parser.add_argument("--top-k", type=int, default=3) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--jit", action="store_true") + parser.add_argument("--recompile", action="store_true") + args = parser.parse_args() + + # Parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Data + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf")) + if not os.path.exists(model_dir): os.makedirs(model_dir) + train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix) + + # Training + trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.provenance, args.top_k) + trainer.train(args.n_epochs) diff --git a/experiments/hwf/run_with_mbs.py b/experiments/hwf/run_with_mbs.py new file mode 100644 index 0000000..b95a7aa --- /dev/null +++ b/experiments/hwf/run_with_mbs.py @@ -0,0 +1,433 @@ +import os +import json +import random +from argparse import ArgumentParser +from tqdm import tqdm +from queue import PriorityQueue +import math + +import torch +from torch import nn, optim +import torch.nn.functional as F +import torchvision +from PIL import Image + +import scallopy + +class HWFDataset(torch.utils.data.Dataset): + def __init__(self, root: str, prefix: str, split: str): + super(HWFDataset, self).__init__() + self.root = root + self.split = split + self.metadata = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,)) + ]) + + def __getitem__(self, index): + sample = self.metadata[index] + + # Input is a sequence of images + img_seq = [] + for img_path in sample["img_paths"]: + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + img_seq.append(img) + img_seq_len = len(img_seq) + + # Output is the "res" in the sample of metadata + res = sample["res"] + + # Return (input, output) pair + return (img_seq, img_seq_len, res) + + def __len__(self): + return len(self.metadata) + + @staticmethod + def collate_fn(batch): + max_len = max([img_seq_len for (_, img_seq_len, _) in batch]) + zero_img = torch.zeros_like(batch[0][0][0]) + pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq)) + img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _) in batch]) + img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _) in batch]) + results = torch.stack([torch.tensor(res) for (_, _, res) in batch]) + return (img_seqs, img_seq_len, results) + + +def hwf_loader(data_dir, batch_size, prefix): + train_loader = torch.utils.data.DataLoader( + HWFDataset(data_dir, prefix, "train"), + collate_fn=HWFDataset.collate_fn, + batch_size=batch_size, + shuffle=True, + ) + test_loader = torch.utils.data.DataLoader( + HWFDataset(data_dir, prefix, "test"), + collate_fn=HWFDataset.collate_fn, + batch_size=batch_size, + shuffle=True, + ) + return (train_loader, test_loader) + + +class ASTLeaf: + def __init__(self, prob, id, symbol): + self.prob = prob + self.id = id + self.symbol = symbol + + def node_id(self): + return self.id + + def is_operator(self): + return self.symbol in ["+", "-", "*", "/"] + + def contains(self, other): + if isinstance(other, ASTLeaf): return self.id == other.id and self.symbol == other.symbol + else: return False + + def probability(self): + return self.prob + + def __repr__(self): + return f"{self.symbol}" + + def __eq__(self, other): + if isinstance(other, ASTLeaf): return self.id == other.id and self.symbol == other.symbol + else: return False + + +class ASTNode: + def __init__(self, lhs, op, rhs): + self.lhs = lhs + self.op = op + self.rhs = rhs + + def node_id(self): + return self.op.node_id() + + def is_operator(self): + return False + + def contains(self, other): + if isinstance(other, ASTNode): + return self == other or self.lhs.contains(other) or self.rhs.contains(other) + else: + return self.lhs.contains(other) or self.rhs.contains(other) + + def probability(self): + return self.lhs.probability() * self.op.probability() * self.rhs.probability() + + def __repr__(self): + return f"({self.lhs} {self.op} {self.rhs})" + + def __eq__(self, other): + if isinstance(other, ASTNode): return self.lhs == other.lhs and self.op == other.op and self.rhs == other.rhs + else: return False + + +class ASTTag: + def __init__(self, asts): + deduped = [] + for ast in asts: + if ast not in deduped: + deduped.append(ast) + self.asts = deduped + + def probability(self): + if len(self.asts) == 0: + return 0.0 + else: + for ast in self.asts: + if isinstance(ast, ASTNode) or isinstance(ast, ASTLeaf): + return ast.probability() + + def __repr__(self): + return f"asttag({self.asts})" + + def filter_valid_proofs(self): + self.asts = [ast for ast in self.asts if type(ast) != list] + + def one_bs(self, expected): + self.filter_valid_proofs() + for ast in self.asts: + q = PriorityQueue() # Priority queue + q.put((1, (ast, expected))) + while True: + (_, (a, alpha_a)) = q.get() + print(a, alpha_a) + exit(0) + + +class MBSSemiring(scallopy.ScallopProvenance): + def base(self, info): + if info is not None: + (prob, id, symbol) = info + leaf = ASTLeaf(prob, id, symbol) + return ASTTag([leaf]) + else: + return self.one() + + def is_valid(self, tag: ASTNode): + return len(tag.asts) > 0 + + def zero(self): + return ASTTag([]) + + def one(self): + return ASTTag([[]]) + + def add(self, t1: ASTTag, t2: ASTTag): + return ASTTag(t1.asts + t2.asts) + + def mult(self, t1: ASTTag, t2: ASTTag): + joined_asts = [] + for a1 in t1.asts: + for a2 in t2.asts: + if type(a1) == list and type(a2) == list: joined_ast = a1 + a2 + elif type(a1) == list: joined_ast = [a for a in a1 if not a2.contains(a)] + [a2] + elif type(a2) == list: joined_ast = [a1] + [a for a in a2 if not a1.contains(a)] + else: + if a1.contains(a2): joined_ast = [a1] + elif a2.contains(a1): joined_ast = [a2] + else: joined_ast = [a1, a2] + + # Join things + if len(joined_ast) == 3: + joined_ast = sorted(joined_ast, key=lambda n: n.node_id()) + if joined_ast[1].is_operator(): + joined_ast = ASTNode(joined_ast[0], joined_ast[1], joined_ast[2]) + joined_asts.append(joined_ast) + elif len(joined_ast) == 1: joined_asts.append(joined_ast[0]) + elif len(joined_ast) < 3: joined_asts.append(joined_ast) + else: raise Exception("Should not happen") + + # Create the tag + return ASTTag(joined_asts) + + def aggregate_unique(self, elements): + max_prob = 0.0 + max_elem = None + for (tag, tup) in elements: + tag_prob = tag.probability() + if tag_prob > max_prob: + max_prob = tag_prob + max_elem = (tag, tup) + if max_elem is not None: return [max_elem] + else: return [] + + +class SymbolNet(nn.Module): + def __init__(self): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, 14) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class HWFNet(nn.Module): + def __init__(self): + super(HWFNet, self).__init__() + + # Symbol embedding + self.symbol_cnn = SymbolNet() + + # Scallop context + self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"] + self.ctx = scallopy.ScallopContext(provenance="custom", custom_provenance=MBSSemiring()) + self.ctx.import_file(os.path.abspath(os.path.join(os.path.abspath(__file__), "../scl/hwf_unique_parser.scl"))) + self.ctx.set_non_probabilistic("length") + self.eval_formula = self.ctx.forward_function("result") + + def forward(self, img_seq, img_seq_len): + batch_size, formula_length, _, _, _ = img_seq.shape + length = [[(l.item(),)] for l in img_seq_len] + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1) + symbol_facts = [[] for _ in range(batch_size)] + for task_id in range(batch_size): + for symbol_id in range(img_seq_len[task_id]): + symbols_distr = symbol[task_id, symbol_id] + curr_symbol_facts = [((p, symbol_id, self.symbols[k]), (symbol_id, self.symbols[k])) for (k, p) in enumerate(symbols_distr)] + symbol_facts[task_id] += curr_symbol_facts + (result_mapping, tags) = self.eval_formula(symbol=symbol_facts, length=length) + return self._extract_result(result_mapping, tags, batch_size) + + def _extract_result(self, result_mapping, tags, batch_size): + result = [] + for task_id in range(batch_size): + max_prob = 0.0 + max_tag = None + max_result_id = None + for i, tag in enumerate(tags[task_id]): + if tag is not None: + p = tag.probability() + if p is not None and p > max_prob: + max_prob = p + max_tag = tag + max_result_id = i + if max_result_id: result.append((result_mapping[max_result_id][0], max_tag)) + else: result.append(None) + return result + + +class Trainer(): + def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate): + self.network = HWFNet().to(device) + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + self.train_loader = train_loader + self.test_loader = test_loader + self.device = device + self.loss_fn = F.binary_cross_entropy + self.model_root = model_root + self.model_name = model_name + self.min_test_loss = 100000000.0 + + def eval_result_eq(self, a, b, threshold=0.01): + result = abs(a - b) < threshold + return result + + def train_epoch(self, epoch): + self.network.train() + num_items = 0 + train_loss = 0 + total_correct = 0 + iter = tqdm(self.train_loader, total=len(self.train_loader)) + for (i, (img_seq, img_seq_len, label)) in enumerate(iter): + self.optimizer.zero_grad() + + # Run the network and get the results + result = self.network(img_seq.to(device), img_seq_len.to(device)) + + for (task_id, (y_pred, tag)) in enumerate(result): + y = label[task_id] + tag.one_bs(y) + + + + + + # # Normalize label format + # batch_size, num_outputs = y_pred.shape + # y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # # Compute loss + # loss = self.loss_fn(y_pred, y) + # train_loss += loss.item() + # loss.backward() + # self.optimizer.step() + + # # Collect index and compute accuracy + # correct_count = 0 + # if num_outputs > 0: + # y_index = torch.argmax(y, dim=1) + # y_pred_index = torch.argmax(y_pred, dim=1) + # correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + + # # Stats + # num_items += batch_size + # total_correct += correct_count + # perc = 100. * total_correct / num_items + # avg_loss = train_loss / (i + 1) + + # # Prints + # iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + def test_epoch(self, epoch): + self.network.eval() + num_items = 0 + test_loss = 0 + total_correct = 0 + with torch.no_grad(): + iter = tqdm(self.test_loader, total=len(self.test_loader)) + for i, (img_seq, img_seq_len, label) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + test_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = test_loss / (i + 1) + + # Prints + iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + # Save model + if test_loss < self.min_test_loss: + self.min_test_loss = test_loss + torch.save(self.network, os.path.join(self.model_root, self.model_name)) + + def train(self, n_epochs): + # self.test_epoch(0) + for epoch in range(1, n_epochs + 1): + self.train_epoch(epoch) + self.test_epoch(epoch) + + +if __name__ == "__main__": + # Command line arguments + parser = ArgumentParser("hwf") + parser.add_argument("--model-name", type=str, default="hwf.pkl") + parser.add_argument("--n-epochs", type=int, default=100) + parser.add_argument("--dataset-prefix", type=str, default="expr") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--loss-fn", type=str, default="bce") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + args = parser.parse_args() + + # Parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Data + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf")) + if not os.path.exists(model_dir): os.makedirs(model_dir) + train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix) + + # Training + trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate) + trainer.train(args.n_epochs) diff --git a/experiments/hwf/run_with_purely_discrete_sample.py b/experiments/hwf/run_with_purely_discrete_sample.py new file mode 100644 index 0000000..7edd880 --- /dev/null +++ b/experiments/hwf/run_with_purely_discrete_sample.py @@ -0,0 +1,285 @@ +import os +import json +import random +from argparse import ArgumentParser +from tqdm import tqdm +import math + +import torch +from torch import nn, optim +import torch.nn.functional as F +import torchvision +from PIL import Image + +import scallopy +import math + +class HWFDataset(torch.utils.data.Dataset): + def __init__(self, root: str, prefix: str, split: str): + super(HWFDataset, self).__init__() + self.root = root + self.split = split + self.metadata = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,)) + ]) + + def __getitem__(self, index): + sample = self.metadata[index] + + # Input is a sequence of images + img_seq = [] + for img_path in sample["img_paths"]: + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + img_seq.append(img) + img_seq_len = len(img_seq) + + # Output is the "res" in the sample of metadata + res = sample["res"] + + # Return (input, output) pair + return (img_seq, img_seq_len, res) + + def __len__(self): + return len(self.metadata) + + @staticmethod + def collate_fn(batch): + max_len = max([img_seq_len for (_, img_seq_len, _) in batch]) + zero_img = torch.zeros_like(batch[0][0][0]) + pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq)) + img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _) in batch]) + img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _) in batch]) + results = torch.stack([torch.tensor(res) for (_, _, res) in batch]) + return (img_seqs, img_seq_len, results) + + +def hwf_loader(data_dir, batch_size, prefix): + train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test"), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True) + return (train_loader, test_loader) + + +class SymbolNet(nn.Module): + def __init__(self): + super(SymbolNet, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.fc1 = nn.Linear(30976, 128) + self.fc2 = nn.Linear(128, 14) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2) + x = F.dropout(x, p=0.25, training=self.training) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, p=0.5, training=self.training) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class HWFNet(nn.Module): + def __init__(self, no_sample_k, sample_k, provenance, k, debug=False): + super(HWFNet, self).__init__() + self.no_sample_k = no_sample_k + self.sample_k = sample_k + self.provenance = provenance + self.debug = debug + + # Symbol embedding + self.symbol_cnn = SymbolNet() + + # Scallop context + self.scallop_file = "hwf_eval.scl" if not args.do_not_use_hash else "hwf_parser_wo_hash.scl" + self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"] + self.ctx = scallopy.ScallopContext(provenance=provenance, k=k) + self.ctx.import_file(os.path.abspath(os.path.join(os.path.abspath(__file__), f"../scl/{self.scallop_file}"))) + self.ctx.set_non_probabilistic("length") + self.ctx.set_input_mapping("symbol", [(i, s) for i in range(7) for s in self.symbols]) + if self.debug: + self.eval_formula = self.ctx.forward_function("result", dispatch="single", debug_provenance=True) + else: + self.eval_formula = self.ctx.forward_function("result", jit=args.jit, recompile=args.recompile) + + def forward(self, img_seq, img_seq_len): + batch_size, formula_length, _, _, _ = img_seq.shape + length = [[(l.item(),)] for l in img_seq_len] + if self.no_sample_k: return self._forward_with_no_sampling(batch_size, img_seq, length) + else: return self._forward_with_sampling(batch_size, formula_length, img_seq, img_seq_len, length) + + def _forward_with_no_sampling(self, batch_size, img_seq, length): + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, -1) + (mapping, probs) = self.eval_formula(symbol=symbol, length=length) + return ([v for (v,) in mapping], probs) + + def _forward_with_sampling(self, batch_size, formula_length, img_seq, img_seq_len, length): + symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1) + symbol_facts = [[] for _ in range(batch_size)] + disjunctions = [[] for _ in range(batch_size)] + for task_id in range(batch_size): + for symbol_id in range(img_seq_len[task_id]): + # Compute the distribution and sample + symbols_distr = symbol[task_id, symbol_id] # Get the predicted distrubution + categ = torch.distributions.Categorical(symbols_distr) # Create a categorical distribution + sample_ids = [k.item() for k in categ.sample((self.sample_k,))] # Sample from this distribution + sample_ids = list(dict.fromkeys(sample_ids)) # Deduplicate the ids + + # Create facts + curr_symbol_facts = [(symbols_distr[k], (symbol_id, self.symbols[k])) for k in sample_ids] + + # Generate disjunction from facts + disjunctions[task_id].append([len(symbol_facts[task_id]) + i for i in range(len(curr_symbol_facts))]) + symbol_facts[task_id] += curr_symbol_facts + (mapping, probs) = self.eval_formula(symbol=symbol_facts, length=length, disjunctions={"symbol": disjunctions}) + return ([v for (v,) in mapping], probs) + + +class Trainer(): + def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, no_sample_k, sample_k, provenance, k): + self.network = HWFNet(no_sample_k, sample_k, provenance, k).to(device) + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + self.train_loader = train_loader + self.test_loader = test_loader + self.device = device + self.loss_fn = F.binary_cross_entropy + self.model_root = model_root + self.model_name = model_name + self.min_test_loss = 100000000.0 + + def eval_result_eq(self, a, b, threshold=0.01): + result = abs(a - b) < threshold + return result + + def train_epoch(self, epoch): + self.network.train() + num_items = 0 + train_loss = 0 + total_correct = 0 + iter = tqdm(self.train_loader, total=len(self.train_loader)) + for (i, (img_seq, img_seq_len, label)) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if not math.isnan(loss.item()): + train_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = train_loss / (i + 1) + + # Prints + iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + def test_epoch(self, epoch): + self.network.eval() + num_items = 0 + test_loss = 0 + total_correct = 0 + with torch.no_grad(): + iter = tqdm(self.test_loader, total=len(self.test_loader)) + for i, (img_seq, img_seq_len, label) in enumerate(iter): + (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device)) + y_pred = y_pred.to("cpu") + + # Normalize label format + batch_size, num_outputs = y_pred.shape + y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1) + + # Compute loss + loss = self.loss_fn(y_pred, y) + if not math.isnan(loss.item()): + test_loss += loss.item() + + # Collect index and compute accuracy + if num_outputs > 0: + y_index = torch.argmax(y, dim=1) + y_pred_index = torch.argmax(y_pred, dim=1) + correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item() + else: + correct_count = 0 + + # Stats + num_items += batch_size + total_correct += correct_count + perc = 100. * total_correct / num_items + avg_loss = test_loss / (i + 1) + + # Prints + iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)") + + # Save model + if test_loss < self.min_test_loss: + self.min_test_loss = test_loss + torch.save(self.network, os.path.join(self.model_root, self.model_name)) + + def train(self, n_epochs): + # self.test_epoch(0) + for epoch in range(1, n_epochs + 1): + self.train_epoch(epoch) + self.test_epoch(epoch) + + +if __name__ == "__main__": + # Command line arguments + parser = ArgumentParser("hwf") + parser.add_argument("--model-name", type=str, default="hwf.pkl") + parser.add_argument("--n-epochs", type=int, default=100) + parser.add_argument("--no-sample-k", action="store_true") + parser.add_argument("--sample-k", type=int, default=7) + parser.add_argument("--dataset-prefix", type=str, default="expr") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--loss-fn", type=str, default="bce") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--do-not-use-hash", action="store_true") + parser.add_argument("--provenance", type=str, default="difftopbottomkclauses") + parser.add_argument("--top-k", type=int, default=3) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--jit", action="store_true") + parser.add_argument("--recompile", action="store_true") + args = parser.parse_args() + + # Parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Data + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf")) + if not os.path.exists(model_dir): os.makedirs(model_dir) + train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix) + + # Training + trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.no_sample_k, args.sample_k, args.provenance, args.top_k) + trainer.train(args.n_epochs) diff --git a/experiments/hwf/scl/hwf_eval.scl b/experiments/hwf/scl/hwf_eval.scl new file mode 100644 index 0000000..a4b39cb --- /dev/null +++ b/experiments/hwf/scl/hwf_eval.scl @@ -0,0 +1,22 @@ +type symbol(index: usize, symbol: String) +type length(n: usize) + +rel digit = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} + +type factor(value: f32, begin: usize, end: usize) +rel factor(x as f32, b, b + 1) = symbol(b, x) and digit(x) + +type mult_div(value: f32, begin: usize, end: usize) +rel mult_div(x, b, r) = factor(x, b, r) +rel mult_div(x * y, b, e) = mult_div(x, b, m) and symbol(m, "*") and factor(y, m + 1, e) +rel mult_div(x / y, b, e) = mult_div(x, b, m) and symbol(m, "/") and factor(y, m + 1, e) + +type add_minus(value: f32, begin: usize, end: usize) +rel add_minus(x, b, r) = mult_div(x, b, r) +rel add_minus(x + y, b, e) = add_minus(x, b, m) and symbol(m, "+") and mult_div(y, m + 1, e) +rel add_minus(x - y, b, e) = add_minus(x, b, m) and symbol(m, "-") and mult_div(y, m + 1, e) + +type result(value: f32) +rel result(y) = add_minus(y, 0, l) and length(l) + +query result diff --git a/experiments/hwf/scl/hwf_parser.scl b/experiments/hwf/scl/hwf_parser.scl new file mode 100644 index 0000000..d0f8db6 --- /dev/null +++ b/experiments/hwf/scl/hwf_parser.scl @@ -0,0 +1,36 @@ +// Inputs +type symbol(usize, String) +type length(usize) + +// Facts for lexing +rel digit = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} +rel mult_div = {"*", "/"} +rel plus_minus = {"+", "-"} + +// Parsing +type value_node(id: u64, string: String, begin: usize, end: usize) +rel value_node($hash(x, d), d, x, x + 1) = symbol(x, d), digit(d), length(n), x < n + +type mult_div_node(id: u64, string: String, left_node: u64, right_node: u64, begin: usize, end: usize) +rel mult_div_node(id, string, 0, 0, b, e) = value_node(id, string, b, e) +rel mult_div_node($hash(id, s, l, r), s, l, r, b, e) = + symbol(id, s), mult_div(s), mult_div_node(l, _, _, _, b, id), value_node(r, _, id + 1, e) + +type plus_minus_node(id: u64, string: String, left_node: u64, right_node: u64, begin: usize, end: usize) +rel plus_minus_node(id, string, l, r, b, e) = mult_div_node(id, string, l, r, b, e) +rel plus_minus_node($hash(id, s, l, r), s, l, r, b, e) = + symbol(id, s), plus_minus(s), plus_minus_node(l, _, _, _, b, id), mult_div_node(r, _, _, _, id + 1, e) + +type root_node(id: u64) +rel root_node(id) = plus_minus_node(id, _, _, _, 0, l), length(l) + +// Evaluate AST +@demand("bf") +rel eval(x, s as f64) = value_node(x, s, _, _) +rel eval(x, y1 + y2) = plus_minus_node(x, "+", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 - y2) = plus_minus_node(x, "-", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 * y2) = mult_div_node(x, "*", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 / y2) = mult_div_node(x, "/", l, r, _, _), eval(l, y1), eval(r, y2), y2 != 0.0 + +// Compute result +rel result(y) = eval(e, y), root_node(e) diff --git a/experiments/hwf/scl/hwf_parser_w_sample.scl b/experiments/hwf/scl/hwf_parser_w_sample.scl new file mode 100644 index 0000000..0150e28 --- /dev/null +++ b/experiments/hwf/scl/hwf_parser_w_sample.scl @@ -0,0 +1,39 @@ +// Inputs +type symbol(usize, String) +type length(usize) + +// Facts for lexing +rel digit = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} +rel mult_div = {"*", "/"} +rel plus_minus = {"+", "-"} + +// Sampling +rel sampled_symbol(id, sym) :- sym = top<7>(s: symbol(id, s), length(n), id < n) + +// Parsing +type value_node(id: u64, string: String, begin: usize, end: usize) +rel value_node($hash(x, d), d, x, x + 1) = sampled_symbol(x, d), digit(d) + +type mult_div_node(id: u64, string: String, left_node: u64, right_node: u64, begin: usize, end: usize) +rel mult_div_node(id, string, 0, 0, b, e) = value_node(id, string, b, e) +rel mult_div_node($hash(id, s, l, r), s, l, r, b, e) = + sampled_symbol(id, s), mult_div(s), mult_div_node(l, _, _, _, b, id), value_node(r, _, id + 1, e) + +type plus_minus_node(id: u64, string: String, left_node: u64, right_node: u64, begin: usize, end: usize) +rel plus_minus_node(id, string, l, r, b, e) = mult_div_node(id, string, l, r, b, e) +rel plus_minus_node($hash(id, s, l, r), s, l, r, b, e) = + sampled_symbol(id, s), plus_minus(s), plus_minus_node(l, _, _, _, b, id), mult_div_node(r, _, _, _, id + 1, e) + +type root_node(id: u64) +rel root_node(id) = plus_minus_node(id, _, _, _, 0, l), length(l) + +// Evaluate AST +@demand("bf") +rel eval(x, s as f64) = value_node(x, s, _, _) +rel eval(x, y1 + y2) = plus_minus_node(x, "+", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 - y2) = plus_minus_node(x, "-", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 * y2) = mult_div_node(x, "*", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 / y2) = mult_div_node(x, "/", l, r, _, _), eval(l, y1), eval(r, y2), y2 != 0.0 + +// Compute result +rel result(y) = eval(e, y), root_node(e) diff --git a/experiments/hwf/scl/hwf_parser_wo_hash.scl b/experiments/hwf/scl/hwf_parser_wo_hash.scl new file mode 100644 index 0000000..59f9424 --- /dev/null +++ b/experiments/hwf/scl/hwf_parser_wo_hash.scl @@ -0,0 +1,41 @@ +// Inputs +type symbol(u64, String) +type length(u64) + +// Facts for lexing +rel digit = {("0", 0.0), ("1", 1.0), ("2", 2.0), ("3", 3.0), ("4", 4.0), ("5", 5.0), ("6", 6.0), ("7", 7.0), ("8", 8.0), ("9", 9.0)} +rel mult_div = {"*", "/"} +rel plus_minus = {"+", "-"} + +// Symbol ID for node index calculation +rel symbol_id = {("+", 1), ("-", 2), ("*", 3), ("/", 4)} + +// Node ID Hashing +@demand("bbbbf") +rel node_id_hash(x, s, l, r, x + sid * n + l * 4 * n + r * 4 * n * n) = symbol_id(s, sid), length(n) + +// Parsing +rel value_node(x, v) = + symbol(x, d), digit(d, v), length(n), x < n +rel mult_div_node(x, "v", x, x, x, x, x) = + value_node(x, _) +rel mult_div_node(h, s, x, l, end, begin, end) = + symbol(x, s), mult_div(s), node_id_hash(x, s, l, end, h), + mult_div_node(l, _, _, _, _, begin, x - 1), + value_node(end, _), end == x + 1 +rel plus_minus_node(x, t, i, l, r, begin, end) = + mult_div_node(x, t, i, l, r, begin, end) +rel plus_minus_node(h, s, x, l, r, begin, end) = + symbol(x, s), plus_minus(s), node_id_hash(x, s, l, r, h), + plus_minus_node(l, _, _, _, _, begin, x - 1), + mult_div_node(r, _, _, _, _, x + 1, end) + +// Evaluate AST +rel eval(x, y, x, x) = value_node(x, y) +rel eval(x, y1 + y2, b, e) = plus_minus_node(x, "+", i, l, r, b, e), eval(l, y1, b, i - 1), eval(r, y2, i + 1, e) +rel eval(x, y1 - y2, b, e) = plus_minus_node(x, "-", i, l, r, b, e), eval(l, y1, b, i - 1), eval(r, y2, i + 1, e) +rel eval(x, y1 * y2, b, e) = mult_div_node(x, "*", i, l, r, b, e), eval(l, y1, b, i - 1), eval(r, y2, i + 1, e) +rel eval(x, y1 / y2, b, e) = mult_div_node(x, "/", i, l, r, b, e), eval(l, y1, b, i - 1), eval(r, y2, i + 1, e), y2 != 0.0 + +// Compute result +rel result(y) = eval(e, y, 0, n - 1), length(n) diff --git a/experiments/hwf/scl/hwf_sample.scl b/experiments/hwf/scl/hwf_sample.scl new file mode 100644 index 0000000..b589ae9 --- /dev/null +++ b/experiments/hwf/scl/hwf_sample.scl @@ -0,0 +1,4 @@ +@py_eval type $py_eval_number(s: String) -> f32 + +type symbol(id: i32, ) +type length(i32) diff --git a/experiments/hwf/scl/hwf_unique_parser.scl b/experiments/hwf/scl/hwf_unique_parser.scl new file mode 100644 index 0000000..8b007a6 --- /dev/null +++ b/experiments/hwf/scl/hwf_unique_parser.scl @@ -0,0 +1,30 @@ +// Inputs +type symbol(i32, String) +type length(i32) + +// Facts for lexing +rel digit = {("0", 0.0), ("1", 1.0), ("2", 2.0), ("3", 3.0), ("4", 4.0), ("5", 5.0), ("6", 6.0), ("7", 7.0), ("8", 8.0), ("9", 9.0)} +rel mult_div = {"*", "/"} +rel plus_minus = {"+", "-"} + +// Parsing +rel value_node(x, v) = + v = unique(v: symbol(x, d), digit(d, v)) +rel mult_div_node(x, "v", -1, -1, x, x) = value_node(x, _) +rel mult_div_node(x, s, l, r, l_begin, r) = + s = unique(s: symbol(x, s), mult_div(s)), + mult_div_node(l, _, _, _, l_begin, l_end), l_end == x - 1, value_node(r, _), r == x + 1 +rel plus_minus_node(x, t, l, r, begin, end) = mult_div_node(x, t, l, r, begin, end) +rel plus_minus_node(x, s, l, r, l_b, r_e) = + s = unique(s: symbol(x, s), plus_minus(s)), + plus_minus_node(l, _, _, _, l_b, l_e), l_e == x - 1, mult_div_node(r, _, _, _, r_b, r_e), r_b == x + 1 + +// Evaluate AST +rel eval(x, y, x, x) = value_node(x, y) +rel eval(x, y1 + y2, b, e) = plus_minus_node(x, "+", l, r, b, e), eval(l, y1, b, x - 1), eval(r, y2, x + 1, e) +rel eval(x, y1 - y2, b, e) = plus_minus_node(x, "-", l, r, b, e), eval(l, y1, b, x - 1), eval(r, y2, x + 1, e) +rel eval(x, y1 * y2, b, e) = plus_minus_node(x, "*", l, r, b, e), eval(l, y1, b, x - 1), eval(r, y2, x + 1, e) +rel eval(x, y1 / y2, b, e) = plus_minus_node(x, "/", l, r, b, e), eval(l, y1, b, x - 1), eval(r, y2, x + 1, e), y2 != 0.0 + +// Compute result +rel result(y) = eval(e, y, 0, n - 1), length(n) diff --git a/experiments/hwf/test_hwf_model.py b/experiments/hwf/test_hwf_model.py new file mode 100644 index 0000000..f03ca43 --- /dev/null +++ b/experiments/hwf/test_hwf_model.py @@ -0,0 +1,148 @@ +import os +import json +import random +from argparse import ArgumentParser +from tqdm import tqdm +import math + +import torch +from torch import nn, optim +import torch.nn.functional as F +import torchvision +from PIL import Image + +import scallopy +import math + +from run_with_hwf_parser import HWFNet, SymbolNet + +class HWFDataset(torch.utils.data.Dataset): + def __init__(self, root: str, prefix: str, split: str): + super(HWFDataset, self).__init__() + self.root = root + self.split = split + self.metadata = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,), (1,)) + ]) + + def __getitem__(self, index): + sample = self.metadata[index] + + # Input is a sequence of images + img_seq = [] + for img_path in sample["img_paths"]: + img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + img_seq.append(img) + img_seq_len = len(img_seq) + + # Output is the "res" in the sample of metadata + res = sample["res"] + + # Output the expression + expr = sample["expr"] + + # Return (input, output) pair + return (img_seq, img_seq_len, expr, res) + + def __len__(self): + return len(self.metadata) + + @staticmethod + def collate_fn(batch): + max_len = max([img_seq_len for (_, img_seq_len, _, _) in batch]) + zero_img = torch.zeros_like(batch[0][0][0]) + pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq)) + img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _, _) in batch]) + img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _, _) in batch]) + results = torch.stack([torch.tensor(res) for (_, _, _, res) in batch]) + exprs = [expr for (_, _, expr, _) in batch] + return (img_seqs, img_seq_len, exprs, results) + + +def hwf_loader(data_dir, batch_size, prefix): + return torch.utils.data.DataLoader( + HWFDataset(data_dir, prefix, "test"), + collate_fn=HWFDataset.collate_fn, + batch_size=batch_size, + shuffle=False, + ) + +def eval_result_eq(a, b, threshold=0.01): + return abs(a - b) < threshold + +if __name__ == "__main__": + # Command line arguments + parser = ArgumentParser("test_hwf_model") + parser.add_argument("--model-name", type=str, default="hwf.pkl") + parser.add_argument("--dataset-prefix", type=str, default="expr") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--seed", type=int, default=12345) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + # Parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Data + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix) + + # Model + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf")) + hwf_net = torch.load(open(os.path.join(model_dir, args.model_name), "rb")) + + # Testing + hwf_net.eval() + with torch.no_grad(): + processed_count, correct_count = 0, 0 + dataset_iter = test_loader if args.verbose else tqdm(test_loader) + for img_seq, img_seq_len, exprs, labels in dataset_iter: + batch_size, formula_length, _, _, _ = img_seq.shape + + # Predict per character symbol + symbol_distr = hwf_net.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1) + exprs_pred = ["".join([hwf_net.symbols[torch.argmax(s).item()] for s in task_symbols_distrs]) for task_symbols_distrs in symbol_distr] + + # Do the prediction + (output_mapping, y_pred) = hwf_net(img_seq, img_seq_len) + y_pred = y_pred.to("cpu") + + # Get the predictions + y_pred_index = torch.argmax(y_pred, dim=1) + + # Iterate through all examples + batch_correct_count = 0 + for i in range(batch_size): + expr = exprs[i] + expr_pred = exprs_pred[i] + y = labels[i] + y_pred = output_mapping[y_pred_index[i]] + is_correct = eval_result_eq(y, y_pred) + if is_correct: + batch_correct_count += 1 + correct_str = "[Correct]" if is_correct else "[Incorrect]" + if args.verbose: + print(f"{correct_str} Ground Truth Expr: {expr}, Predicted Expr: {expr_pred}, Ground Truth: {y}, Computed: {y_pred}") + + # Update the progress bar + processed_count += batch_size + correct_count += batch_correct_count + if not args.verbose: + perc = float(correct_count) / float(processed_count) * 100.0 + dataset_iter.set_description(f"Correct: {correct_count}/{processed_count} ({perc:.4f}%)") + + # If verbose, print the accuracy at the end + if args.verbose: + perc = float(correct_count) / float(processed_count) * 100.0 + print(f"Overall Correctness: {correct_count}/{processed_count} ({perc:.4f}%)") diff --git a/experiments/hwf/variants/inc_dec.py b/experiments/hwf/variants/inc_dec.py new file mode 100644 index 0000000..ca62338 --- /dev/null +++ b/experiments/hwf/variants/inc_dec.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision +import json +import scallopy +import random +import os +from PIL import Image +from tqdm import tqdm +from typing import * + +data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../../data")) + +class ECExprDataset(torch.utils.data.Dataset): + def __init__(self, train: bool = True): + split = "train" if train else "test" + self.metadata = json.load(open(os.path.join(data_dir, f"HWF/inc_dec_expr_{split}.json"))) + self.img_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (1,))]) + + def __len__(self): + return len(self.metadata) + + def __getitem__(self, idx): + datapoint = self.metadata[idx] + imgs = [] + for img_path in datapoint["img_paths"]: + img_full_path = os.path.join(data_dir, "HWF/Handwritten_Math_Symbols", img_path) + img = Image.open(img_full_path).convert("L") + img = self.img_transform(img) + imgs.append(img) + res = datapoint["res"] + return (tuple(imgs), res) + +def EC_expr_loader(batch_size): + train_loader = torch.utils.data.DataLoader(ECExprDataset(train=True), batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(ECExprDataset(train=False), batch_size=batch_size, shuffle=True) + return train_loader, test_loader + +class ConvolutionNeuralNetEC(nn.Module): + def __init__(self, num_classes): + super(ConvolutionNeuralNetEC, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) + self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) + self.fc1 = nn.Linear(7744, 128) + self.fc2 = nn.Linear(128, num_classes) + + def forward(self, x): + x = F.max_pool2d(self.conv1(x), 2) + x = F.max_pool2d(self.conv2(x), 2) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.dropout(x, p = 0.5, training=self.training) + x = self.fc2(x) + return F.softmax(x, dim=1) + +class ECExprNet(nn.Module): + def __init__(self): + super(ECExprNet, self).__init__() + + # Symbol Recognition CNN(s) + self.digit_cnn = ConvolutionNeuralNetEC(10) + self.symbol_cnn = ConvolutionNeuralNetEC(2) + + # Scallop Context + self.compute = scallopy.ScallopForwardFunction( + program=""" + type digit(i32), op1(String), op2(String) + rel result(a + 1) = digit(a) and op1("+") and op2("+") + rel result(a - 1) = digit(a) and op1("-") and op2("-") + """, + provenance="difftopkproofs", + input_mappings={ + "digit": list(range(10)), + "op1": ["+", "-"], + "op2": ["+", "-"], + }, + output_relation="result", + output_mapping=list(range(-1, 11)), + ) + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): + return self.compute( + digit=self.digit_cnn(x[0]), + op1=self.symbol_cnn(x[1]), + op2=self.symbol_cnn(x[2]), + ) + +class ECExprTrainer(): + def __init__(self, train_loader, test_loader, learning_rate): + self.network = ECExprNet() + self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) + self.train_loader = train_loader + self.test_loader = test_loader + + def loss(self, output, ground_truth): + output_mapping = list(range(-1, 11)) + (_, dim) = output.shape + gt = torch.stack([torch.tensor([1.0 if output_mapping[i] == t else 0.0 for i in range(dim)]) for t in ground_truth]) + return F.binary_cross_entropy(output, gt) + + def train_epoch(self, epoch): + self.network.train() + train_loss = 0.0 + iter = tqdm(self.train_loader, total=len(self.train_loader)) + for (batch_id, (data, target)) in enumerate(iter): + self.optimizer.zero_grad() + output = self.network(data) + loss = self.loss(output, target) + loss.backward() + self.optimizer.step() + train_loss += loss.item() + avg_loss = train_loss / (batch_id + 1) + iter.set_description(f"[Train Epoch {epoch}] Batch Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}") + + def test_epoch(self, epoch): + self.network.eval() + num_items = 0 + test_loss = 0 + correct = 0 + with torch.no_grad(): + iter = tqdm(self.test_loader, total=len(self.test_loader)) + for (batch_id, (data, target)) in enumerate(iter): + output = self.network(data) + test_loss += self.loss(output, target).item() + avg_loss = test_loss / (batch_id + 1) + pred = output.data.max(1, keepdim=True)[1] - 1 + correct += pred.eq(target.data.view_as(pred)).sum() + num_items += pred.shape[0] + perc = 100. * correct / num_items + iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {correct}/{num_items} ({perc:.2f}%)") + + def train(self, n_epochs): + self.test_epoch(0) + for epoch in range(1, n_epochs + 1): + self.train_epoch(epoch) + self.test_epoch(epoch) + +# Parameters +n_epochs = 3 +batch_size = 32 +learning_rate = 0.001 +seed = 1234 + +# Random seed +torch.manual_seed(seed) +random.seed(seed) + +# Dataloaders +train_loader, test_loader = EC_expr_loader(batch_size) +trainer = ECExprTrainer(train_loader, test_loader, learning_rate) +trainer.train(n_epochs)