Skip to content

Commit

Permalink
Merge pull request #197 from svalinn/fix_expand_list
Browse files Browse the repository at this point in the history
  • Loading branch information
gonuke authored Feb 19, 2025
2 parents 7925e5d + 5d93341 commit 8e0d597
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
24 changes: 16 additions & 8 deletions parastell/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,47 @@ def enforce_helical_symmetry(matrix):
return matrix


def expand_list(list, num):
def expand_list(list_to_expand, num):
"""Expands a list of ordered floats to a total number of entries by
linearly interpolating between entries, inserting a proportional number of
new entries between original entries. If num < len(list), no entries are
added.
new entries between original entries. If num <= len(list), list_to_expand
is not modified. It is possible that the result will have slightly more
or fewer elements than num, due to round off approximations.
Arguments:
list (iterable of float): list to be expanded.
list_to_expand (iterable of float): list to be expanded.
num (int): desired number of entries in expanded list.
Returns:
list_exp (iterable of float): expanded list.
"""
if len(list_to_expand) >= num:
return list_to_expand

list_exp = []

init_entry = list[0]
final_entry = list[-1]
init_entry = list_to_expand[0]
final_entry = list_to_expand[-1]
extent = final_entry - init_entry

avg_diff = extent / (num - 1)

for entry, next_entry in zip(list[:-1], list[1:]):
for entry, next_entry in zip(list_to_expand[:-1], list_to_expand[1:]):
# Only add entries to current block if difference between entry and
# next_entry is greater than desired average
num_new_entries = 0

if next_entry - entry > avg_diff:
num_new_entries = int(round(next_entry - entry / avg_diff))
# Goal is to create bins of approximately avg_diff width between
# entry and next_entry
num_new_entries = int(round((next_entry - entry) / avg_diff)) - 1

# Manually append first entry
list_exp = np.append(list_exp, entry)

# If num_new_entries == 0, don't add new entries
# First and last elements of new_entries are entry and next_entry,
# respectively
new_entries = np.linspace(
entry,
next_entry,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,49 @@ def test_dagmc_renumbering():
assert len(combined_model.volumes) == num_vol_exp
assert max(combined_model.volumes_by_id.keys()) == max_vol_id_exp
assert all(mat in mats_exp for mat in mats)


def test_expand_list():
"""Tests utils.expand_list() to ensure returned arrays are the length
expected, and contain the expected values, by testing if:
* the expected entries are added to uniformly and non-uniformly spaced
lists
* entries are added when the requested size is less than or equal to
that of the input list (no entries should be added)
"""
# Make sure new entries are inserted as expected
test_values = np.linspace(1, 10, 10)
exp_expanded_list = np.linspace(1, 10, 19)
expanded_list = expand_list(test_values, 19)
assert np.allclose(exp_expanded_list, expanded_list)

# Make sure no changes are made if list already has the requested number of
# entries
expanded_list = expand_list(test_values, 10)
assert len(expanded_list) == len(test_values)
assert np.allclose(expanded_list, test_values)

# Make sure no changes are made if list has more than the requested number
# of entries
expanded_list = expand_list(test_values, 5)
assert len(expanded_list) == len(test_values)
assert np.allclose(expanded_list, test_values)

# Make sure it works with unevenly spaced entries
test_values = [1, 5, 6, 7, 10]
exp_expanded_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expanded_list = expand_list(test_values, 10)
assert np.allclose(expanded_list, exp_expanded_list)

# Make sure it works with unevenly spaced entries that are not
# nicely divisible
test_values = [1, 4.5, 6, 7, 10]
expanded_list = expand_list(test_values, 5)
assert len(expanded_list) == 5

# int math makes this list one element longer than requested
test_values = [1, 4.5, 6, 7, 10]
expected_values = [1, 1.875, 2.75, 3.625, 4.5, 5.25, 6, 7, 8, 9, 10]
expanded_list = expand_list(test_values, 10)
assert len(expanded_list) == 11
assert np.allclose(expected_values, expanded_list)

0 comments on commit 8e0d597

Please sign in to comment.