Skip to content

Commit

Permalink
Extract all code from Normal dist
Browse files Browse the repository at this point in the history
  • Loading branch information
fcooper8472 committed Jan 3, 2024
1 parent 102c73d commit a32817d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 35 deletions.
19 changes: 14 additions & 5 deletions .github/workflows/test-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,32 @@ jobs:
python -m pip install poetry
poetry install
- name: Test with pytest
- name: Run code extraction
run: |
poetry run pytest tests/extract.py
- name: Install C++ deps
- name: C++ deps
run: |
sudo apt install libboost-math-dev
- name: Compile and run C++ files
run: |
g++ -o my_program pdf.cpp -lboost_math
g++ -o my_program logpdf.cpp -lboost_math
g++ -o my_program rvs.cpp -lboost_math
g++ -o pdf pdf.cpp
g++ -o logpdf logpdf.cpp
g++ -o rvs rvs.cpp
./pdf > pdf.out
./logpdf > logpdf.out
./rvs > rvs.out
cat rvs.out
working-directory: tests/test_output/normal/cpp

- name: Run Python files
run: |
python pdf.py
python logpdf.py
python rvs.py
cat rvs.out
working-directory: tests/test_output/normal/python
6 changes: 3 additions & 3 deletions distribution_zoo/cont_uni/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def __init__(self):
def sliders(self):

self.param_range_start, self.param_range_end = st.sidebar.slider(
'Range', min_value=self.range_min, max_value=self.range_max, value=(-10.0, 10.0), step=0.1
'Range', min_value=self.range_min, max_value=self.range_max, value=(-10.0, 10.0), step=0.1, key='normal_range'
)

self.param_mean = st.sidebar.slider(
r'Mean ($\mu$)', min_value=-30.0, max_value=30.0, value=0.0, step=0.1
r'Mean ($\mu$)', min_value=-30.0, max_value=30.0, value=0.0, step=0.1, key='normal_mean'
)

self.param_std = st.sidebar.slider(
r'Standard deviation ($\sigma$)', min_value=0.1, max_value=20.0, value=1.0, step=0.1
r'Standard deviation ($\sigma$)', min_value=0.1, max_value=20.0, value=1.0, step=0.1, key='normal_std'
)

def plot(self):
Expand Down
87 changes: 62 additions & 25 deletions tests/extract.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,81 @@
from streamlit.testing.v1 import AppTest
from streamlit.testing.v1.element_tree import Tab

from distribution_zoo import Lang

from utils import at_normal
from utils import get_test_output_dir

import json
import random
import re


def test_extract_normal(at_normal):
def save_params(params: dict, dist_name: str) -> None:
params_dir = get_test_output_dir() / dist_name
params_dir.mkdir(parents=True, exist_ok=True)

info_section = at_normal.main[4]
code_section = info_section.tabs[2]
assert code_section.label == 'Code'
with open(params_dir / 'params.json', 'w') as f:
json.dump(params, f, indent=2)


def extract_blocks_from_code_tab(code_tab: Tab, dist_name: str) -> None:
lang_d_name = code_tab.label
lang_fence = Lang.convert(lang_d_name, input_type='d_name', output_type='fence')
lang_ext = Lang.convert(lang_d_name, input_type='d_name', output_type='ext')

code_dir = get_test_output_dir() / dist_name / lang_fence
code_dir.mkdir(parents=True, exist_ok=True)

cpp_section = code_section[1].tabs[0]
assert cpp_section.label == 'C++'
block_pattern = re.compile(f'```{lang_fence}(.*?)```', re.DOTALL)

pre_md = cpp_section[0].markdown[0].value
code_md = cpp_section[1].value
pre_md = code_tab.markdown[0].value
code_md = code_tab.markdown[1].value

cpp_block_pattern = re.compile(r'```cpp(.*?)```', re.DOTALL)
pre_matches = re.findall(block_pattern, pre_md)
code_matches = re.findall(block_pattern, code_md)

pre_code_matches = re.findall(cpp_block_pattern, pre_md)
main_code_matches = re.findall(cpp_block_pattern, code_md)
assert len(pre_matches) == 1 or len(pre_matches) == 0
assert len(code_matches) == 3

assert len(pre_code_matches) == 1
assert len(main_code_matches) == 3
pre = pre_matches[0] if pre_matches else ''

code_dir = get_test_output_dir() / 'normal' / 'cpp'
code_dir.mkdir(exist_ok=True, parents=True)
with open(code_dir / f'pdf{lang_ext}', 'w') as f:
f.write(f'{pre}{code_matches[0]}')

with open(code_dir / 'pdf.cpp', 'w') as f:
f.write(f'{pre_code_matches[0]}{main_code_matches[0]}')
with open(code_dir / f'logpdf{lang_ext}', 'w') as f:
f.write(f'{pre}{code_matches[1]}')

with open(code_dir / 'logpdf.cpp', 'w') as f:
f.write(f'{pre_code_matches[0]}{main_code_matches[1]}')
with open(code_dir / f'rvs{lang_ext}', 'w') as f:
f.write(f'{pre}{code_matches[2]}')

with open(code_dir / 'rvs.cpp', 'w') as f:
f.write(f'{pre_code_matches[0]}{main_code_matches[2]}')

def test_extract_normal(at_normal: AppTest):

test_params = {
'param_range_start': round(random.uniform(-10.0, -5.0), 1),
'param_range_end': round(random.uniform(5.0, 10.0), 1),
'param_mean': round(random.uniform(-5.0, 5.0), 1),
'param_std': round(random.uniform(0.5, 15.0), 1),
}

save_params(test_params, 'normal')

at_normal.slider(key='normal_range').set_value((test_params['param_range_start'], test_params['param_range_end'])).run()
at_normal.slider(key='normal_mean').set_value(test_params['param_mean']).run()
at_normal.slider(key='normal_std').set_value(test_params['param_std']).run()

info_section = at_normal.main[4]
code_section = info_section.tabs[2]
assert code_section.label == 'Code'

for tab in code_section[1].tabs:
extract_blocks_from_code_tab(tab, 'normal')

def test_code_files_exist():
assert (get_test_output_dir() / 'normal' / 'params.json').is_file()

assert (get_test_output_dir() / 'normal' / 'cpp' / 'pdf.cpp').is_file()
assert (get_test_output_dir() / 'normal' / 'cpp' / 'logpdf.cpp').is_file()
assert (get_test_output_dir() / 'normal' / 'cpp' / 'rvs.cpp').is_file()
for fence in ['cpp', 'python']:
ext = Lang.convert(fence, input_type='fence', output_type='ext')
assert (get_test_output_dir() / 'normal' / fence / f'pdf{ext}').is_file()
assert (get_test_output_dir() / 'normal' / fence / f'logpdf{ext}').is_file()
assert (get_test_output_dir() / 'normal' / fence / f'rvs{ext}').is_file()
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def app_test() -> AppTest:
def at_normal() -> AppTest:
at = AppTest.from_file(script_path=str(app_file_path), default_timeout=100.0).run()
at.sidebar.selectbox(key='dist_class').select('Continuous Univariate').run()
at.sidebar.selectbox(key='dist').select('Normal').run()

at = at.sidebar.selectbox(key='dist').select('Normal').run()
assert at.main.header[0].value == 'Normal distribution'

return at


def get_test_output_dir():
print(Path(__file__).parent / 'test_output')
return Path(__file__).parent / 'test_output'

0 comments on commit a32817d

Please sign in to comment.