Skip to content

Commit

Permalink
Merge pull request #32 from boostcampaitech4lv23nlp2/feat/preprocess
Browse files Browse the repository at this point in the history
Feat/preprocess add translation and replace
  • Loading branch information
FacerAin authored Nov 23, 2022
2 parents 45294b5 + e579f8b commit 2ac4275
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 57 deletions.
27 changes: 0 additions & 27 deletions .github/workflows/check-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,6 @@ name: check-code
on: [pull_request]

jobs:
check-test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Cache pip
uses: actions/cache@v2
with:
# This path is specific to Ubuntu
path: ~/.cache/pip
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
- name: Install dependencies
run: |
python3 -m pip install --upgrade pip
- name: Check Test
run: |
make test

check-lint:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ mlflow==2.0.1
streamlit==1.14.1
seaborn==0.12.0
ipykernel==6.17.1
hanja==0.13.3
numpy==1.19.2
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
extend-ignore = E203, W503, E501, E231, E402, E731, E741, F401
extend-ignore = E203, W503, E501, E231, E402, E731, E741, F401, W605
max-line-length = 120

[tool:pytest]
Expand Down
4 changes: 2 additions & 2 deletions src/data_loader/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import torch

from src.utils import entity_representation
from src.utils import representation


def preprocessing_dataset(dataset):
Expand Down Expand Up @@ -38,7 +38,7 @@ def tokenized_dataset(self, dataset, tokenizer):
"""tokenizer에 따라 sentence를 tokenizing 합니다."""
concat_entity = []
for e01, e02, sentence in zip(dataset["subject_entity"], dataset["object_entity"], dataset["sentence"]):
temp = entity_representation(e01, e02, sentence, method=None)
temp = representation(e01, e02, sentence, entity_method=None, is_replace=False, translation_methods=[None])
concat_entity.append(temp)
tokenized_sentences = tokenizer(
concat_entity,
Expand Down
3 changes: 2 additions & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .arguments import DataTrainingArguments, ModelArguments, get_training_args
from .control_mlflow import save_model_remote, set_mlflow_logger
from .get_train_valid_split import get_train_valid_split
from .representation import entity_representation
from .preprocess import replace_symbol
from .representation import representation
from .set_seed import set_seed
from .utils import label_to_num, num_to_label
2 changes: 1 addition & 1 deletion src/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_training_args(
output_dir="./results",
save_total_limit=5,
save_strategy="epoch",
num_train_epochs=1,
num_train_epochs=20,
learning_rate=5e-5,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
Expand Down
17 changes: 17 additions & 0 deletions src/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import re

symbol_map = {"\"`'‘‘’“”'ˈ′": "'"}


def remove_symbol():
pass


def replace_symbol(sentence):
for key, value in symbol_map.items():
sentence = re.sub(f"[{key}]", value, sentence)
return sentence


def remove_language():
pass
91 changes: 77 additions & 14 deletions src/utils/representation.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,53 @@
from typing import Tuple

import hanja

def extraction(subject: str) -> Tuple[int, int, str, str]:
from . import replace_symbol


def translation(sentence: str, method: str = None) -> str:
assert method in [
None,
"chinese",
], "입력하신 method는 없습니다."
if method is None:
return sentence

if method == "chinese":
return hanja.translate(sentence, "substitution")


def extraction(entity: str) -> dict:
"""
Args:
subject (str): subject or object
entity (str): subject or object
Returns:
Tuple[int,int,str,str]: return subject object idx or subject object
Dict[int,int,str,str]: return dict containing entity information
"""
subject_entity = subject[:-1].split(",")[-1].split(":")[1]
subject_length = len(subject.split(","))
sub_start_idx = int(subject.split(",", subject_length - 3)[subject_length - 3].split(",")[0].split(":")[1])
sub_end_idx = int(subject.split(",", subject_length - 3)[subject_length - 3].split(",")[1].split(":")[1])
subject = "".join(subject.split(",", subject_length - 3)[: subject_length - 3]).split(":")[1]
subject_entity = subject_entity.replace("'", "").strip()
entity_type = entity[:-1].split(",")[-1].split(":")[1]
entity_length = len(entity.split(","))
start_idx = int(entity.split(",", entity_length - 3)[entity_length - 3].split(",")[0].split(":")[1])
end_idx = int(entity.split(",", entity_length - 3)[entity_length - 3].split(",")[1].split(":")[1])
entity_word = "".join(entity.split(",", entity_length - 3)[: entity_length - 3]).split(":")[1]
entity_word = entity_word.replace("'", "").strip()
entity_type = entity_type.replace("'", "").strip()

entity_dict = {
"start_idx": start_idx,
"end_idx": end_idx,
"entity_type": entity_type,
"entity_word": entity_word,
}

return entity_dict

return sub_start_idx, sub_end_idx, subject, subject_entity

def unpack_entity_dict(start_idx, end_idx, entity_type, entity_word):
return start_idx, end_idx, entity_word, entity_type

def entity_representation(subject: str, object: str, sentence: str, method: str = None) -> str:

def entity_representation(subject_dict: dict, object_dict: dict, sentence: str, method: str = None) -> str:
"""
Args:
subject (str): subject dictionary
Expand All @@ -40,14 +68,14 @@ def entity_representation(subject: str, object: str, sentence: str, method: str
"typed_entity_marker_punct",
], "입력하신 method는 없습니다."

sub_start_idx, sub_end_idx, subject, subject_entity = extraction(subject)
obj_start_idx, obj_end_idx, object, object_entity = extraction(object)
sub_start_idx, sub_end_idx, subject, subject_entity = unpack_entity_dict(**subject_dict)
obj_start_idx, obj_end_idx, object, object_entity = unpack_entity_dict(**object_dict)

# entity representation

# baseline code
if method is None:
temp = subject + " [SEP]" + object + " [SEP] " + sentence
temp = subject + " [SEP] " + object + " [SEP] " + sentence

# entity mask
elif method == "entity_mask":
Expand Down Expand Up @@ -145,3 +173,38 @@ def entity_representation(subject: str, object: str, sentence: str, method: str
temp = temp.replace(f"<O:{object_entity}>", f"# ∧ {object_entity.lower()} ∧")

return temp


def representation(
subject: str,
object: str,
sentence: str,
entity_method: str = None,
translation_methods: list = [None],
is_replace=False,
) -> str:
"""
Args:
subject (str): subject dictionary
object (str): object dictionary
sentence (str): single sentence
entity_method (str, optional): entity representation. Defaults to None.
translation_methods (list, optional): translation methods: (None, chinese)
is_replace (bool, optional) replace symbol methods. Defaults to False.(True, False)
Returns:
str: single sentence
"""

subject_dict = extraction(subject)
object_dict = extraction(object)

tmp = entity_representation(subject_dict, object_dict, sentence, method=entity_method)

for translation_method in translation_methods:
tmp = translation(tmp, method=translation_method)

if is_replace:
tmp = replace_symbol(tmp)

return tmp
33 changes: 22 additions & 11 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import unittest

from src.utils.preprocess import replace_symbol

class PreprocessTester(unittest.TestCase):
def test_run(self):
pass
test_quote_symbol_map = {"\"`'‘‘’“”'ˈ′": "'"}
test_bracket_symbol_map = {"[[": "<", "]]": ">", "\[《〈「˹「⟪≪<⌜『«": "<", "\]》〉」˼」⟫≫>⌟»": ">", "({": "(", ")}": ")"}
test_sentences = [
"비틀즈 [SEP] 조지 해리슨 [SEP] 〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.{}",
]

test_quote_answers = [
"비틀즈 [SEP] 조지 해리슨 [SEP] 〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.{}",
]

def test_chinese(self):
pass
test_bracket_answers = [
"비틀즈 [SEP] 조지 해리슨 [SEP] <Something>는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 <Abbey Road>에 담은 노래다.()",
]

def test_japanese(self):
pass

def test_etc_language(self):
pass
class PreprocessTester(unittest.TestCase):
def test_replace_symbol(self):
for sentence, answer in zip(test_sentences, test_quote_answers):
generate_sentence = replace_symbol(sentence, test_quote_symbol_map)
self.assertEqual(generate_sentence, answer)

def test_korean(self):
pass
def test_bracket_symbol(self):
for sentence, answer in zip(test_sentences, test_bracket_answers):
generate_sentence = replace_symbol(sentence, test_bracket_symbol_map)
self.assertEqual(generate_sentence, answer)
72 changes: 72 additions & 0 deletions tests/test_representation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest

from src.utils.representation import representation

test_objects = [
[
"〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.",
"{'word': '비틀즈', 'start_idx': 24, 'end_idx': 26, 'type': 'ORG'}",
"{'word': '조지 해리슨', 'start_idx': 13, 'end_idx': 18, 'type': 'PER'}",
],
[
"K리그2에서 성적 1위를 달리고 있는 광주FC는 지난 26일 한국프로축구연맹으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다.",
"{'word': '광주FC', 'start_idx': 21, 'end_idx': 24, 'type': 'ORG'}",
"{'word': '한국프로축구연맹', 'start_idx': 34, 'end_idx': 41, 'type': 'ORG'}",
],
[
"백한성(白漢成, 水原鶴人, 1899년 6월 15일 조선 충청도 공주 출생 ~ 1971년 10월 13일 대한민국 서울에서 별세.)은 대한민국의 정치가이며 법조인이다.",
"{'word': '백한성', 'start_idx': 0, 'end_idx': 2, 'type': 'PER'}",
"{'word': '조선 충청도 공주', 'start_idx': 28, 'end_idx': 36, 'type': 'LOC'}",
],
[
"KBS 전주방송총국(KBS 全州放送總局)은 전라북도 지역을 대상으로 하는 한국방송공사의 지역 방송 총국이다.",
"{'word': 'KBS 전주방송총국', 'start_idx': 0, 'end_idx': 9, 'type': 'ORG'}",
"{'word': 'KBS 全州放送總局', 'start_idx': 11, 'end_idx': 20, 'type': 'ORG'}",
],
]

none_answers = [
"비틀즈 [SEP] 조지 해리슨 [SEP] 〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.",
"광주FC [SEP] 한국프로축구연맹 [SEP] K리그2에서 성적 1위를 달리고 있는 광주FC는 지난 26일 한국프로축구연맹으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다.",
"백한성 [SEP] 조선 충청도 공주 [SEP] 백한성(白漢成, 水原鶴人, 1899년 6월 15일 조선 충청도 공주 출생 ~ 1971년 10월 13일 대한민국 서울에서 별세.)은 대한민국의 정치가이며 법조인이다.",
"KBS 전주방송총국 [SEP] KBS 全州放送總局 [SEP] KBS 전주방송총국(KBS 全州放送總局)은 전라북도 지역을 대상으로 하는 한국방송공사의 지역 방송 총국이다.",
]


chinese_answers = [
"비틀즈 [SEP] 조지 해리슨 [SEP] 〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.",
"광주FC [SEP] 한국프로축구연맹 [SEP] K리그2에서 성적 1위를 달리고 있는 광주FC는 지난 26일 한국프로축구연맹으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다.",
"백한성 [SEP] 조선 충청도 공주 [SEP] 백한성(백한성, 수원학인, 1899년 6월 15일 조선 충청도 공주 출생 ~ 1971년 10월 13일 대한민국 서울에서 별세.)은 대한민국의 정치가이며 법조인이다.",
"KBS 전주방송총국 [SEP] KBS 전주방송총국 [SEP] KBS 전주방송총국(KBS 전주방송총국)은 전라북도 지역을 대상으로 하는 한국방송공사의 지역 방송 총국이다.",
]

replace_symbole_answers = [
"비틀즈 [SEP] 조지 해리슨 [SEP] 〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.",
"광주FC [SEP] 한국프로축구연맹 [SEP] K리그2에서 성적 1위를 달리고 있는 광주FC는 지난 26일 한국프로축구연맹으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다.",
"백한성 [SEP] 조선 충청도 공주 [SEP] 백한성(白漢成, 水原鶴人, 1899년 6월 15일 조선 충청도 공주 출생 ~ 1971년 10월 13일 대한민국 서울에서 별세.)은 대한민국의 정치가이며 법조인이다.",
"KBS 전주방송총국 [SEP] KBS 全州放送總局 [SEP] KBS 전주방송총국(KBS 全州放送總局)은 전라북도 지역을 대상으로 하는 한국방송공사의 지역 방송 총국이다.",
]


class RepresentationTester(unittest.TestCase):
def test_none(self):
for example_object, answer in zip(test_objects, none_answers):
sentence, subject, object = example_object
generate_text = representation(subject, object, sentence, entity_method=None)
self.assertEqual(generate_text, answer)

def test_chinese(self):
for example_object, answer in zip(test_objects, chinese_answers):
sentence, subject, object = example_object
generate_text = representation(
subject, object, sentence, entity_method=None, translation_methods=["chinese"]
)
self.assertEqual(generate_text, answer)

def test_replace(self):
for example_object, answer in zip(test_objects, replace_symbole_answers):
sentence, subject, object = example_object
generate_text = representation(
subject, object, sentence, entity_method=None, translation_methods=[None], is_replace=False
)
self.assertEqual(generate_text, answer)
31 changes: 31 additions & 0 deletions tests/test_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

from src.utils.representation import translation

test_sentences = [
"백한성(白漢成, 水原鶴人, 1899년 6월 15일 조선 충청도 공주 출생 ~ 1971년 10월 13일 대한민국 서울에서 별세.)은 대한민국의 정치가이며 법조인이다.",
'1904년 7월 1일, ""툰-운트 스포트버라인 바이어 04 레버쿠젠"" (Turn- und Spielverein Bayer 04 Leverkusen)의 이름으로 창단되었다.',
"헌강왕(憲康王, ~ 886년, 재위: 875년 ~ 886년)은 신라의 제49대 왕이다.",
"쇼니 씨(少弐氏)의 8대 당주로 쇼니 요리히사(少弐頼尚)의 둘째 아들이다.",
"버턴 릭터(Burton Richter, 1931년 3월 22일 ~ 2018년 7월 18일)는 노벨 물리학상을 받은 미국의 물리학자이다.",
"유한굉(劉漢宏, Liu Hanhong, ~ 887년)은 중국 당나라 말기에 활약했던 군벌로, 당초에는 당나라에 반기를 들었으나, 후에 당나라의 관직을 받고 의승군 절도사(義勝軍節度使, 본거지는 지금의 저장 성 사오싱 시)로서 절강 동부 일대를 지배하였다.",
]

chiense_sentences = [
"백한성(백한성, 수원학인, 1899년 6월 15일 조선 충청도 공주 출생 ~ 1971년 10월 13일 대한민국 서울에서 별세.)은 대한민국의 정치가이며 법조인이다.",
'1904년 7월 1일, ""툰-운트 스포트버라인 바이어 04 레버쿠젠"" (Turn- und Spielverein Bayer 04 Leverkusen)의 이름으로 창단되었다.',
"헌강왕(헌강왕, ~ 886년, 재위: 875년 ~ 886년)은 신라의 제49대 왕이다.",
"쇼니 씨(소이씨)의 8대 당주로 쇼니 요리히사(소이뢰상)의 둘째 아들이다.",
"버턴 릭터(Burton Richter, 1931년 3월 22일 ~ 2018년 7월 18일)는 노벨 물리학상을 받은 미국의 물리학자이다.",
"유한굉(유한굉, Liu Hanhong, ~ 887년)은 중국 당나라 말기에 활약했던 군벌로, 당초에는 당나라에 반기를 들었으나, 후에 당나라의 관직을 받고 의승군 절도사(의승군절도사, 본거지는 지금의 저장 성 사오싱 시)로서 절강 동부 일대를 지배하였다.",
]


class TranlsationTester(unittest.TestCase):
def test_run(self):
pass

def test_chinese(self):
for sentence, answer in zip(test_sentences, chiense_sentences):
translation_sentence = translation(sentence, method="chinese")
self.assertEqual(translation_sentence, answer)

0 comments on commit 2ac4275

Please sign in to comment.