-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding scallop plugins; bumping versions
- Loading branch information
Showing
63 changed files
with
1,337 additions
and
623 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "scallop-codegen" | ||
version = "0.2.1" | ||
version = "0.2.2" | ||
authors = ["Ziyang Li <[email protected]>"] | ||
edition = "2018" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
build: | ||
python -m build | ||
|
||
install: build | ||
find dist -name "*.whl" -print | xargs pip install --force-reinstall | ||
|
||
develop: | ||
pip install --editable . | ||
|
||
clean: | ||
rm -rf dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,168 +1,12 @@ | ||
import sys | ||
|
||
# Import torch (optional) | ||
try: import torch | ||
except: pass | ||
|
||
# Argument parser | ||
import argparse | ||
|
||
# Prompting for REPL | ||
import prompt_toolkit | ||
from tabulate import tabulate | ||
|
||
# Scallop modules | ||
import scallopy | ||
import scallopy_ext | ||
|
||
|
||
TABLE_FMT = "rounded_outline" | ||
MULTILINE_TABLE_FMT = "fancy_grid" | ||
MAX_COL_WIDTH = 120 | ||
|
||
|
||
def argument_parser(plugin_registry: scallopy_ext.PluginRegistry): | ||
parser = argparse.ArgumentParser("scallop", description="Scallop language command line interface") | ||
parser.add_argument("file", nargs="?", default=None, help="The file to execute") | ||
parser.add_argument("-p", "--provenance", type=str, default="unit", help="The provenance to pick") | ||
parser.add_argument("-k", "--top-k", default=3, type=int, help="The `k` to use when applying `top-k` related provenance") | ||
parser.add_argument("-m", "--module", type=str, default=None, help="Load module in interactive mode") | ||
parser.add_argument("--iter-limit", type=int, default=100, help="Iteration limit") | ||
parser.add_argument("--debug-front", action="store_true", help="Dump Front IR") | ||
parser.add_argument("--debug-back", action="store_true", help="Dump Back IR") | ||
parser.add_argument("--debug-ram", action="store_true", help="Dump RAM IR") | ||
parser.add_argument("--dump-loaded-plugins", action="store_true", help="Dump loaded scallopy plugins") | ||
|
||
# Setup using plugin registry | ||
plugin_registry.setup_argument_parser(parser) | ||
|
||
# Return the final parser | ||
return parser | ||
|
||
|
||
def cmd_args(plugin_registry: scallopy_ext.PluginRegistry): | ||
parser = argument_parser(plugin_registry) | ||
return parser.parse_known_args() | ||
|
||
|
||
def print_relation(ctx: scallopy.ScallopContext, relation_name: str): | ||
# First get the relation | ||
if ctx.is_probabilistic(): | ||
relation = [(f"{prob:.4f}", *tup) for (prob, tup) in ctx.relation(relation_name)] | ||
else: | ||
relation = list(ctx.relation(relation_name)) | ||
|
||
# Check if relation is empty | ||
if len(relation) == 0: | ||
print("{}") | ||
return | ||
|
||
# Get arity and max column width | ||
arity = len(relation[0]) | ||
max_column_width = [MAX_COL_WIDTH for _ in range(arity)] | ||
|
||
# Then get the headers | ||
headers = ctx.relation_field_names(relation_name) | ||
if headers and ctx.is_probabilistic(): | ||
headers = ["prob"] + headers | ||
max_column_width = [None] + max_column_width | ||
|
||
# Check if we need multi-line table | ||
need_multi_line_table = False | ||
for tup in relation: | ||
for val in tup: | ||
if type(val) == str and len(val) > MAX_COL_WIDTH: | ||
need_multi_line_table = True | ||
|
||
table_fmt = MULTILINE_TABLE_FMT if need_multi_line_table else TABLE_FMT | ||
if headers: to_print = tabulate(relation, headers=headers, tablefmt=table_fmt, maxcolwidths=max_column_width) | ||
else: to_print = tabulate(relation, tablefmt=table_fmt, maxcolwidths=max_column_width) | ||
print(to_print) | ||
|
||
|
||
def interpret(ctx, args): | ||
try: | ||
ctx.import_file(args.file) | ||
except Exception as e: | ||
print(e, file=sys.stderr) | ||
|
||
# Run the context | ||
ctx.run() | ||
|
||
# Print the results | ||
to_output_relations = [r for r in ctx.relations() if ctx.has_relation(r)] | ||
for relation in sorted(to_output_relations): | ||
print(f"{relation}:") | ||
print_relation(ctx, relation) | ||
|
||
|
||
def repl(ctx, args): | ||
# Initialize key bindings | ||
from prompt_toolkit.key_binding import KeyBindings | ||
|
||
bindings = KeyBindings() | ||
|
||
@bindings.add('c-c') | ||
def _(event): print(); exit() | ||
|
||
@bindings.add('c-d') | ||
def _(event): print(); exit() | ||
|
||
# If module is specified in args, load the module | ||
if args.module is not None: | ||
ctx.import_file(args.module) | ||
|
||
# Create a prompt session | ||
prompt = prompt_toolkit.PromptSession(key_bindings=bindings) | ||
while True: | ||
user_input = prompt.prompt('scl> ') | ||
try: | ||
queries = ctx.add_item(user_input) | ||
if len(queries) > 0: | ||
ctx.run() | ||
for query in queries: | ||
if len(queries) > 1: | ||
print(f"{query}:") | ||
print_relation(ctx, query) | ||
except Exception as err: | ||
print(err, file=sys.stderr) | ||
|
||
|
||
def setup_context(ctx: scallopy.ScallopContext, args): | ||
# Iteration limit | ||
ctx.set_iter_limit(args.iter_limit) | ||
|
||
# Debug | ||
if args.debug_front: | ||
ctx.set_debug_front() | ||
if args.debug_back: | ||
ctx.set_debug_back() | ||
if args.debug_ram: | ||
ctx.set_debug_ram() | ||
|
||
|
||
def main(): | ||
plugin_registry = scallopy_ext.PluginRegistry() | ||
|
||
# Parse command line arguments | ||
args, unknown_args = cmd_args(plugin_registry) | ||
|
||
# Configure environments | ||
plugin_registry.configure(args, unknown_args) | ||
if args.dump_loaded_plugins: | ||
plugin_registry.dump_loaded_plugins() | ||
|
||
# Create a scallopy context | ||
ctx = scallopy.ScallopContext(provenance=args.provenance, k=args.top_k) | ||
setup_context(ctx, args) | ||
|
||
# Load the scallopy extension library | ||
plugin_registry.load_into_ctx(ctx) | ||
|
||
# Check if the user has provided a file | ||
if args.file is not None: | ||
# If so, interpret the file directly | ||
interpret(ctx, args) | ||
parser = argparse.ArgumentParser("scallop", description="Scallop language command line interface") | ||
parser.add_argument("first_arg", nargs="?", default=None) | ||
args, _ = parser.parse_known_args() | ||
if args.first_arg == "create-plugin": | ||
from .create_plugin import main as create_plugin_main | ||
create_plugin_main() | ||
else: | ||
# Otherwise, enter REPL | ||
repl(ctx, args) | ||
from .run_scallop import main as run_scallop_main | ||
run_scallop_main() |
Oops, something went wrong.