Skip to content

Commit

Permalink
Adding scallop plugins; bumping versions
Browse files Browse the repository at this point in the history
  • Loading branch information
Liby99 committed Feb 23, 2024
1 parent bcf2d21 commit dbc8a1a
Show file tree
Hide file tree
Showing 63 changed files with 1,337 additions and 623 deletions.
2 changes: 1 addition & 1 deletion etc/codegen/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "scallop-codegen"
version = "0.2.1"
version = "0.2.2"
authors = ["Ziyang Li <[email protected]>"]
edition = "2018"

Expand Down
17 changes: 10 additions & 7 deletions etc/scallop-cli/examples/vision/tag_faces.scl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// Input; an image directory
rel img_dir = {
"etc/scallop-cli/res/testing/images/face_example_3.jpg",
// Input; an image url
rel img_url = {
"https://m.media-amazon.com/images/M/MV5BOWFhYjE4NzMtOWJmZi00NzEyLTg5NTctYmIxMTU1ZDIxMDAyXkEyXkFqcGdeQXVyNTE1NjY5Mg@@._V1_.jpg",
}

// Input; a face tagging prompt
rel prompt = {
"the 7 main characters of Star Trek: The Next Generation"
}

rel image($load_image(img_dir)) = img_dir(img_dir)
rel image($load_image_url(img_url)) = img_url(img_url)

// Use face detection model to extract faces from the image
@face_detection(["cropped-image", "bbox-x", "bbox-y", "bbox-w", "bbox-h"], enlarge_face_factor=1.2, dump_image=false)
Expand All @@ -32,8 +32,11 @@ rel identity(id, name) = name := top<1>(name: face_name(img, list, name), face_i
// Tag the faces
rel tag_image(0, img) = image(img)
rel tag_image(id + 1, $tag_image(img, x, y, w, h, name, "green", 10, 32)) = tag_image(id, img), identity(id, name), face_bbox(id, x, y, w, h)
rel save_image($save_image(img)) = tag_image(n as u32, img), n := count!(id: identity(id, name))
rel output_image_url($upload_imgur(img)) = tag_image(n as u32, img), n := count!(id: identity(id, name))

// Upload the image (Scallop is lazy)
query output_image_url

// Inspect intermediate relations
query identity
query face_bbox
query save_image
query face_bbox
11 changes: 11 additions & 0 deletions etc/scallop-cli/makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
build:
python -m build

install: build
find dist -name "*.whl" -print | xargs pip install --force-reinstall

develop:
pip install --editable .

clean:
rm -rf dist
172 changes: 8 additions & 164 deletions etc/scallop-cli/scallop/cli.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,12 @@
import sys

# Import torch (optional)
try: import torch
except: pass

# Argument parser
import argparse

# Prompting for REPL
import prompt_toolkit
from tabulate import tabulate

# Scallop modules
import scallopy
import scallopy_ext


TABLE_FMT = "rounded_outline"
MULTILINE_TABLE_FMT = "fancy_grid"
MAX_COL_WIDTH = 120


def argument_parser(plugin_registry: scallopy_ext.PluginRegistry):
parser = argparse.ArgumentParser("scallop", description="Scallop language command line interface")
parser.add_argument("file", nargs="?", default=None, help="The file to execute")
parser.add_argument("-p", "--provenance", type=str, default="unit", help="The provenance to pick")
parser.add_argument("-k", "--top-k", default=3, type=int, help="The `k` to use when applying `top-k` related provenance")
parser.add_argument("-m", "--module", type=str, default=None, help="Load module in interactive mode")
parser.add_argument("--iter-limit", type=int, default=100, help="Iteration limit")
parser.add_argument("--debug-front", action="store_true", help="Dump Front IR")
parser.add_argument("--debug-back", action="store_true", help="Dump Back IR")
parser.add_argument("--debug-ram", action="store_true", help="Dump RAM IR")
parser.add_argument("--dump-loaded-plugins", action="store_true", help="Dump loaded scallopy plugins")

# Setup using plugin registry
plugin_registry.setup_argument_parser(parser)

# Return the final parser
return parser


def cmd_args(plugin_registry: scallopy_ext.PluginRegistry):
parser = argument_parser(plugin_registry)
return parser.parse_known_args()


def print_relation(ctx: scallopy.ScallopContext, relation_name: str):
# First get the relation
if ctx.is_probabilistic():
relation = [(f"{prob:.4f}", *tup) for (prob, tup) in ctx.relation(relation_name)]
else:
relation = list(ctx.relation(relation_name))

# Check if relation is empty
if len(relation) == 0:
print("{}")
return

# Get arity and max column width
arity = len(relation[0])
max_column_width = [MAX_COL_WIDTH for _ in range(arity)]

# Then get the headers
headers = ctx.relation_field_names(relation_name)
if headers and ctx.is_probabilistic():
headers = ["prob"] + headers
max_column_width = [None] + max_column_width

# Check if we need multi-line table
need_multi_line_table = False
for tup in relation:
for val in tup:
if type(val) == str and len(val) > MAX_COL_WIDTH:
need_multi_line_table = True

table_fmt = MULTILINE_TABLE_FMT if need_multi_line_table else TABLE_FMT
if headers: to_print = tabulate(relation, headers=headers, tablefmt=table_fmt, maxcolwidths=max_column_width)
else: to_print = tabulate(relation, tablefmt=table_fmt, maxcolwidths=max_column_width)
print(to_print)


def interpret(ctx, args):
try:
ctx.import_file(args.file)
except Exception as e:
print(e, file=sys.stderr)

# Run the context
ctx.run()

# Print the results
to_output_relations = [r for r in ctx.relations() if ctx.has_relation(r)]
for relation in sorted(to_output_relations):
print(f"{relation}:")
print_relation(ctx, relation)


def repl(ctx, args):
# Initialize key bindings
from prompt_toolkit.key_binding import KeyBindings

bindings = KeyBindings()

@bindings.add('c-c')
def _(event): print(); exit()

@bindings.add('c-d')
def _(event): print(); exit()

# If module is specified in args, load the module
if args.module is not None:
ctx.import_file(args.module)

# Create a prompt session
prompt = prompt_toolkit.PromptSession(key_bindings=bindings)
while True:
user_input = prompt.prompt('scl> ')
try:
queries = ctx.add_item(user_input)
if len(queries) > 0:
ctx.run()
for query in queries:
if len(queries) > 1:
print(f"{query}:")
print_relation(ctx, query)
except Exception as err:
print(err, file=sys.stderr)


def setup_context(ctx: scallopy.ScallopContext, args):
# Iteration limit
ctx.set_iter_limit(args.iter_limit)

# Debug
if args.debug_front:
ctx.set_debug_front()
if args.debug_back:
ctx.set_debug_back()
if args.debug_ram:
ctx.set_debug_ram()


def main():
plugin_registry = scallopy_ext.PluginRegistry()

# Parse command line arguments
args, unknown_args = cmd_args(plugin_registry)

# Configure environments
plugin_registry.configure(args, unknown_args)
if args.dump_loaded_plugins:
plugin_registry.dump_loaded_plugins()

# Create a scallopy context
ctx = scallopy.ScallopContext(provenance=args.provenance, k=args.top_k)
setup_context(ctx, args)

# Load the scallopy extension library
plugin_registry.load_into_ctx(ctx)

# Check if the user has provided a file
if args.file is not None:
# If so, interpret the file directly
interpret(ctx, args)
parser = argparse.ArgumentParser("scallop", description="Scallop language command line interface")
parser.add_argument("first_arg", nargs="?", default=None)
args, _ = parser.parse_known_args()
if args.first_arg == "create-plugin":
from .create_plugin import main as create_plugin_main
create_plugin_main()
else:
# Otherwise, enter REPL
repl(ctx, args)
from .run_scallop import main as run_scallop_main
run_scallop_main()
Loading

0 comments on commit dbc8a1a

Please sign in to comment.