From 45c1d96e27c12f4481796a8d48a69095e63a7d3b Mon Sep 17 00:00:00 2001 From: Marcel Robeer Date: Fri, 9 Oct 2020 15:00:57 +0200 Subject: [PATCH] enable include_detail in context_word_embs.py --- nlpaug/augmenter/word/context_word_embs.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/nlpaug/augmenter/word/context_word_embs.py b/nlpaug/augmenter/word/context_word_embs.py index 4aa72a5..6b25431 100755 --- a/nlpaug/augmenter/word/context_word_embs.py +++ b/nlpaug/augmenter/word/context_word_embs.py @@ -275,10 +275,12 @@ def insert(self, data): augmented_text += ' ' + tail_text augmented_texts.append(augmented_text) - if isinstance(data, list): - return augmented_texts + augmented_texts = augmented_texts if isinstance(data, list) else augmented_texts[0] + + if self.include_detail: + return augmented_texts, head_doc.get_change_logs() else: - return augmented_texts[0] + return augmented_texts def substitute(self, data): if not data: @@ -414,10 +416,12 @@ def substitute(self, data): augmented_text += ' ' + tail_text augmented_texts.append(augmented_text) - if isinstance(data, list): - return augmented_texts + augmented_texts = augmented_texts if isinstance(data, list) else augmented_texts[0] + + if self.include_detail: + return augmented_texts, head_doc.get_change_logs() else: - return augmented_texts[0] + return augmented_texts @classmethod def get_model(cls, model_path, device='cuda', force_reload=False, temperature=1.0, top_k=None, top_p=0.0,