Skip to content

Commit

Permalink
Drafting infer_fieldnames_from_function_return_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Yomguithereal committed Dec 19, 2023
1 parent 21fcd41 commit 01e19bd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
26 changes: 25 additions & 1 deletion minet/scrape/classes/function.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
from typing import Optional, Callable, Any, cast, Dict
from typing import Union, Optional, Callable, Any, cast, Dict, List

import inspect

from casanova import RowWrapper
from bs4 import SoupStrainer

from minet.types import get_type_hints, get_origin, get_args
from minet.scrape.classes.base import ScraperBase
from minet.scrape.soup import WonderfulSoup
from minet.scrape.straining import strainer_from_css
from minet.scrape.utils import ensure_soup
from minet.scrape.types import AnyScrapableTarget


def infer_fieldnames_from_function_return_type(fn: Callable) -> Optional[List[str]]:
if not callable(fn):
raise TypeError

return_type = get_type_hints(fn)["return"]

origin = get_origin(return_type)

if origin is Union:
args = get_args(return_type)

# Optionals
if len(args) == 2:
if args[1] is type(None):
return_type = args[0]

if return_type in (str, int, float, bool, type(None)):
return ["value"]

return None


class FunctionScraper(ScraperBase):
fn: Callable[[RowWrapper, WonderfulSoup], Any]
fieldnames = None
Expand Down
33 changes: 32 additions & 1 deletion test/scraper_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# =============================================================================
# Minet Scrape Unit Tests
# =============================================================================
from typing import Optional

import pytest
from bs4 import BeautifulSoup, Tag, SoupStrainer
from textwrap import dedent
Expand Down Expand Up @@ -31,6 +33,7 @@
ScraperValidationMixedConcernError,
ScraperValidationUnknownKeyError,
)
from minet.scrape.classes.function import infer_fieldnames_from_function_return_type

BASIC_HTML = """
<ul>
Expand Down Expand Up @@ -160,7 +163,7 @@
"""


class TestDefinitionScraper(object):
class TestDefinitionScraper:
def test_basics(self):
result = scrape({"iterator": "li"}, BASIC_HTML)

Expand Down Expand Up @@ -1055,3 +1058,31 @@ def clean(t):
text = get_display_text(elements)

assert text == "L'internationale."


class TestFunctionScraper:
def test_infer_fieldnames_from_function_return_type(self):
def basic_string() -> str:
return "ok"

def basic_int() -> int:
return 4

def basic_float() -> float:
return 4.0

def basic_bool() -> bool:
return True

def basic_void() -> None:
return

def basic_optional_scalar() -> Optional[str]:
return

assert infer_fieldnames_from_function_return_type(basic_string) == ["value"]
assert infer_fieldnames_from_function_return_type(basic_int) == ["value"]
assert infer_fieldnames_from_function_return_type(basic_float) == ["value"]
assert infer_fieldnames_from_function_return_type(basic_bool) == ["value"]
assert infer_fieldnames_from_function_return_type(basic_void) == ["value"]
assert infer_fieldnames_from_function_return_type(basic_optional_scalar) == ["value"]

0 comments on commit 01e19bd

Please sign in to comment.