-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
102c73d
commit a32817d
Showing
4 changed files
with
81 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,81 @@ | ||
from streamlit.testing.v1 import AppTest | ||
from streamlit.testing.v1.element_tree import Tab | ||
|
||
from distribution_zoo import Lang | ||
|
||
from utils import at_normal | ||
from utils import get_test_output_dir | ||
|
||
import json | ||
import random | ||
import re | ||
|
||
|
||
def test_extract_normal(at_normal): | ||
def save_params(params: dict, dist_name: str) -> None: | ||
params_dir = get_test_output_dir() / dist_name | ||
params_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
info_section = at_normal.main[4] | ||
code_section = info_section.tabs[2] | ||
assert code_section.label == 'Code' | ||
with open(params_dir / 'params.json', 'w') as f: | ||
json.dump(params, f, indent=2) | ||
|
||
|
||
def extract_blocks_from_code_tab(code_tab: Tab, dist_name: str) -> None: | ||
lang_d_name = code_tab.label | ||
lang_fence = Lang.convert(lang_d_name, input_type='d_name', output_type='fence') | ||
lang_ext = Lang.convert(lang_d_name, input_type='d_name', output_type='ext') | ||
|
||
code_dir = get_test_output_dir() / dist_name / lang_fence | ||
code_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
cpp_section = code_section[1].tabs[0] | ||
assert cpp_section.label == 'C++' | ||
block_pattern = re.compile(f'```{lang_fence}(.*?)```', re.DOTALL) | ||
|
||
pre_md = cpp_section[0].markdown[0].value | ||
code_md = cpp_section[1].value | ||
pre_md = code_tab.markdown[0].value | ||
code_md = code_tab.markdown[1].value | ||
|
||
cpp_block_pattern = re.compile(r'```cpp(.*?)```', re.DOTALL) | ||
pre_matches = re.findall(block_pattern, pre_md) | ||
code_matches = re.findall(block_pattern, code_md) | ||
|
||
pre_code_matches = re.findall(cpp_block_pattern, pre_md) | ||
main_code_matches = re.findall(cpp_block_pattern, code_md) | ||
assert len(pre_matches) == 1 or len(pre_matches) == 0 | ||
assert len(code_matches) == 3 | ||
|
||
assert len(pre_code_matches) == 1 | ||
assert len(main_code_matches) == 3 | ||
pre = pre_matches[0] if pre_matches else '' | ||
|
||
code_dir = get_test_output_dir() / 'normal' / 'cpp' | ||
code_dir.mkdir(exist_ok=True, parents=True) | ||
with open(code_dir / f'pdf{lang_ext}', 'w') as f: | ||
f.write(f'{pre}{code_matches[0]}') | ||
|
||
with open(code_dir / 'pdf.cpp', 'w') as f: | ||
f.write(f'{pre_code_matches[0]}{main_code_matches[0]}') | ||
with open(code_dir / f'logpdf{lang_ext}', 'w') as f: | ||
f.write(f'{pre}{code_matches[1]}') | ||
|
||
with open(code_dir / 'logpdf.cpp', 'w') as f: | ||
f.write(f'{pre_code_matches[0]}{main_code_matches[1]}') | ||
with open(code_dir / f'rvs{lang_ext}', 'w') as f: | ||
f.write(f'{pre}{code_matches[2]}') | ||
|
||
with open(code_dir / 'rvs.cpp', 'w') as f: | ||
f.write(f'{pre_code_matches[0]}{main_code_matches[2]}') | ||
|
||
def test_extract_normal(at_normal: AppTest): | ||
|
||
test_params = { | ||
'param_range_start': round(random.uniform(-10.0, -5.0), 1), | ||
'param_range_end': round(random.uniform(5.0, 10.0), 1), | ||
'param_mean': round(random.uniform(-5.0, 5.0), 1), | ||
'param_std': round(random.uniform(0.5, 15.0), 1), | ||
} | ||
|
||
save_params(test_params, 'normal') | ||
|
||
at_normal.slider(key='normal_range').set_value((test_params['param_range_start'], test_params['param_range_end'])).run() | ||
at_normal.slider(key='normal_mean').set_value(test_params['param_mean']).run() | ||
at_normal.slider(key='normal_std').set_value(test_params['param_std']).run() | ||
|
||
info_section = at_normal.main[4] | ||
code_section = info_section.tabs[2] | ||
assert code_section.label == 'Code' | ||
|
||
for tab in code_section[1].tabs: | ||
extract_blocks_from_code_tab(tab, 'normal') | ||
|
||
def test_code_files_exist(): | ||
assert (get_test_output_dir() / 'normal' / 'params.json').is_file() | ||
|
||
assert (get_test_output_dir() / 'normal' / 'cpp' / 'pdf.cpp').is_file() | ||
assert (get_test_output_dir() / 'normal' / 'cpp' / 'logpdf.cpp').is_file() | ||
assert (get_test_output_dir() / 'normal' / 'cpp' / 'rvs.cpp').is_file() | ||
for fence in ['cpp', 'python']: | ||
ext = Lang.convert(fence, input_type='fence', output_type='ext') | ||
assert (get_test_output_dir() / 'normal' / fence / f'pdf{ext}').is_file() | ||
assert (get_test_output_dir() / 'normal' / fence / f'logpdf{ext}').is_file() | ||
assert (get_test_output_dir() / 'normal' / fence / f'rvs{ext}').is_file() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters