Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Semantic Role Labeling task #301

Draft
wants to merge 50 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
07a4f2b
Adding span_srl task with tests, usage and documentation
ahmeshaf Jul 21, 2023
22cba55
Fixing minor issues
ahmeshaf Jul 25, 2023
917868b
adding example usage of SRL
ahmeshaf Jul 25, 2023
0ff46ad
Merging main
ahmeshaf Jul 25, 2023
b6f4f52
Fixing format warnings
ahmeshaf Jul 25, 2023
d803b23
Fixing format warnings
ahmeshaf Jul 25, 2023
53a494c
Fixing format warnings
ahmeshaf Jul 25, 2023
dd7d9fb
Fixing format warnings
ahmeshaf Jul 25, 2023
0ad6063
Fix Literal ImportError
ahmeshaf Jul 25, 2023
15412b5
Fix Label assignment
ahmeshaf Jul 25, 2023
fd19441
Fix the template's preamble
ahmeshaf Jul 25, 2023
b56c1d0
Black formatting
ahmeshaf Jul 25, 2023
d6564f7
imports in alphabetical order
ahmeshaf Jul 27, 2023
de68696
alignment_mode should be a Literal.
ahmeshaf Jul 27, 2023
ed07c83
Update spacy_llm/tasks/srl_task.py
ahmeshaf Jul 27, 2023
472d5c7
Update spacy_llm/tasks/templates/span-srl.v1.jinja
ahmeshaf Jul 27, 2023
55a8018
Update spacy_llm/tests/tasks/test_span_srl.py
ahmeshaf Jul 27, 2023
355241a
reformatting
ahmeshaf Jul 27, 2023
a63d610
Merge branch 'main' of github.com:ahmeshaf/spacy-llm
ahmeshaf Jul 27, 2023
84d17df
reformatting
ahmeshaf Jul 27, 2023
666c3ee
adding test on srl roles
ahmeshaf Jul 28, 2023
cb81bdf
SRLTask inherits SpanTask
ahmeshaf Jul 28, 2023
c6d0dfd
Merge branch 'explosion:main' into main
ahmeshaf Aug 1, 2023
6ab4723
Added label definitions rendering in prompt
ahmeshaf Aug 1, 2023
037f36f
Reformatting
ahmeshaf Aug 1, 2023
d6faecd
Restructuring SRLExample and ARGRelItem
ahmeshaf Aug 2, 2023
2a4e862
added expected response
ahmeshaf Aug 2, 2023
b380478
Removing print statement
ahmeshaf Aug 2, 2023
8fc6b8d
Added few-shot span-srl
ahmeshaf Aug 7, 2023
73bf0f6
Add examples path in srl docs
ahmeshaf Aug 7, 2023
824aa82
removing whitespaces causing commit check failures
ahmeshaf Aug 8, 2023
6d5efc9
Make SRLExample hashable to remove duplicate examples
ahmeshaf Aug 10, 2023
2a9ede5
Add doc-tailored examples in generate_prompts
ahmeshaf Aug 11, 2023
be50655
Added defs for alignment modes
ahmeshaf Aug 16, 2023
0970e64
fix serialization issue of pred_item
ahmeshaf Aug 23, 2023
3e0a50e
Update spacy_llm/tests/tasks/test_span_srl.py
rmitsch Sep 18, 2023
30fc8e1
Merge branch 'main' into feat/srl
rmitsch Sep 21, 2023
6e40fba
Refactor to fit SRLTask into new task structure.
rmitsch Sep 22, 2023
deba894
Format.
rmitsch Sep 22, 2023
c57c058
Format.
rmitsch Sep 22, 2023
4daa3af
Allow arbitrary types in SRLExample.
rmitsch Sep 22, 2023
9ef0565
Format.
rmitsch Sep 22, 2023
1bdb0b4
Fix typing issues.
rmitsch Sep 22, 2023
3e9d72f
Format.
rmitsch Sep 22, 2023
0338389
Merge pull request #1 from explosion/main
ahmeshaf Oct 6, 2023
c6b8e55
Fixing pydantic parsing error for dicts
ahmeshaf Oct 6, 2023
e654805
adding params/returns documentation
ahmeshaf Oct 6, 2023
6d3eb19
black formatting
ahmeshaf Oct 6, 2023
8aaba47
Merge branch 'feat/srl' of github.com:ahmeshaf/spacy-llm into feat/srl
ahmeshaf Oct 6, 2023
574c602
Merge branch 'explosion:main' into feat/srl
ahmeshaf Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spacy_llm/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .rel import RELTask, make_rel_task
from .sentiment import SentimentTask, make_sentiment_task
from .spancat import SpanCatTask, make_spancat_task_v3
from .srl import SRLTask, make_srl_task
from .summarization import SummarizationTask, make_summarization_task
from .textcat import TextCatTask, make_textcat_task

Expand All @@ -17,6 +18,7 @@
"spacy.REL.v1",
"spacy.Sentiment.v1",
"spacy.SpanCat.v3",
"spacy.SRL.v1",
"spacy.Summarization.v1",
"spacy.TextCat.v3",
)
Expand All @@ -42,6 +44,7 @@
"make_rel_task",
"make_sentiment_task",
"make_spancat_task_v3",
"make_srl_task",
"make_summarization_task",
"make_textcat_task",
"BuiltinTask",
Expand All @@ -51,6 +54,7 @@
"RELTask",
"SentimentTask",
"SpanCatTask",
"SRLTask",
"SummarizationTask",
"TextCatTask",
]
5 changes: 5 additions & 0 deletions spacy_llm/tasks/srl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .registry import make_srl_task
from .task import SRLTask
from .util import SRLExample

__all__ = ["make_srl_task", "SRLExample", "SRLTask"]
154 changes: 154 additions & 0 deletions spacy_llm/tasks/srl/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import re
from typing import Iterable, List, Tuple, Any, Dict

from pydantic import ValidationError
from spacy.tokens import Doc
from wasabi import msg

from ..util.parsing import find_substrings
from .task import SRLTask
from .util import PredicateItem, RoleItem, SpanItem


def _format_response(task: SRLTask, arg_lines) -> List[Tuple[str, str]]:
"""Parse raw string response into a structured format.
task (SRLTask): Task to format responses for.
arg_lines ():
RETURNS (List[Tuple[str, str]]): Formatted response.
"""
output = []
# this ensures unique arguments in the sentence for a predicate
found_labels = set()
for line in arg_lines:
try:
if line.strip() and ":" in line:
label, phrase = line.strip().split(":", 1)

# label is of the form "ARG-n (def)"
label = label.split("(")[0].strip()

# strip any surrounding quotes
phrase = phrase.strip("'\" -")

norm_label = task.normalizer(label)
if norm_label in task.label_dict and norm_label not in found_labels:
if phrase.strip():
_phrase = phrase.strip()
found_labels.add(norm_label)
output.append((task.label_dict[norm_label], _phrase))
except ValidationError:
msg.warn(
"Validation issue",
line,
show=task.verbose,
)
return output


def parse_responses_v1(
task: SRLTask, docs: Iterable[Doc], responses: Iterable[str]
) -> Iterable[
Tuple[List[Dict[str, Any]], List[Tuple[Dict[str, Any], List[Dict[str, Any]]]]]
]:
"""
Parse LLM response by extracting predicate-arguments blocks from the generate response.
For example,
LLM response for doc: "A sentence with multiple predicates (p1, p2)"

Step 1: Extract the Predicates for the Text
Predicates: p1, p2

Step 2: For each Predicate, extract the Semantic Roles in 'Text'
Text: A sentence with multiple predicates (p1, p2)
Predicate: p1
ARG-0: a0_1
ARG-1: a1_1
ARG-M-TMP: a_t_1
ARG-M-LOC: a_l_1

Predicate: p2
ARG-0: a0_2
ARG-1: a1_2
ARG-M-TMP: a_t_2

So the steps in the parsing are to first find the text boundaries for the information
of each predicate. This is done by identifying the lines "Predicate: p1" and "Predicate: p2",
which gives us the text for each predicate as follows:

Predicate: p1
ARG-0: a0_1
ARG-1: a1_1
ARG-M-TMP: a_t_1
ARG-M-LOC: a_l_1

and,

Predicate: p2
ARG-0: a0_2
ARG-1: a1_2
ARG-M-TMP: a_t_2

Once we separate these out, then it is a matter of parsing line by line to extract the predicate
and its args for each predicate block

rmitsch marked this conversation as resolved.
Show resolved Hide resolved
"""
for doc, prompt_response in zip(docs, responses):
predicates: List[Dict[str, Any]] = []
relations: List[Tuple[Dict[str, Any], List[Dict[str, Any]]]] = []
lines = prompt_response.split("\n")

# match lines that start with {Predicate:, Predicate 1:, Predicate1:}
pred_patt = r"^" + re.escape(task.predicate_key) + r"\b\s*\d*[:\-\s]"
pred_indices, pred_lines = zip(
*[(i, line) for i, line in enumerate(lines) if re.search(pred_patt, line)]
)

pred_indices = list(pred_indices)

# extract the predicate strings
pred_strings = [line.split(":", 1)[1].strip("'\" ") for line in pred_lines]

# extract the line ranges (s, e) of predicate's content.
# then extract the pred content lines using the ranges
pred_indices.append(len(lines))
pred_ranges = zip(pred_indices[:-1], pred_indices[1:])
pred_contents = [lines[s:e] for s, e in pred_ranges]

# assign the spans of the predicates and args
# then create ArgRELItem from the identified predicates and arguments
for pred_str, pred_content_lines in zip(pred_strings, pred_contents):
pred_offsets = list(
find_substrings(
doc.text, [pred_str], case_sensitive=True, single_match=True
)
)

# ignore the args if the predicate is not found
if len(pred_offsets):
p_start_char, p_end_char = pred_offsets[0]
pred_item = PredicateItem(
text=pred_str, start_char=p_start_char, end_char=p_end_char
).dict()
predicates.append(pred_item)

roles = []

for label, phrase in _format_response(task, pred_content_lines):
arg_offsets = find_substrings(
doc.text,
[phrase],
case_sensitive=task.case_sensitive_matching,
single_match=task.single_match,
)
for start, end in arg_offsets:
arg_item = SpanItem(
text=phrase, start_char=start, end_char=end
).dict()
arg_rel_item = RoleItem(
predicate=pred_item, role=arg_item, label=label
).dict()
roles.append(arg_rel_item)

relations.append((pred_item, roles))

yield predicates, relations
72 changes: 72 additions & 0 deletions spacy_llm/tasks/srl/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Callable, Dict, List, Optional, Type, Union

from ...compat import Literal
from ...registry import registry
from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser
from ...util import split_labels
from .parser import parse_responses_v1
from .task import DEFAULT_SPAN_SRL_TEMPLATE_V1, SRLTask
from .util import SRLExample, score


@registry.llm_tasks("spacy.SRL.v1")
def make_srl_task(
template: str = DEFAULT_SPAN_SRL_TEMPLATE_V1,
parse_responses: Optional[TaskResponseParser[SRLTask]] = None,
prompt_example_type: Optional[Type[FewshotExample]] = None,
scorer: Optional[Scorer] = None,
examples: ExamplesConfigType = None,
labels: Union[List[str], str] = [],
label_definitions: Optional[Dict[str, str]] = None,
normalizer: Optional[Callable[[str], str]] = None,
alignment_mode: Literal["strict", "contract", "expand"] = "contract",
case_sensitive_matching: bool = True,
single_match: bool = True,
verbose: bool = False,
predicate_key: str = "Predicate",
):
"""SRL.v1 task factory.

template (str): Prompt template passed to the model.
parse_responses (Optional[TaskResponseParser]): Callable for parsing LLM responses for this task.
prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples.
examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for
few-shot learning. If None is passed, then zero-shot learning will be used.
scorer (Optional[Scorer]): Scorer function.
labels (str): Comma-separated list of labels to pass to the template.
Leave empty to populate it at initialization time (only if examples are provided).
label_definitions (Optional[Dict[str, str]]): Map of label -> description
of the label to help the language model output the entities wanted.
It is usually easier to provide these definitions rather than
full examples, although both can be provided.
normalizer (Optional[Callable[[str], str]]): optional normalizer function.
alignment_mode (Literal["strict", "contract", "expand"]): How character indices snap to token boundaries.
Options: "strict" (no snapping), "contract" (span of all tokens completely within the character span),
"expand" (span of all tokens at least partially covered by the character span).
Defaults to "strict".
case_sensitive_matching: Whether to search without case sensitivity.
single_match (bool): If False, allow one substring to match multiple times in
the text. If True, returns the first hit.
verbose (bool): Verbose or not
predicate_key (str): The str of Predicate in the template
"""
labels_list = split_labels(labels)
raw_examples = examples() if callable(examples) else examples
example_type = prompt_example_type or SRLExample
srl_examples = [example_type(**eg) for eg in raw_examples] if raw_examples else None

return SRLTask(
template=template,
parse_responses=parse_responses or parse_responses_v1,
prompt_example_type=example_type,
prompt_examples=srl_examples,
scorer=scorer or score,
labels=labels_list,
label_definitions=label_definitions,
normalizer=normalizer,
verbose=verbose,
alignment_mode=alignment_mode,
case_sensitive_matching=case_sensitive_matching,
single_match=single_match,
predicate_key=predicate_key,
)
Loading
Loading