-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Minor updates to makefiles, readme, and documentations
- Loading branch information
Showing
12 changed files
with
275 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,34 @@ | ||
# Fact with Probability | ||
|
||
We can associate facts with probabilities through the use of `::` symbol. | ||
This can be done with the set syntax and also individual fact syntax: | ||
|
||
``` scl | ||
rel color = {0.1::"red", 0.8::"green", 0.1::"blue"} | ||
// or | ||
rel 0.1::color("red") | ||
rel 0.8::color("green") | ||
rel 0.1::color("blue") | ||
``` | ||
|
||
## Mutual exclusive facts | ||
|
||
Within the set annotation, if we replace the comma (`,`) with semi-colons (`;`), we will be specifying mutual exclusive facts. | ||
If one is encoding a categorical distribution, they should be specifying mutual exclusions by default. | ||
Suppose we have two MNIST digits that can be classified as a number between 0 to 9. | ||
If we represent each digit with their ID, say `A` and `B`, we should write the following program in Scallop: | ||
|
||
``` scl | ||
type ImageID = A | B | ||
type digit(img_id: ImageID, number: i32) | ||
rel digit = {0.01::(A, 0); 0.86::(A, 1); ...; 0.03::(A, 9)} | ||
rel digit = {0.75::(B, 0); 0.03::(B, 1); ...; 0.02::(B, 9)} | ||
``` | ||
|
||
Notice that we have specified two sets of digits, each being a mutual exclusion, as suggested by the semi-colon separator (`;`). | ||
This means that each of `A` and `B` could be classified as one of the 10 numbers, but not multiple. | ||
|
||
## Specifyin mutually exclusive facts in Scallopy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,83 @@ | ||
# Logic and Probability | ||
# Probabilistic Rules | ||
|
||
In Scallop, rules can have probabilities too, just like the probabilities associated with the facts and tuples. | ||
For instance, you might write the following probabilistic rule to denote that "when earthquake happens, there is a 80% chance that an alarm will go off": | ||
|
||
``` scl | ||
rel 0.8::alarm() = earthquake() | ||
``` | ||
|
||
Combine the above rule with the fact that earthquake happens with a 10% probability, we obtain that the alarm will go off with a 0.08 probability. | ||
Note that this result can be obtained using `topkproofs` or `addmultprob` provenance, while a provenance such as `minmaxprob` will give different results. | ||
|
||
``` scl | ||
rel 0.1::earthquake() | ||
query alarm // 0.08::alarm() | ||
``` | ||
|
||
## Rule tags with expressions | ||
|
||
What is special about the probabilities of rules is that the probabilities could be expressions depending on values within the rule. | ||
For instance, here is a set of rules that say that the probability of a path depends on the length of a path, which falls off when the length increases: | ||
|
||
``` scl | ||
// A few edges | ||
rel edge = {(0, 1), (1, 2)} | ||
// Compute the length of the paths (note that we encode length with floating point numbers) | ||
rel path(x, y, 1.0) = edge(x, y) | ||
rel path(x, z, l + 1.0) = path(x, y, l) and edge(y, z) | ||
// Compute the probabilistic path using the fall off (1 / length) | ||
rel 1.0 / l :: prob_path(x, y) = path(x, y, l) | ||
// Perform the query with arbitrary | ||
query prob_path // prob_path: {1.0::(0, 1), 0.5::(0, 2), 1.0::(1, 2)} | ||
``` | ||
|
||
Here, since `path(0, 1)` and `path(1, 2)` have length 1, their probability is `1 / 1 = 1`. | ||
However, `path(0, 2)` has length 2 so its probability is `1 / 2 = 0.5`. | ||
|
||
As can be seen, with the support for having expressions in the tag, we can encode more custom probabilistic rules in Scallop. | ||
Internally, this is implemented through the use of custom foreign predicates. | ||
|
||
## Rule tags that are not floating points | ||
|
||
In general, Scallop supports many forms of tag, including but not limited to probabilities (floating points). | ||
For instance, we can encode boolean as well: | ||
|
||
``` scl | ||
rel constraint(x == y) = digit_1(x) and digit_2(y) | ||
rel b::sat() = constraint(b) | ||
``` | ||
|
||
The relation `constraint` has type `(bool)`, and therefore the variable `b` in the second rule has type boolean as well. | ||
With the second rule, we lift the boolean value into the boolean tag associated with the nullary relation `sat`. | ||
|
||
## Associating rules with tags from Scallopy | ||
|
||
> We elaborate on this topic in the Scallopy section as well | ||
You can associate rules with tags from Scallopy as well, so that we are not confined to Scallop's syntax. | ||
For instance, the following python program creates a new Scallop context and inserts a rule with a tag of 0.8. | ||
|
||
``` py | ||
ctx = scallopy.Context(provenance="topkproofs") | ||
ctx.add_rule("alarm() = earthquake()", tag=0.8) | ||
``` | ||
|
||
Of course, the tag doesn't need to be a simple constant floating point, since we are operating within the domain of Python. | ||
How about using a PyTorch tensor? Certainly! | ||
|
||
``` py | ||
ctx = scallopy.Context(provenance="topkproofs") | ||
ctx.add_rule("alarm() = earthquake()", tag=torch.tensor(0.8, requires_grad=True)) | ||
``` | ||
|
||
Notice that we have specified that `requires_grad=True`. | ||
This means that if any Scallop output depends on the tag of this rule, the PyTorch back-propagation will be able to accumulate gradient on this tensor of 0.8. | ||
Any optimization will have an effect on updating the tag, by essentially treating it as a parameter. | ||
Of course, we might need more thoughts so that the optimization can actually happen. | ||
For instance, you will need to tell the optimizer that this tensor is a parameter. | ||
But we will delay this discussion to a later section. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
.PHONY: build install develop | ||
|
||
build: | ||
python -m build | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
.PHONY: build install develop | ||
|
||
build: | ||
python -m build | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[project] | ||
name = "scallop-plip" | ||
version = "0.0.1" | ||
dependencies = [ | ||
"transformers", | ||
"torch", | ||
] | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project.entry-points."scallop.plugin.setup_arg_parser"] | ||
plip = "scallop_plip:setup_arg_parser" | ||
|
||
[project.entry-points."scallop.plugin.configure"] | ||
plip = "scallop_plip:configure" | ||
|
||
[project.entry-points."scallop.plugin.load_into_context"] | ||
plip = "scallop_plip:load_into_context" |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import scallopy | ||
|
||
from .config import setup_arg_parser, configure | ||
from .plip import plip | ||
|
||
def load_into_context(ctx: scallopy.Context): | ||
ctx.register_foreign_attribute(plip) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import sys | ||
from scallop_gpu import get_device | ||
|
||
|
||
_DEFAULT_PLIP_MODEL_CHECKPOINT = "vinid/plip" | ||
_PLIP_MODEL_CHECKPOINT = _DEFAULT_PLIP_MODEL_CHECKPOINT | ||
_PLIP_MODEL = None | ||
_PLIP_PREPROCESS = None | ||
|
||
|
||
def setup_arg_parser(parser): | ||
parser.add_argument("--plip-model-checkpoint", type=str, default=_DEFAULT_PLIP_MODEL_CHECKPOINT) | ||
|
||
|
||
def configure(args): | ||
global _PLIP_MODEL_CHECKPOINT | ||
_PLIP_MODEL_CHECKPOINT = args["plip_model_checkpoint"] | ||
|
||
|
||
def get_plip_model(debug=False): | ||
global _PLIP_MODEL | ||
global _PLIP_PREPROCESS | ||
|
||
if _PLIP_MODEL is None: | ||
try: | ||
if debug: | ||
print(f"[scallop-plip] Loading PLIP model `{_PLIP_MODEL_CHECKPOINT}`...") | ||
from transformers import CLIPProcessor, CLIPModel | ||
|
||
model = CLIPModel.from_pretrained(_PLIP_MODEL_CHECKPOINT) | ||
preprocess = CLIPProcessor.from_pretrained(_PLIP_MODEL_CHECKPOINT) | ||
|
||
# model, preprocess = plip.load(_PLIP_MODEL_CHECKPOINT, device=get_device()) | ||
_PLIP_MODEL = model | ||
_PLIP_PREPROCESS = preprocess | ||
|
||
if debug: | ||
print(f"[scallop-plip] Done!") | ||
except Exception as ex: | ||
if debug: | ||
print(ex, file=sys.stderr) | ||
return None | ||
|
||
if debug: | ||
print("[scallop-plip] Using loaded PLIP model") | ||
|
||
return (_PLIP_MODEL, _PLIP_PREPROCESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Tuple, List, Optional | ||
|
||
import scallopy | ||
import torch | ||
from PIL import Image | ||
|
||
from scallop_gpu import get_device | ||
|
||
from .config import get_plip_model | ||
|
||
ERR_HEAD = f"[@plip]" | ||
DELIMITER = ";" | ||
|
||
@scallopy.foreign_attribute | ||
def plip( | ||
item, | ||
labels: Optional[List[str]] = None, | ||
*, | ||
prompt: Optional[str] = None, | ||
score_threshold: float = 0, | ||
unknown_class: str = "?", | ||
debug: bool = False, | ||
): | ||
# Check if the annotation is on relation type decl | ||
assert item.is_relation_decl(), f"{ERR_HEAD} has to be an attribute of a relation type declaration" | ||
assert len(item.relation_decls()) == 1, f"{ERR_HEAD} cannot be an attribute on multiple relations" | ||
|
||
# Get the relation name and check argument types | ||
relation_decl = item.relation_decl(0) | ||
args = [arg for arg in relation_decl.arg_bindings] | ||
|
||
# Check the argument types | ||
assert len(args) >= 1 and args[0].ty.is_tensor() and args[0].adornment.is_bound(), f"{ERR_HEAD} first argument has to be a bounded type `Tensor`" | ||
if labels is not None: | ||
assert len(args) == 2, f"{ERR_HEAD} relation has to be of arity-2 provided the labels" | ||
assert args[1].ty.is_string() and (args[1].adornment is None or args[1].adornment.is_free()), f"{ERR_HEAD} second argument has to be of free type `String`" | ||
else: | ||
assert len(args) == 3, f"{ERR_HEAD} relation has to be of arity-3 given that labels need to be passed in dynamically" | ||
assert args[1].ty.is_string() and args[1].adornment.is_bound(), f"{ERR_HEAD} second argument has to be a bounded type `String`" | ||
assert args[2].ty.is_string() and (args[2].adornment is None or args[2].adornment.is_free()), f"{ERR_HEAD} third argument has to be of free type `String`" | ||
|
||
@scallopy.foreign_predicate(name=relation_decl.name.name) | ||
def plip_classify(img: scallopy.Tensor) -> scallopy.Generator[float, Tuple[str]]: | ||
device = get_device() | ||
maybe_plip_model = get_plip_model(debug=debug) | ||
if maybe_plip_model is None: | ||
return | ||
|
||
# If successfully loaded plip, then initialize the plip models | ||
(plip_model, plip_preprocess) = maybe_plip_model | ||
|
||
# Enter non-training mode | ||
with torch.no_grad(): | ||
class_prompts = [get_class_prompt(prompt, class_str) for class_str in labels] | ||
txt_tokens = plip.tokenize(class_prompts).to(device=device) | ||
img_proc = plip_preprocess(Image.fromarray(img.numpy())).unsqueeze(0).to(device=device) | ||
logits_per_image, _ = plip_model(img_proc, txt_tokens) | ||
probs = logits_per_image.softmax(dim=-1).cpu()[0] | ||
for prob, class_prompt, class_str in zip(probs, class_prompts, labels): | ||
if debug: | ||
print(f"[@plip_classifier] {prob:.4f} :: {class_prompt} (Class Name: {class_str})") | ||
if prob >= score_threshold: yield (prob, (class_str,)) | ||
else: yield (0.0, (unknown_class,)) | ||
|
||
# Generate the foreign predicate for dynamic labels | ||
@scallopy.foreign_predicate(name=relation_decl.name.name) | ||
def plip_classify_with_labels(img: scallopy.Tensor, list: scallopy.String) -> scallopy.Generator[float, Tuple[str]]: | ||
nonlocal labels | ||
labels = [item.strip() for item in list.split(DELIMITER)] | ||
return plip_classify(img) | ||
|
||
# Return the appropriate foreign predicate | ||
if labels is not None: return plip_classify | ||
else: return plip_classify_with_labels | ||
|
||
|
||
def get_class_prompt(prompt: Optional[str], class_name: str): | ||
if prompt: return prompt.replace("{{}}", class_name) | ||
else: return class_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters