Skip to content

Commit

Permalink
some small fixes, touchups, space behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Oct 27, 2024
1 parent 6f4de68 commit 81ab1e8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ sat.predict_proba(text)

```python
# import library to register the custom models
import wtpsplit
import wtpsplit.models
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("segment-any-text/sat-3l-sm") # or some other model name; see https://huggingface.co/segment-any-text
Expand Down Expand Up @@ -217,6 +217,7 @@ torch.save(
"dummy-dataset.pth"
)
```
Note that there should not be any newlines within individual sentences! Your corpus should already be well-split.

Create/adapt config; provide base model via `model_name_or_path` and training data .pth via `text_path`:

Expand Down
17 changes: 17 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ def test_strip_whitespace():
assert splits == ["This is a test sentence", "This is another test sentence."]


def test_strip_newline_behaviour():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

splits = sat.split(
"Yes\nthis is a test sentence. This is another test sentence.",
)
assert splits == ["Yes", "this is a test sentence. ", "This is another test sentence."]

def test_strip_newline_behaviour_as_spaces():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

splits = sat.split(
"Yes\nthis is a test sentence. This is another test sentence.", treat_newline_as_space=True
)
assert splits == ["Yes\nthis is a test sentence. ", "This is another test sentence."]


def test_split_noisy():
sat = SaT("segment-any-text/sat-12l-sm", hub_prefix=None)

Expand Down
20 changes: 19 additions & 1 deletion wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ def split(
outer_batch_size=1000,
paragraph_threshold: float = 0.5,
strip_whitespace: bool = False,
do_paragraph_segmentation=False,
do_paragraph_segmentation: bool = False,
treat_newline_as_space: bool = False,
verbose: bool = False,
):
if isinstance(text_or_texts, str):
Expand All @@ -705,6 +706,7 @@ def split(
paragraph_threshold=paragraph_threshold,
strip_whitespace=strip_whitespace,
do_paragraph_segmentation=do_paragraph_segmentation,
treat_newline_as_space=treat_newline_as_space,
verbose=verbose,
)
)
Expand All @@ -721,6 +723,7 @@ def split(
paragraph_threshold=paragraph_threshold,
strip_whitespace=strip_whitespace,
do_paragraph_segmentation=do_paragraph_segmentation,
treat_newline_as_space=treat_newline_as_space,
verbose=verbose,
)

Expand All @@ -736,6 +739,7 @@ def _split(
remove_whitespace_before_inference: bool,
outer_batch_size: int,
do_paragraph_segmentation: bool,
treat_newline_as_space: bool,
strip_whitespace: bool,
verbose: bool,
):
Expand Down Expand Up @@ -791,4 +795,18 @@ def get_default_threshold(model_str: str):
sentences = indices_to_sentences(
text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace
)
if not treat_newline_as_space:
# within the model, newlines in the text were ignored - they were treated as spaces.
# this is the default behavior: additionally split on newlines as provided in the input
new_sentences = []
for sentence in sentences:
new_sentences.extend(sentence.split("\n"))
sentences = new_sentences
else:
warnings.warn(
"treat_newline_as_space=True will lead to newlines in the output "
"if they were present in the input. Within the model, such newlines are "
"treated as spaces. "
"If you want to split on such newlines, set treat_newline_as_space=False."
)
yield sentences
2 changes: 1 addition & 1 deletion wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def compute_f1(pred, true):

def get_metrics(labels, preds, threshold: float = 0.01):
# Compute precision-recall curve and AUC
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, preds)
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, sigmoid(preds))
pr_auc = sklearn.metrics.auc(recall, precision)

# Compute F1 scores for all thresholds
Expand Down

0 comments on commit 81ab1e8

Please sign in to comment.