Skip to content

Commit

Permalink
Generalise the traversal algorithm to return node values
Browse files Browse the repository at this point in the history
Decouples the computation of genetic values from output in terms of
individuals to enable later generalisations.
  • Loading branch information
jeromekelleher committed Oct 3, 2024
1 parent 8444038 commit 7ab3653
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions tstrait/genetic_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,45 @@


@numba.njit
def _traversal_genetic_value(
nodes_individual,
def _compute_nodes_genetic_value(
left_child_array,
right_sib_array,
stack,
has_mutation,
num_individuals,
num_nodes,
effect_size,
): # pragma: no cover
"""
Numba to speed up the tree traversal algorithm to determine the genotype of
individuals.
Stack has to be Typed List in numba to use numba.
Compute the genetic value of each node for the specified set of mutations
encoded in the stack.
"""

genetic_value = np.zeros(num_individuals)
genetic_value = np.zeros(num_nodes)
while len(stack) > 0:
parent_node_id = stack.pop()
if parent_node_id == num_nodes:
individual_id = -1
else:
individual_id = nodes_individual[parent_node_id]
if individual_id > -1:
genetic_value[individual_id] += effect_size
genetic_value[parent_node_id] = effect_size
child_node_id = left_child_array[parent_node_id]
while child_node_id != -1:
if not has_mutation[child_node_id]:
stack.append(child_node_id)
child_node_id = right_sib_array[child_node_id]

return genetic_value


@numba.njit
def _accumulate_individual_values(
nodes_genetic_value, nodes_individual, num_nodes, num_individuals
): # pragma: no cover
"""
Accumulate the genetic values by summing their node contributions.
"""
individuals_genetic_value = np.zeros(num_individuals)
for u in range(num_nodes):
ind = nodes_individual[u]
if ind != -1:
individuals_genetic_value[ind] += nodes_genetic_value[u]
return individuals_genetic_value


class _GeneticValue:
"""GeneticValue class to compute genetic values of individuals.
Expand Down Expand Up @@ -75,20 +80,24 @@ def _individual_genetic_values(self, tree, site, causal_allele, effect_size):
stack.append(node)

if len(stack) == 0:
genetic_value = np.zeros(self.ts.num_individuals)
genetic_value = np.zeros(self.ts.num_nodes)
else:
genetic_value = _traversal_genetic_value(
nodes_individual=self.ts.nodes_individual,
genetic_value = _compute_nodes_genetic_value(
left_child_array=tree.left_child_array,
right_sib_array=tree.right_sib_array,
stack=stack,
has_mutation=has_mutation,
num_individuals=self.ts.num_individuals,
num_nodes=self.ts.num_nodes,
effect_size=effect_size,
)

return genetic_value
individuals_genetic_value = _accumulate_individual_values(
genetic_value,
self.ts.nodes_individual,
self.ts.num_nodes,
self.ts.num_individuals,
)
return individuals_genetic_value

def _run(self):
"""Computes genetic values of individuals.
Expand Down

0 comments on commit 7ab3653

Please sign in to comment.