diff --git a/spacy_llm/tasks/__init__.py b/spacy_llm/tasks/__init__.py index 426546ca..1f495483 100644 --- a/spacy_llm/tasks/__init__.py +++ b/spacy_llm/tasks/__init__.py @@ -11,6 +11,7 @@ from .rel import RELTask, make_rel_task from .sentiment import SentimentTask, make_sentiment_task from .spancat import SpanCatTask, make_spancat_task_v3 +from .srl import SRLTask, make_srl_task from .summarization import SummarizationTask, make_summarization_task from .textcat import TextCatTask, make_textcat_task from .translation import TranslationTask, make_translation_task @@ -22,6 +23,7 @@ "spacy.REL.v1", "spacy.Sentiment.v1", "spacy.SpanCat.v3", + "spacy.SRL.v1", "spacy.Summarization.v1", "spacy.TextCat.v3", "spacy.Translation.v1", @@ -51,6 +53,7 @@ "make_rel_task", "make_sentiment_task", "make_spancat_task_v3", + "make_srl_task", "make_summarization_task", "make_textcat_task", "make_translation_task", @@ -64,6 +67,7 @@ "SentimentTask", "ShardingNoopTask", "SpanCatTask", + "SRLTask", "SummarizationTask", "TextCatTask", "TranslationTask", diff --git a/spacy_llm/tasks/srl/__init__.py b/spacy_llm/tasks/srl/__init__.py new file mode 100644 index 00000000..4967a3f1 --- /dev/null +++ b/spacy_llm/tasks/srl/__init__.py @@ -0,0 +1,5 @@ +from .registry import make_srl_task +from .task import SRLTask +from .util import SRLExample + +__all__ = ["make_srl_task", "SRLExample", "SRLTask"] diff --git a/spacy_llm/tasks/srl/parser.py b/spacy_llm/tasks/srl/parser.py new file mode 100644 index 00000000..55f43356 --- /dev/null +++ b/spacy_llm/tasks/srl/parser.py @@ -0,0 +1,158 @@ +import re +from typing import Iterable, List, Tuple, Any, Dict + +from pydantic import ValidationError +from spacy.tokens import Doc +from wasabi import msg + +from ..util.parsing import find_substrings +from .task import SRLTask +from .util import PredicateItem, RoleItem, SpanItem + + +def _format_response(task: SRLTask, arg_lines: List[str]) -> List[Tuple[str, str]]: + """Parse raw string response into a structured format. + task (SRLTask): Task to format responses for. + arg_lines (List[str]): The str responses corresponding to the roles of a predicate + RETURNS (List[Tuple[str, str]]): Formatted response. + """ + output = [] + # this ensures unique arguments in the sentence for a predicate + found_labels = set() + for line in arg_lines: + try: + if line.strip() and ":" in line: + label, phrase = line.strip().split(":", 1) + + # label is of the form "ARG-n (def)" + label = label.split("(")[0].strip() + + # strip any surrounding quotes + phrase = phrase.strip("'\" -") + + norm_label = task.normalizer(label) + if norm_label in task.label_dict and norm_label not in found_labels: + if phrase.strip(): + _phrase = phrase.strip() + found_labels.add(norm_label) + output.append((task.label_dict[norm_label], _phrase)) + except ValidationError: + msg.warn( + "Validation issue", + line, + show=task.verbose, + ) + return output + + +def parse_responses_v1( + task: SRLTask, docs: Iterable[Doc], responses: Iterable[str] +) -> Iterable[ + Tuple[List[Dict[str, Any]], List[Tuple[Dict[str, Any], List[Dict[str, Any]]]]] +]: + """ + Parse LLM response by extracting predicate-arguments blocks from the generate response. + For example, + LLM response for doc: "A sentence with multiple predicates (p1, p2)" + + Step 1: Extract the Predicates for the Text + Predicates: p1, p2 + + Step 2: For each Predicate, extract the Semantic Roles in 'Text' + Text: A sentence with multiple predicates (p1, p2) + Predicate: p1 + ARG-0: a0_1 + ARG-1: a1_1 + ARG-M-TMP: a_t_1 + ARG-M-LOC: a_l_1 + + Predicate: p2 + ARG-0: a0_2 + ARG-1: a1_2 + ARG-M-TMP: a_t_2 + + So the steps in the parsing are to first find the text boundaries for the information + of each predicate. This is done by identifying the lines "Predicate: p1" and "Predicate: p2", + which gives us the text for each predicate as follows: + + Predicate: p1 + ARG-0: a0_1 + ARG-1: a1_1 + ARG-M-TMP: a_t_1 + ARG-M-LOC: a_l_1 + + and, + + Predicate: p2 + ARG-0: a0_2 + ARG-1: a1_2 + ARG-M-TMP: a_t_2 + + Once we separate these out, then it is a matter of parsing line by line to extract the predicate + and its args for each predicate block + task (SpanTask): Task instance. + docs (Iterable[Doc]): Corresponding Doc instances. + responses (Iterable[str]): LLM responses. + RETURNS (Tuple[List[Dict[str, Any]], List[Tuple[Dict[str, Any], List[Dict[str, Any]]]]]): Predicates to assign + to each doc, Relations to assign to each doc. + """ + for doc, prompt_response in zip(docs, responses): + predicates: List[Dict[str, Any]] = [] + relations: List[Tuple[Dict[str, Any], List[Dict[str, Any]]]] = [] + lines = prompt_response.split("\n") + + # match lines that start with {Predicate:, Predicate 1:, Predicate1:} + pred_patt = r"^" + re.escape(task.predicate_key) + r"\b\s*\d*[:\-\s]" + pred_indices, pred_lines = zip( + *[(i, line) for i, line in enumerate(lines) if re.search(pred_patt, line)] + ) + + pred_indices = list(pred_indices) + + # extract the predicate strings + pred_strings = [line.split(":", 1)[1].strip("'\" ") for line in pred_lines] + + # extract the line ranges (s, e) of predicate's content. + # then extract the pred content lines using the ranges + pred_indices.append(len(lines)) + pred_ranges = zip(pred_indices[:-1], pred_indices[1:]) + pred_contents = [lines[s:e] for s, e in pred_ranges] + + # assign the spans of the predicates and args + # then create ArgRELItem from the identified predicates and arguments + for pred_str, pred_content_lines in zip(pred_strings, pred_contents): + pred_offsets = list( + find_substrings( + doc.text, [pred_str], case_sensitive=True, single_match=True + ) + ) + + # ignore the args if the predicate is not found + if len(pred_offsets): + p_start_char, p_end_char = pred_offsets[0] + pred_item = PredicateItem( + text=pred_str, start_char=p_start_char, end_char=p_end_char + ).dict() + predicates.append(pred_item) + + roles = [] + + for label, phrase in _format_response(task, pred_content_lines): + arg_offsets = find_substrings( + doc.text, + [phrase], + case_sensitive=task.case_sensitive_matching, + single_match=task.single_match, + ) + for start, end in arg_offsets: + arg_item = SpanItem( + text=phrase, start_char=start, end_char=end + ).dict() + arg_rel_item = RoleItem( + predicate=pred_item, role=arg_item, label=label + ).dict() + roles.append(arg_rel_item) + + relations.append((pred_item, roles)) + + yield predicates, relations diff --git a/spacy_llm/tasks/srl/registry.py b/spacy_llm/tasks/srl/registry.py new file mode 100644 index 00000000..d9c6ca50 --- /dev/null +++ b/spacy_llm/tasks/srl/registry.py @@ -0,0 +1,72 @@ +from typing import Callable, Dict, List, Optional, Type, Union + +from ...compat import Literal +from ...registry import registry +from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...util import split_labels +from .parser import parse_responses_v1 +from .task import DEFAULT_SPAN_SRL_TEMPLATE_V1, SRLTask +from .util import SRLExample, score + + +@registry.llm_tasks("spacy.SRL.v1") +def make_srl_task( + template: str = DEFAULT_SPAN_SRL_TEMPLATE_V1, + parse_responses: Optional[TaskResponseParser[SRLTask]] = None, + prompt_example_type: Optional[Type[FewshotExample]] = None, + scorer: Optional[Scorer] = None, + examples: ExamplesConfigType = None, + labels: Union[List[str], str] = [], + label_definitions: Optional[Dict[str, str]] = None, + normalizer: Optional[Callable[[str], str]] = None, + alignment_mode: Literal["strict", "contract", "expand"] = "contract", + case_sensitive_matching: bool = True, + single_match: bool = True, + verbose: bool = False, + predicate_key: str = "Predicate", +): + """SRL.v1 task factory. + + template (str): Prompt template passed to the model. + parse_responses (Optional[TaskResponseParser]): Callable for parsing LLM responses for this task. + prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. + examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for + few-shot learning. If None is passed, then zero-shot learning will be used. + scorer (Optional[Scorer]): Scorer function. + labels (str): Comma-separated list of labels to pass to the template. + Leave empty to populate it at initialization time (only if examples are provided). + label_definitions (Optional[Dict[str, str]]): Map of label -> description + of the label to help the language model output the entities wanted. + It is usually easier to provide these definitions rather than + full examples, although both can be provided. + normalizer (Optional[Callable[[str], str]]): optional normalizer function. + alignment_mode (Literal["strict", "contract", "expand"]): How character indices snap to token boundaries. + Options: "strict" (no snapping), "contract" (span of all tokens completely within the character span), + "expand" (span of all tokens at least partially covered by the character span). + Defaults to "strict". + case_sensitive_matching: Whether to search without case sensitivity. + single_match (bool): If False, allow one substring to match multiple times in + the text. If True, returns the first hit. + verbose (bool): Verbose or not + predicate_key (str): The str of Predicate in the template + """ + labels_list = split_labels(labels) + raw_examples = examples() if callable(examples) else examples + example_type = prompt_example_type or SRLExample + srl_examples = [example_type(**eg) for eg in raw_examples] if raw_examples else None + + return SRLTask( + template=template, + parse_responses=parse_responses or parse_responses_v1, + prompt_example_type=example_type, + prompt_examples=srl_examples, + scorer=scorer or score, + labels=labels_list, + label_definitions=label_definitions, + normalizer=normalizer, + verbose=verbose, + alignment_mode=alignment_mode, + case_sensitive_matching=case_sensitive_matching, + single_match=single_match, + predicate_key=predicate_key, + ) diff --git a/spacy_llm/tasks/srl/task.py b/spacy_llm/tasks/srl/task.py new file mode 100644 index 00000000..9dd3f804 --- /dev/null +++ b/spacy_llm/tasks/srl/task.py @@ -0,0 +1,220 @@ +import warnings +from typing import Callable, Dict, Iterable, List, Optional, Type + +import jinja2 +from spacy.language import Language +from spacy.tokens import Doc +from spacy.training import Example + +from ...compat import Literal, Self +from ...ty import FewshotExample, Scorer, TaskResponseParser +from ..span import SpanTask +from ..templates import read_template +from .util import SRLExample + +DEFAULT_SPAN_SRL_TEMPLATE_V1 = read_template("span-srl.v1") + + +class SRLTask(SpanTask): + def __init__( + self, + template: str, + parse_responses: TaskResponseParser[Self], + prompt_example_type: Type[FewshotExample], + prompt_examples: Optional[List[FewshotExample]], + scorer: Scorer, + labels: List[str], + label_definitions: Optional[Dict[str, str]], + normalizer: Optional[Callable[[str], str]], + alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821 + case_sensitive_matching: bool, + single_match: bool, + verbose: bool, + predicate_key: str, + ): + """ + template (str): Prompt template passed to the model. + parse_responses (TaskResponseParser): Callable for parsing LLM responses for this task. + prompt_example_type (Type[FewshotExample]): Type to use for fewshot examples. + prompt_examples (Optional[List[FewshotExample]]): Optional list of few-shot examples to include in prompts. + scorer (Scorer): Scorer function. + labels (List[str]): List of labels to pass to the template. + Leave empty to populate it at initialization time (only if examples are provided). + label_definitions (Optional[Dict[str, str]]): Map of label -> description + of the label to help the language model output the entities wanted. + It is usually easier to provide these definitions rather than + full examples, although both can be provided. + spans_key (str): Key of the `Doc.spans` dict to save under. + normalizer (Optional[Callable[[str], str]]): optional normalizer function. + alignment_mode (str): "strict", "contract" or "expand". + case_sensitive_matching (bool): Whether to search without case sensitivity. + single_match (bool): If False, allow one substring to match multiple times in + the text. If True, returns the first hit. + description (str): A description of what to recognize or not recognize as entities. + check_label_consistency (SpanTaskLabelCheck): Callable to check label consistency. + """ + super().__init__( + parse_responses=parse_responses, + prompt_example_type=prompt_example_type, + labels=labels, + template=template, + label_definitions=label_definitions, + prompt_examples=prompt_examples, + normalizer=normalizer, + alignment_mode=alignment_mode, + case_sensitive_matching=case_sensitive_matching, + single_match=single_match, + description=None, + allow_overlap=False, + check_label_consistency=SRLTask._check_srl_label_consistency, + ) + + self._predicate_key = predicate_key + self._verbose = verbose + self._scorer = scorer + self._check_extensions() + + @classmethod + def _check_extensions(cls): + """Add `predicates` extension if need be. + Add `relations` extension if need be.""" + if not Doc.has_extension("predicates"): + Doc.set_extension("predicates", default=[]) + + if not Doc.has_extension("relations"): + Doc.set_extension("relations", default=[]) + + def initialize( + self, + get_examples: Callable[[], Iterable["Example"]], + nlp: Language, + labels: List[str] = [], + n_prompt_examples: int = 0, + ) -> None: + self._check_extensions() + + super()._initialize( + get_examples=get_examples, + nlp=nlp, + labels=labels, + n_prompt_examples=n_prompt_examples, + ) + + def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[str]: + # todo Simplify after **kwargs ditching PR has been merged. + environment = jinja2.Environment() + _template = environment.from_string(self._template) + for doc in docs: + predicates = None + if len(doc._.predicates): + predicates = ", ".join([p["text"] for p in doc._.predicates]) + + doc_examples = self._prompt_examples + + # check if there are doc-tailored examples + if doc.has_extension("egs") and doc._.egs is not None and len(doc._.egs): + doc_examples = doc._.egs + + prompt = _template.render( + text=doc.text, + labels=list(self._label_dict.values()), + label_definitions=self._label_definitions, + predicates=predicates, + examples=doc_examples, + ) + + yield prompt + + @property + def _cfg_keys(self) -> List[str]: + return [ + "_label_dict", + "_template", + "_label_definitions", + "_verbose", + "_predicate_key", + "_alignment_mode", + "_case_sensitive_matching", + "_single_match", + ] + + def parse_responses( + self, docs: Iterable[Doc], responses: Iterable[str] + ) -> Iterable[Doc]: + for doc, (predicates, relations) in zip( + docs, self._parse_responses(self, docs, responses) + ): + doc._.predicates = predicates + doc._.relations = relations + yield doc + + def _extract_labels_from_example(self, example: Example) -> List[str]: + if hasattr(example, "relations"): + return [r.label for p, rs in example.relations for r in rs] + return [] + + @classmethod + def _check_srl_label_consistency(cls, task: Self) -> List[FewshotExample]: + """Checks consistency of labels between examples and defined labels. Emits warning on inconsistency. + + Note: it's unusual for a SpanTask to have its own label consistency check implementation (and an example type + not derived from SpanExample). This should be cleaned up and unified. + + task (SRLTask): SRLTask instance. + RETURNS (List[FewshotExample]): List of SRLExamples with valid labels. + """ + assert task.prompt_examples + assert issubclass(task.prompt_example_type, SRLExample) + + srl_examples = [ + task.prompt_example_type(**eg.dict()) for eg in task.prompt_examples + ] + example_labels = { + task.normalizer(r.label): r.label + for example in srl_examples + for p, rs in example.relations + for r in rs + } + unspecified_labels = { + example_labels[key] + for key in (set(example_labels.keys()) - set(task.label_dict.keys())) + } + if not set(example_labels.keys()) <= set(task.label_dict.keys()): + warnings.warn( + f"Examples contain labels that are not specified in the task configuration. The latter contains the " + f"following labels: {sorted(list(set(task.label_dict.values())))}. Labels in examples missing from " + f"the task configuration: {sorted(list(unspecified_labels))}. Please ensure your label specification " + f"and example labels are consistent." + ) + + # Return examples without non-declared roles. the roles within a predicate that have undeclared role labels + # are discarded. + return [ + example + for example in [ + task.prompt_example_type( + text=example.text, + predicates=example.predicates, + relations=[ + ( + p, + [ + r + for r in rs + if task.normalizer(r.label) in task.label_dict + ], + ) + for p, rs in example.relations + ], + ) + for example in srl_examples + ] + ] + + @property + def predicate_key(self) -> str: + return self._predicate_key + + @property + def verbose(self) -> bool: + return self._verbose diff --git a/spacy_llm/tasks/srl/util.py b/spacy_llm/tasks/srl/util.py new file mode 100644 index 00000000..94654dfa --- /dev/null +++ b/spacy_llm/tasks/srl/util.py @@ -0,0 +1,131 @@ +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Tuple + +from pydantic import BaseModel +from spacy.training import Example +from typing_extensions import Self + +from ...ty import FewshotExample + + +class SpanItem(BaseModel): + text: str + start_char: int + end_char: int + + def __hash__(self): + return hash((self.text, self.start_char, self.end_char)) + + +class PredicateItem(SpanItem): + roleset_id: str = "" + + def __hash__(self): + return hash((self.text, self.start_char, self.end_char, self.roleset_id)) + + +class RoleItem(BaseModel): + role: SpanItem + label: str + + def __hash__(self): + return hash((self.role, self.label)) + + +class SRLExample(FewshotExample): + text: str + predicates: List[PredicateItem] + relations: List[Tuple[PredicateItem, List[RoleItem]]] + + class Config: + arbitrary_types_allowed = True + + def __hash__(self): + return hash((self.text,) + tuple(self.predicates)) + + def __str__(self): + preds = ", ".join([p.text for p in self.predicates]) + rels = [ + (p.text, [(r.label, r.role.text) for r in rs]) for p, rs in self.relations + ] + return f"Predicates: {preds}\nRelations: {str(rels)}" "" + + @classmethod + def generate(cls, example: Example, **kwargs) -> Self: + return cls( + text=example.reference.text, + predicates=example.reference._.predicates, + relations=example.reference._.relations, + ) + + +def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: + """Score SRL accuracy in examples. + examples (Iterable[Example]): Examples to score. + RETURNS (Dict[str, Any]): Dict with metric name -> score. + """ + pred_predicates_spans = set() + gold_predicates_spans = set() + + pred_relation_tuples = set() + gold_relation_tuples = set() + + for i, eg in enumerate(examples): + pred_doc = eg.predicted + gold_doc = eg.reference + + pred_predicates_spans.update( + [(i, PredicateItem(**dict(p))) for p in pred_doc._.predicates] + ) + gold_predicates_spans.update( + [(i, PredicateItem(**dict(p))) for p in gold_doc._.predicates] + ) + + pred_relation_tuples.update( + [ + (i, PredicateItem(**dict(p)), RoleItem(**dict(r))) + for p, rs in pred_doc._.relations + for r in rs + ] + ) + gold_relation_tuples.update( + [ + (i, PredicateItem(**dict(p)), RoleItem(**dict(r))) + for p, rs in gold_doc._.relations + for r in rs + ] + ) + + def _overlap_prf(gold: set, pred: set): + overlap = gold.intersection(pred) + p = 0.0 if not len(pred) else len(overlap) / len(pred) + r = 0.0 if not len(gold) else len(overlap) / len(gold) + f = 0.0 if not p or not r else 2 * p * r / (p + r) + return p, r, f + + predicates_prf = _overlap_prf(gold_predicates_spans, pred_predicates_spans) + micro_rel_prf = _overlap_prf(gold_relation_tuples, pred_relation_tuples) + + def _get_label2rels(rel_tuples: Iterable[Tuple[int, PredicateItem, RoleItem]]): + label2rels = defaultdict(set) + for tup in rel_tuples: + label_ = tup[-1].label + label2rels[label_].add(tup) + return label2rels + + pred_label2relations = _get_label2rels(pred_relation_tuples) + gold_label2relations = _get_label2rels(gold_relation_tuples) + + all_labels = set.union( + set(pred_label2relations.keys()), set(gold_label2relations.keys()) + ) + label2prf = {} + for label in all_labels: + pred_label_rels = pred_label2relations[label] + gold_label_rels = gold_label2relations[label] + label2prf[label] = _overlap_prf(gold_label_rels, pred_label_rels) + + return { + "Predicates": predicates_prf, + "ARGs": {"Overall": micro_rel_prf, "PerLabel": label2prf}, + } diff --git a/spacy_llm/tasks/templates/span-srl.v1.jinja b/spacy_llm/tasks/templates/span-srl.v1.jinja new file mode 100644 index 00000000..f8a5ec79 --- /dev/null +++ b/spacy_llm/tasks/templates/span-srl.v1.jinja @@ -0,0 +1,59 @@ +You are an expert Semantic Role Labeling (SRL) system. Your task is to accept Text as input and extract the Predicates and the Semantic Roles for each Predicate's ARGs in a step-by-step manner. +{# whitespace #} +{# whitespace #} +{%- if predicates -%} +Step 1: Use the following Predicates for the Text: +Predicates: {{predicates}} +{%- else -%} +Step 1: Extract the Predicates for the Text in the following format : +Predicates: +{%- endif -%} +{# whitespace #} +{# whitespace #} +Step 2: For each Predicate, extract only the following Semantic Roles in '''Text''' in this format: +Text: +Predicate: +{# whitespace #} +{%- for label, definition in label_definitions.items() -%} +{{ label }}: +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{%- if examples -%} +{# whitespace #} +Below are a few similar examples (only use these as a guide): +{# whitespace #} +{# whitespace #} +{%- for example in examples -%} +Example Text: +''' +{{ example.text }} +''' +{# whitespace #} +Step 1: Extract the Predicates in '''Example Text''': +{# whitespace #} +Predicates: {{ example.predicates|map(attribute='text')|join(', ') }} +{# whitespace #} +Step 2: For each Predicate, extract the Sematic Roles in '''Example Text''': +{# whitespace #} +{%- for predicate, relations in example.relations -%} +Predicate: {{predicate.text}} +{# whitespace #} +{%- for relation in relations -%} +{{relation.label}}: {{relation.role.text}} +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{%- endfor -%} +{# whitespace #} +{%- endif -%} +{# whitespace #} +Here is the text that needs labeling: +{# whitespace #} +Text: +''' +{{text}} +''' \ No newline at end of file diff --git a/spacy_llm/tests/tasks/examples/span_srl.jsonl b/spacy_llm/tests/tasks/examples/span_srl.jsonl new file mode 100644 index 00000000..98f73555 --- /dev/null +++ b/spacy_llm/tests/tasks/examples/span_srl.jsonl @@ -0,0 +1 @@ +{"text": "Ben bought a house last year in Berlin .", "predicates": [{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}], "relations": [[{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}, [{"role": {"text": "Ben", "start_char": 0, "end_char": 3}, "label": "ARG-0"}, {"role": {"text": "a house", "start_char": 11, "end_char": 18}, "label": "ARG-1"}, {"role": {"text": "last year", "start_char": 19, "end_char": 28}, "label": "ARG-M-TMP"}, {"role": {"text": "in Berlin", "start_char": 29, "end_char": 38}, "label": "ARG-M-LOC"}]]]} diff --git a/spacy_llm/tests/tasks/test_srl.py b/spacy_llm/tests/tasks/test_srl.py new file mode 100644 index 00000000..64fa60ae --- /dev/null +++ b/spacy_llm/tests/tasks/test_srl.py @@ -0,0 +1,173 @@ +from pathlib import Path + +import pytest +from confection import Config +from pytest import FixtureRequest + +from spacy_llm.pipeline import LLMWrapper +from spacy_llm.tasks.srl import SRLExample +from spacy_llm.tests.compat import has_openai_key +from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.util import assemble_from_config, split_labels + +EXAMPLES_DIR = Path(__file__).parent / "examples" + + +@pytest.fixture +def zeroshot_cfg_string(): + return """ + [paths] + examples = null + + [nlp] + lang = "en" + pipeline = ["llm"] + + [components] + + [components.llm] + factory = "llm" + + [components.llm.task] + @llm_tasks = "spacy.SRL.v1" + labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + + [components.llm.task.label_definitions] + ARG-0 = "Agent" + ARG-1 = "Patient or Theme" + ARG-M-TMP = "Temporal Modifier" + ARG-M-LOC = "Location Modifier" + + [components.llm.model] + @llm_models = "spacy.GPT-3-5.v1" + """ + + +@pytest.fixture +def fewshot_cfg_string(): + return f""" + [paths] + examples = null + + [nlp] + lang = "en" + pipeline = ["llm"] + + [components] + + [components.llm] + factory = "llm" + + [components.llm.task] + @llm_tasks = "spacy.SRL.v1" + labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + + [components.llm.task.label_definitions] + ARG-0 = "Agent" + ARG-1 = "Patient or Theme" + ARG-M-TMP = "Temporal Modifier" + ARG-M-LOC = "Location Modifier" + + [components.llm.task.examples] + @misc = "spacy.FewShotReader.v1" + path = {str((Path(__file__).parent / "examples" / "span_srl.jsonl"))} + + [components.llm.model] + @llm_models = "spacy.GPT-3-5.v1" + """ + + +@pytest.fixture +def task(): + text = "We love this sentence right now in Berlin" + predicate = {"text": "love", "start_char": 3, "end_char": 7} + srl_example = SRLExample( + **{ + "text": text, + "predicates": [predicate], + "relations": [ + ( + predicate, + [ + { + "label": "ARG-0", + "role": {"text": "We", "start_char": 0, "end_char": 2}, + }, + { + "label": "ARG-1", + "role": { + "text": "this sentence", + "start_char": 8, + "end_char": 21, + }, + }, + { + "label": "ARG-M-TMP", + "role": { + "text": "right now", + "start_char": 22, + "end_char": 31, + }, + }, + { + "label": "ARG-M-LOC", + "role": { + "text": "in Berlin", + "start_char": 32, + "end_char": 41, + }, + }, + ], + ) + ], + } + ) + return text, srl_example + + +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string"]) +def test_srl_config(cfg_string, request: FixtureRequest): + """Simple test to check if the config loads properly given different settings""" + cfg_string = request.getfixturevalue(cfg_string) + orig_config = Config().from_str(cfg_string) + nlp = assemble_from_config(orig_config) + assert nlp.pipe_names == ["llm"] + + pipe = nlp.get_pipe("llm") + assert isinstance(pipe, LLMWrapper) + assert isinstance(pipe.task, LLMTask) + + task = pipe.task + labels = orig_config["components"]["llm"]["task"]["labels"] + labels = sorted(split_labels(labels)) + assert isinstance(task, LabeledTask) + assert task.labels == tuple(labels) + assert set(pipe.labels) == set(task.labels) + assert nlp.pipe_labels["llm"] == list(task.labels) + + +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) +def test_srl_predict(task, cfg_string, request): + """Use OpenAI to get REL results.""" + cfg_string = request.getfixturevalue(cfg_string) + orig_config = Config().from_str(cfg_string) + nlp = assemble_from_config(orig_config) + + text, gold_example = task + doc = nlp(text) + + assert len(doc._.predicates) + assert len(doc._.relations) + + assert doc._.predicates[0]["text"] == gold_example.predicates[0].text + + predicated_roles = tuple( + sorted([r["role"]["text"] for p, rs in doc._.relations for r in rs]) + ) + gold_roles = tuple( + sorted([r.role.text for p, rs in gold_example.relations for r in rs]) + ) + + assert predicated_roles == gold_roles diff --git a/usage_examples/span_srl_openai/README.md b/usage_examples/span_srl_openai/README.md new file mode 100644 index 00000000..52643b22 --- /dev/null +++ b/usage_examples/span_srl_openai/README.md @@ -0,0 +1,63 @@ +# Semantic Role Labeling (SRL) using LLMs + +This example shows how you can use a model from OpenAI for SRL in +zero- and few-shot settings. + + +We leverage the OpenAI API to detect the predicates and argument roles in a sentence. +In the example below, we focus on the predicate "bought" and ARG-0, ARG-1, and ARG-M-LOC. + +First, create a new API key from +[openai.com](https://platform.openai.com/account/api-keys) or fetch an existing +one. Record the secret key and make sure this is available as an environmental +variable: + +```sh +export OPENAI_API_KEY="sk-..." +export OPENAI_API_ORG="org-..." +``` + +Then, you can run the pipeline on a sample text via: + +```sh +python run_pipeline.py [TEXT] [PATH TO CONFIG] [PATH TO FILE WITH EXAMPLES] +``` + +For example: + +```sh +python run_pipeline.py \ + "Laura bought an apartment last month in Berlin." \ + ./zeroshot.cfg +``` +or, for few-shot: +```sh +python run_pipeline.py \ + "Laura bought an apartment last month in Berlin." \ + ./fewshot.cfg \ + ./examples.jsonl +``` + +LLM-response: +```sh +LLM response for doc: Laura bought an apartment last month in Boston. + +Step 1: Extract the Predicates for the Text +Predicates: bought + +Step 2: For each Predicate, extract the Semantic Roles in 'Text' +Text: Laura bought an apartment last month in Boston. +Predicate: bought +ARG-0: Laura +ARG-1: an apartment +ARG-2: +ARG-M-TMP: last month +ARG-M-LOC: in Boston +``` +std output: +```sh +Text: Laura bought an apartment last month in Boston. +SRL Output: +Predicates: ['bought'] +Relations: [('bought', [('ARG-0', 'Laura'), ('ARG-1', 'an apartment'), ('ARG-M-TMP', 'last month'), ('ARG-M-LOC', 'in Boston')])] +``` \ No newline at end of file diff --git a/usage_examples/span_srl_openai/__init__.py b/usage_examples/span_srl_openai/__init__.py new file mode 100644 index 00000000..06fab2f6 --- /dev/null +++ b/usage_examples/span_srl_openai/__init__.py @@ -0,0 +1,3 @@ +from .run_pipeline import run_pipeline + +__all__ = ["run_pipeline"] diff --git a/usage_examples/span_srl_openai/examples.jsonl b/usage_examples/span_srl_openai/examples.jsonl new file mode 100644 index 00000000..98f73555 --- /dev/null +++ b/usage_examples/span_srl_openai/examples.jsonl @@ -0,0 +1 @@ +{"text": "Ben bought a house last year in Berlin .", "predicates": [{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}], "relations": [[{"text": "bought", "start_char": 4, "end_char": 10, "roleset_id": ""}, [{"role": {"text": "Ben", "start_char": 0, "end_char": 3}, "label": "ARG-0"}, {"role": {"text": "a house", "start_char": 11, "end_char": 18}, "label": "ARG-1"}, {"role": {"text": "last year", "start_char": 19, "end_char": 28}, "label": "ARG-M-TMP"}, {"role": {"text": "in Berlin", "start_char": 29, "end_char": 38}, "label": "ARG-M-LOC"}]]]} diff --git a/usage_examples/span_srl_openai/fewshot.cfg b/usage_examples/span_srl_openai/fewshot.cfg new file mode 100644 index 00000000..ffd1304a --- /dev/null +++ b/usage_examples/span_srl_openai/fewshot.cfg @@ -0,0 +1,28 @@ +[paths] +examples = null + +[nlp] +lang = "en" +pipeline = ["llm"] + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.SRL.v1" +labels = ARG-0,ARG-1,ARG-M-TMP,ARG-M-LOC + +[components.llm.task.label_definitions] +ARG-0 = "Agent" +ARG-1 = "Patient or Theme" +ARG-M-TMP = "Temporal Modifier" +ARG-M-LOC = "Location Modifier" + +[components.llm.task.examples] +@misc = "spacy.FewShotReader.v1" +path = ${paths.examples} + +[components.llm.model] +@llm_models = "spacy.GPT-3-5.v1" \ No newline at end of file diff --git a/usage_examples/span_srl_openai/run_pipeline.py b/usage_examples/span_srl_openai/run_pipeline.py new file mode 100644 index 00000000..f20375b6 --- /dev/null +++ b/usage_examples/span_srl_openai/run_pipeline.py @@ -0,0 +1,52 @@ +import os +import typer + +from pathlib import Path +from spacy_llm.util import assemble +from spacy_llm.tasks.srl.task import SRLExample +from spacy_llm.tasks.srl.util import PredicateItem, RoleItem +from typing import Optional +from wasabi import msg + +Arg = typer.Argument +Opt = typer.Option + + +def run_pipeline( + # fmt: off + text: str = Arg("", help="Text to perform text categorization on."), + config_path: Path = Arg(..., help="Path to the configuration file to use."), + examples_path: Optional[Path] = Arg(None, help="Path to the examples file to use (few-shot only)."), + verbose: bool = Opt(False, "--verbose", "-v", help="Show extra information."), + # fmt: on +): + if not os.getenv("OPENAI_API_KEY", None): + msg.fail( + "OPENAI_API_KEY env variable was not found. " + "Set it by running 'export OPENAI_API_KEY=...' and try again.", + exits=1, + ) + + msg.text(f"Loading config from {config_path}", show=verbose) + nlp = assemble( + config_path, + overrides={} + if examples_path is None + else {"paths.examples": str(examples_path)}, + ) + + doc = nlp(text) + + predicates = [PredicateItem(**p) for p in doc._.predicates] + relations = [ + (PredicateItem(**p), [RoleItem(**r) for r in rs]) for p, rs in doc._.relations + ] + + doc_srl = SRLExample(text=doc.text, predicates=predicates, relations=relations) + + msg.text(f"Text: {doc_srl.text}") + msg.text(f"SRL Output:\n{str(doc_srl)}\n") + + +if __name__ == "__main__": + typer.run(run_pipeline) diff --git a/usage_examples/span_srl_openai/zeroshot.cfg b/usage_examples/span_srl_openai/zeroshot.cfg new file mode 100644 index 00000000..cf826712 --- /dev/null +++ b/usage_examples/span_srl_openai/zeroshot.cfg @@ -0,0 +1,23 @@ +[nlp] +lang = "en" +pipeline = ["llm"] + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.SRL.v1" +labels = ARG-0,ARG-1,ARG-2,ARG-M-TMP,ARG-M-LOC + +[components.llm.task.label_definitions] +ARG-0 = "Agent" +ARG-1 = "Patient or Theme" +ARG-2 = "ARG-2" +ARG-M-TMP = "Temporal Modifier" +ARG-M-LOC = "Location Modifier" + +[components.llm.model] +@llm_models = "spacy.GPT-3-5.v1" +config = {"temperature": 1}