From b50ccfb1486353e63e35725465de4b5b81ae3001 Mon Sep 17 00:00:00 2001 From: Manabu Niseki Date: Sat, 25 May 2024 10:51:16 +0900 Subject: [PATCH] refactor: renew factories (inject dependencies via __init__) --- backend/factories/abstract.py | 6 ++---- backend/factories/emailrep.py | 22 ++++++++++----------- backend/factories/eml.py | 3 +-- backend/factories/inquest.py | 17 ++++++++-------- backend/factories/oldid.py | 11 ++++++----- backend/factories/response.py | 18 ++++++++--------- backend/factories/spamassassin.py | 17 ++++++++++------ backend/factories/urlscan.py | 17 ++++++++-------- backend/factories/virustotal.py | 13 +++++++------ tests/conftest.py | 17 ++++++++-------- tests/factories/test_eml.py | 29 ++++++++++++++++------------ tests/factories/test_inquest.py | 10 +++++++--- tests/factories/test_oleid.py | 19 +++++++++++++----- tests/factories/test_spamassassin.py | 11 +++++++---- tests/factories/test_urlscan.py | 11 +++++++---- tests/factories/test_virustotal.py | 10 +++++++--- 16 files changed, 129 insertions(+), 102 deletions(-) diff --git a/backend/factories/abstract.py b/backend/factories/abstract.py index 26bc18c..761c516 100644 --- a/backend/factories/abstract.py +++ b/backend/factories/abstract.py @@ -3,14 +3,12 @@ class AbstractFactory(ABC): - @classmethod @abstractmethod - def call(cls, *args: typing.Any, **kwargs: typing.Any): + def call(self, *args: typing.Any, **kwargs: typing.Any): raise NotImplementedError() class AbstractAsyncFactory(ABC): - @classmethod @abstractmethod - async def call(cls, *args: typing.Any, **kwargs: typing.Any): + async def call(self, *args: typing.Any, **kwargs: typing.Any): raise NotImplementedError() diff --git a/backend/factories/emailrep.py b/backend/factories/emailrep.py index 3025460..7a0ce0c 100644 --- a/backend/factories/emailrep.py +++ b/backend/factories/emailrep.py @@ -10,8 +10,6 @@ from .abstract import AbstractAsyncFactory -NAME_OR_KEY = "EmailRep" - @future_safe async def lookup(email: str, *, client: clients.EmailRep) -> schemas.EmailRepLookup: @@ -19,7 +17,7 @@ async def lookup(email: str, *, client: clients.EmailRep) -> schemas.EmailRepLoo @future_safe -async def transform(lookup: schemas.EmailRepLookup, *, name_or_key: str = NAME_OR_KEY): +async def transform(lookup: schemas.EmailRepLookup, *, key_or_name: str): details: list[schemas.VerdictDetail] = [] malicious = False @@ -28,18 +26,20 @@ async def transform(lookup: schemas.EmailRepLookup, *, name_or_key: str = NAME_O malicious = True description = f"{lookup.email} is suspicious. See https://emailrep.io/{lookup.email} for details." - details.append(schemas.VerdictDetail(key=name_or_key, description=description)) - return schemas.Verdict(name=name_or_key, malicious=malicious, details=details) + details.append(schemas.VerdictDetail(key=key_or_name, description=description)) + return schemas.Verdict(name=key_or_name, malicious=malicious, details=details) class EmailRepVerdictFactory(AbstractAsyncFactory): - @classmethod - async def call( - cls, email: str, *, client: clients.EmailRep, name_or_key: str = NAME_OR_KEY - ) -> schemas.Verdict: + def __init__(self, client: clients.EmailRep, *, name: str = "EmailRep"): + self.client = client + self.name = name + + async def call(self, email: str, key: str | None = None) -> schemas.Verdict: + key_or_name: str = key or self.name f_result: FutureResultE[schemas.Verdict] = flow( - lookup(email, client=client), - bind(partial(transform, name_or_key=name_or_key)), + lookup(email, client=self.client), + bind(partial(transform, key_or_name=key_or_name)), ) result = await f_result.awaitable() return unsafe_perform_io(result.alt(raise_exception).unwrap()) diff --git a/backend/factories/eml.py b/backend/factories/eml.py index b74fb04..f42c67c 100644 --- a/backend/factories/eml.py +++ b/backend/factories/eml.py @@ -169,8 +169,7 @@ def transform(parsed: dict) -> schemas.Eml: class EmlFactory(AbstractFactory): - @classmethod - def call(cls, data: bytes) -> schemas.Eml: + def call(self, data: bytes) -> schemas.Eml: result: ResultE[schemas.Eml] = flow( to_eml(data), bind(parse), diff --git a/backend/factories/inquest.py b/backend/factories/inquest.py index 276b27d..381d0f2 100644 --- a/backend/factories/inquest.py +++ b/backend/factories/inquest.py @@ -9,8 +9,6 @@ from backend import clients, schemas, settings, types -NAME = "InQuest" - @future_safe async def lookup(sha256: str, *, client: clients.InQuest) -> schemas.InQuestLookup: @@ -36,7 +34,7 @@ async def bulk_lookup( @future_safe -async def transform(lookups: list[schemas.InQuestLookup], *, name: str = NAME): +async def transform(lookups: list[schemas.InQuestLookup], *, name: str): malicious_lookups = [lookup for lookup in lookups if lookup.malicious] if len(malicious_lookups) == 0: @@ -67,24 +65,25 @@ async def transform(lookups: list[schemas.InQuestLookup], *, name: str = NAME): class InQuestVerdictFactory: - @classmethod + def __init__(self, client: clients.InQuest, *, name: str = "InQuest"): + self.client = client + self.name = name + async def call( - cls, + self, sha256s: types.ListSet[str], *, - client: clients.InQuest, - name: str = NAME, max_per_second: float | None = settings.ASYNC_MAX_PER_SECOND, max_at_once: int | None = settings.ASYNC_MAX_AT_ONCE, ) -> schemas.Verdict: f_result: FutureResultE[schemas.Verdict] = flow( bulk_lookup( sha256s, - client=client, + client=self.client, max_at_once=max_at_once, max_per_second=max_per_second, ), - bind(partial(transform, name=name)), + bind(partial(transform, name=self.name)), ) result = await f_result.awaitable() return unsafe_perform_io(result.alt(raise_exception).unwrap()) diff --git a/backend/factories/oldid.py b/backend/factories/oldid.py index df17254..105ef35 100644 --- a/backend/factories/oldid.py +++ b/backend/factories/oldid.py @@ -8,8 +8,6 @@ from .abstract import AbstractFactory -NAME = "oleid" - @safe def parse(attachment: schemas.Attachment) -> OleID: @@ -74,9 +72,12 @@ def inner(oleid: OleID): class OleIDVerdictFactory(AbstractFactory): - @classmethod + def __init__(self, name: str = "oleid"): + self.name = name + def call( - cls, attachments: list[schemas.Attachment], *, name: str = NAME + self, + attachments: list[schemas.Attachment], ) -> schemas.Verdict: details = list( itertools.chain.from_iterable( @@ -95,4 +96,4 @@ def call( description="There is no suspicious OLE file in attachments.", ) ) - return schemas.Verdict(name=name, malicious=malicious, details=details) + return schemas.Verdict(name=self.name, malicious=malicious, details=details) diff --git a/backend/factories/response.py b/backend/factories/response.py index 5b8faa7..c931e54 100644 --- a/backend/factories/response.py +++ b/backend/factories/response.py @@ -28,7 +28,7 @@ def log_exception(exception: Exception): @future_safe async def parse(eml_file: bytes) -> schemas.Response: return schemas.Response( - eml=EmlFactory.call(eml_file), id=hashlib.sha256(eml_file).hexdigest() + eml=EmlFactory().call(eml_file), id=hashlib.sha256(eml_file).hexdigest() ) @@ -36,38 +36,38 @@ async def parse(eml_file: bytes) -> schemas.Response: async def get_spam_assassin_verdict( eml_file: bytes, *, client: clients.SpamAssassin ) -> schemas.Verdict: - return await SpamAssassinVerdictFactory.call(eml_file, client=client) + return await SpamAssassinVerdictFactory(client).call(eml_file) @future_safe async def get_oleid_verdict(attachments: list[schemas.Attachment]) -> schemas.Verdict: - return OleIDVerdictFactory.call(attachments) + return OleIDVerdictFactory().call(attachments) @future_safe async def get_email_rep_verdicts(from_, *, client: clients.EmailRep) -> schemas.Verdict: - return await EmailRepVerdictFactory.call(from_, client=client) + return await EmailRepVerdictFactory(client).call(from_) @future_safe async def get_urlscan_verdict( urls: types.ListSet[str], *, client: clients.UrlScan ) -> schemas.Verdict: - return await UrlScanVerdictFactory.call(urls, client=client) + return await UrlScanVerdictFactory(client).call(urls) @future_safe async def get_inquest_verdict( sha256s: types.ListSet[str], *, client: clients.InQuest ) -> schemas.Verdict: - return await InQuestVerdictFactory.call(sha256s, client=client) + return await InQuestVerdictFactory(client).call(sha256s) @future_safe async def get_vt_verdict( sha256s: types.ListSet[str], *, client: clients.VirusTotal ) -> schemas.Verdict: - return await VirusTotalVerdictFactory.call(sha256s, client=client) + return await VirusTotalVerdictFactory(client).call(sha256s) @future_safe @@ -109,9 +109,7 @@ async def set_verdicts( return response -class ResponseFactory( - AbstractAsyncFactory, -): +class ResponseFactory(AbstractAsyncFactory): @classmethod async def call( cls, diff --git a/backend/factories/spamassassin.py b/backend/factories/spamassassin.py index 4a1f070..aca6d64 100644 --- a/backend/factories/spamassassin.py +++ b/backend/factories/spamassassin.py @@ -1,3 +1,5 @@ +from functools import partial + from returns.functions import raise_exception from returns.future import FutureResultE, future_safe from returns.pipeline import flow @@ -8,8 +10,6 @@ from .abstract import AbstractAsyncFactory -NAME = "SpamAssassin" - @future_safe async def report( @@ -20,7 +20,7 @@ async def report( @future_safe async def transform( - report: schemas.SpamAssassinReport, *, name: str = NAME + report: schemas.SpamAssassinReport, *, name: str ) -> schemas.Verdict: details = [ schemas.VerdictDetail( @@ -39,12 +39,17 @@ async def transform( class SpamAssassinVerdictFactory(AbstractAsyncFactory): - @classmethod + def __init__(self, client: clients.SpamAssassin, *, name: str = "SpamAssassin"): + self.client = client + self.name = name + async def call( - cls, eml_file: bytes, *, client: clients.SpamAssassin + self, + eml_file: bytes, ) -> schemas.Verdict: f_result: FutureResultE[schemas.Verdict] = flow( - report(eml_file, client=client), bind(transform) + report(eml_file, client=self.client), + bind(partial(transform, name=self.name)), ) result = await f_result.awaitable() return unsafe_perform_io(result.alt(raise_exception).unwrap()) diff --git a/backend/factories/urlscan.py b/backend/factories/urlscan.py index 74fa038..20fd991 100644 --- a/backend/factories/urlscan.py +++ b/backend/factories/urlscan.py @@ -12,8 +12,6 @@ from .abstract import AbstractAsyncFactory -NAME = "urlscan.io" - @future_safe async def lookup(url: str, *, client: clients.UrlScan) -> schemas.UrlScanLookup: @@ -39,7 +37,7 @@ async def bulk_lookup( @future_safe -async def transform(lookups: list[schemas.UrlScanLookup], *, name: str = NAME): +async def transform(lookups: list[schemas.UrlScanLookup], *, name: str): results = itertools.chain.from_iterable([lookup.results for lookup in lookups]) malicious_results = [result for result in results if result.verdicts.malicious] @@ -71,24 +69,25 @@ async def transform(lookups: list[schemas.UrlScanLookup], *, name: str = NAME): class UrlScanVerdictFactory(AbstractAsyncFactory): - @classmethod + def __init__(self, client: clients.UrlScan, *, name: str = "urlscan.io"): + self.client = client + self.name = name + async def call( - cls, + self, urls: types.ListSet[str], *, - client: clients.UrlScan, - name: str = NAME, max_per_second: float | None = settings.ASYNC_MAX_PER_SECOND, max_at_once: int | None = settings.ASYNC_MAX_AT_ONCE, ): f_result: FutureResultE[schemas.Verdict] = flow( bulk_lookup( urls, - client=client, + client=self.client, max_at_once=max_at_once, max_per_second=max_per_second, ), - bind(partial(transform, name=name)), + bind(partial(transform, name=self.name)), ) result = await f_result.awaitable() return unsafe_perform_io(result.alt(raise_exception).unwrap()) diff --git a/backend/factories/virustotal.py b/backend/factories/virustotal.py index bc126eb..4227ae6 100644 --- a/backend/factories/virustotal.py +++ b/backend/factories/virustotal.py @@ -73,24 +73,25 @@ async def transform(objects: list[vt.Object], *, name: str = NAME) -> schemas.Ve class VirusTotalVerdictFactory(AbstractAsyncFactory): - @classmethod + def __init__(self, client: clients.VirusTotal, *, name: str = "VirusTotal"): + self.client = client + self.name = name + async def call( - cls, + self, sha256s: types.ListSet[str], *, - client: clients.VirusTotal, - name: str = NAME, max_per_second: float | None = settings.ASYNC_MAX_PER_SECOND, max_at_once: int | None = settings.ASYNC_MAX_AT_ONCE, ) -> schemas.Verdict: f_result: FutureResultE[schemas.Verdict] = flow( bulk_get_file_objects( sha256s, - client=client, + client=self.client, max_at_once=max_at_once, max_per_second=max_per_second, ), - bind(partial(transform, name=name)), + bind(partial(transform, name=self.name)), ) result = await f_result.awaitable() return unsafe_perform_io(result.alt(raise_exception).unwrap()) diff --git a/tests/conftest.py b/tests/conftest.py index 0965e98..906202f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,21 +26,20 @@ async def is_spam_assassin_responsive(port: int) -> bool: return False -if not ci.is_ci(): +if ci.is_ci(): @pytest.fixture(scope="session", autouse=True) - def docker_compose(docker_ip: str, docker_services: Services): # type: ignore + def docker_compose(): # type: ignore + return +else: + + @pytest.fixture(scope="session", autouse=True) + def docker_compose(docker_ip: str, docker_services: Services): port = docker_services.port_for("spamassassin", 783) docker_services.wait_until_responsive( timeout=60.0, pause=0.1, check=lambda: is_spam_assassin_responsive(port) ) -else: - - @pytest.fixture - def docker_compose(): - return - @pytest.fixture def spam_assassin() -> clients.SpamAssassin: @@ -119,7 +118,7 @@ def test_html() -> str: @pytest.fixture def docx_attachment(encrypted_docx_eml: bytes) -> schemas.Attachment: - eml = factories.EmlFactory.call(encrypted_docx_eml) + eml = factories.EmlFactory().call(encrypted_docx_eml) return eml.attachments[0] diff --git a/tests/factories/test_eml.py b/tests/factories/test_eml.py index 0c7e0cd..734a220 100644 --- a/tests/factories/test_eml.py +++ b/tests/factories/test_eml.py @@ -4,8 +4,13 @@ from backend.factories.eml import is_inline_forward_attachment -def test_sample(sample_eml: bytes): - eml = factories.EmlFactory.call(sample_eml) +@pytest.fixture() +def factory(): + return factories.EmlFactory() + + +def test_sample(sample_eml: bytes, factory: factories.EmlFactory): + eml = factory.call(sample_eml) assert eml.header.message_id is None assert eml.header.subject == "Winter promotions" assert eml.header.to == ["foo.bar@example.com"] @@ -14,8 +19,8 @@ def test_sample(sample_eml: bytes): assert len(eml.bodies) == 2 -def test_cc(cc_eml: bytes): - eml = factories.EmlFactory.call(cc_eml) +def test_cc(cc_eml: bytes, factory: factories.EmlFactory): + eml = factory.call(cc_eml) assert eml.header.message_id == "ecc38b11-aa06-44c9-b8de-283b06a1d89e@example.com" assert eml.header.subject == "To and Cc headers" assert eml.header.to == ["foo.bar@example.com", "info@example.com"] @@ -28,8 +33,8 @@ def test_cc(cc_eml: bytes): assert eml.attachments == [] -def test_multipart(multipart_eml: bytes): - eml = factories.EmlFactory.call(multipart_eml) +def test_multipart(multipart_eml: bytes, factory: factories.EmlFactory): + eml = factory.call(multipart_eml) assert eml.attachments is not None assert len(eml.attachments) == 1 @@ -38,8 +43,8 @@ def test_multipart(multipart_eml: bytes): assert first.hash.md5 == "f561388f7446cedd5b8b480311744b3c" -def test_encrypted_docx(encrypted_docx_eml: bytes): - eml = factories.EmlFactory.call(encrypted_docx_eml) +def test_encrypted_docx(encrypted_docx_eml: bytes, factory: factories.EmlFactory): + eml = factory.call(encrypted_docx_eml) assert eml.attachments is not None assert len(eml.attachments) == 1 @@ -50,14 +55,14 @@ def test_encrypted_docx(encrypted_docx_eml: bytes): ) -def test_emails(emails: list[bytes]): +def test_emails(emails: list[bytes], factory: factories.EmlFactory): for email in emails: - eml = factories.EmlFactory.call(email) + eml = factory.call(email) assert eml is not None -def test_complete_msg(complete_msg: bytes): - eml = factories.EmlFactory.call(complete_msg) +def test_complete_msg(complete_msg: bytes, factory: factories.EmlFactory): + eml = factory.call(complete_msg) assert eml.header.subject == "Test Multiple attachments complete email!!" diff --git a/tests/factories/test_inquest.py b/tests/factories/test_inquest.py index 97b0918..e601e71 100644 --- a/tests/factories/test_inquest.py +++ b/tests/factories/test_inquest.py @@ -13,13 +13,17 @@ async def client(): yield client +@pytest.fixture +def factory(client: clients.InQuest): + return factories.InQuestVerdictFactory(client) + + @vcr.use_cassette( "tests/fixtures/vcr_cassettes/inquest.yaml", filter_headers=["authorization"] ) # type: ignore @pytest.mark.asyncio -async def test_inquest_factory(client: clients.InQuest): - verdict = await factories.InQuestVerdictFactory.call( +async def test_inquest_factory(factory: factories.InQuestVerdictFactory): + verdict = await factory.call( ["e86c5988a3a6640fb90b90b9e9200e4cce0669594dbb5422622946208c124149"], - client=client, ) assert verdict.malicious is True diff --git a/tests/factories/test_oleid.py b/tests/factories/test_oleid.py index 2f633a5..abe8466 100644 --- a/tests/factories/test_oleid.py +++ b/tests/factories/test_oleid.py @@ -1,18 +1,27 @@ +import pytest + from backend import factories, schemas +@pytest.fixture +def factory(): + return factories.OleIDVerdictFactory() + + def get_attachments(eml_file: bytes) -> list[schemas.Attachment]: - eml = factories.EmlFactory.call(eml_file) + eml = factories.EmlFactory().call(eml_file) return eml.attachments -def test_encrypted_docx(encrypted_docx_eml: bytes): - verdict = factories.OleIDVerdictFactory.call(get_attachments(encrypted_docx_eml)) +def test_encrypted_docx( + encrypted_docx_eml: bytes, factory: factories.OleIDVerdictFactory +): + verdict = factory.call(get_attachments(encrypted_docx_eml)) assert verdict.malicious is True assert len(verdict.details) == 1 -def test_sample(sample_eml: bytes): - verdict = factories.OleIDVerdictFactory.call(get_attachments(sample_eml)) +def test_sample(sample_eml: bytes, factory: factories.OleIDVerdictFactory): + verdict = factory.call(get_attachments(sample_eml)) assert verdict.malicious is False assert len(verdict.details) == 1 diff --git a/tests/factories/test_spamassassin.py b/tests/factories/test_spamassassin.py index 512081c..d8ab19c 100644 --- a/tests/factories/test_spamassassin.py +++ b/tests/factories/test_spamassassin.py @@ -3,10 +3,13 @@ from backend import clients, factories +@pytest.fixture +def factory(spam_assassin: clients.SpamAssassin): + return factories.SpamAssassinVerdictFactory(spam_assassin) + + @pytest.mark.asyncio -async def test_sample(sample_eml: bytes, spam_assassin: clients.SpamAssassin): - verdict = await factories.SpamAssassinVerdictFactory.call( - sample_eml, client=spam_assassin - ) +async def test_sample(sample_eml: bytes, factory: factories.SpamAssassinVerdictFactory): + verdict = await factory.call(sample_eml) assert verdict.malicious is False assert len(verdict.details) > 0 diff --git a/tests/factories/test_urlscan.py b/tests/factories/test_urlscan.py index 73970c3..b528689 100644 --- a/tests/factories/test_urlscan.py +++ b/tests/factories/test_urlscan.py @@ -13,12 +13,15 @@ async def client(): yield client +@pytest.fixture +def factory(client: clients.UrlScan): + return factories.UrlScanVerdictFactory(client) + + @vcr.use_cassette( "tests/fixtures/vcr_cassettes/urlscan.yaml", filter_headers=["api-key"] ) # type: ignore @pytest.mark.asyncio -async def test_urlscan_factory(client: clients.UrlScan): - verdict = await factories.UrlScanVerdictFactory.call( - ["http://example.com"], client=client - ) +async def test_urlscan_factory(factory: factories.UrlScanVerdictFactory): + verdict = await factory.call(["http://example.com"]) assert verdict.malicious is False diff --git a/tests/factories/test_virustotal.py b/tests/factories/test_virustotal.py index 2cfd488..f359d40 100644 --- a/tests/factories/test_virustotal.py +++ b/tests/factories/test_virustotal.py @@ -11,11 +11,15 @@ async def client(): yield client +@pytest.fixture +def factory(client: clients.VirusTotal): + return factories.VirusTotalVerdictFactory(client) + + @pytest.mark.skip(reason="VCR cannot handle this...") @pytest.mark.asyncio -async def test_virus_total_factory(client: clients.VirusTotal): - verdict = await factories.VirusTotalVerdictFactory.call( +async def test_virus_total_factory(factory: factories.VirusTotalVerdictFactory): + verdict = await factory.call( ["275a021bbfb6489e54d471899f7db9d1663fc695ec2fe2a2c4538aabf651fd0f"], - client=client, ) assert verdict.malicious is True