Conversation
…s in SumUnit and ProbabilisticCircuit
tomsch420
left a comment
There was a problem hiding this comment.
Proivde unittests and a notebook in the docs that showcases pruning and growing (refactor your examples to a markdown notebook, see the doc notebooks for refrence)
src/probabilistic_model/probabilistic_circuit/rx/probabilistic_circuit.py
Outdated
Show resolved
Hide resolved
src/probabilistic_model/probabilistic_circuit/rx/probabilistic_circuit.py
Outdated
Show resolved
Hide resolved
src/probabilistic_model/probabilistic_circuit/rx/probabilistic_circuit.py
Outdated
Show resolved
Hide resolved
tomsch420
left a comment
There was a problem hiding this comment.
This is way nicer than before but not fully finished yet. The comments tell you what to do. Also read this: https://testing.googleblog.com/2017/11/obsessed-with-primitives.html and replace the dict types by proper registries.
src/probabilistic_model/probabilistic_circuit/rx/flow_analyzer.py
Outdated
Show resolved
Hide resolved
src/probabilistic_model/probabilistic_circuit/rx/flow_analyzer.py
Outdated
Show resolved
Hide resolved
src/probabilistic_model/probabilistic_circuit/rx/flow_analyzer.py
Outdated
Show resolved
Hide resolved
src/probabilistic_model/probabilistic_circuit/rx/flow_analyzer.py
Outdated
Show resolved
Hide resolved
src/probabilistic_model/probabilistic_circuit/rx/probabilistic_circuit.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull Request Overview
This PR implements pruning and growing functionality for probabilistic circuits based on the Sparse Probabilistic Circuits paper by Dang et al. The implementation enables structural optimization of circuits by removing less important edges (pruning) and adding new components with noise (growing).
- Adds flow analysis capability to compute edge importance based on data flows
- Implements pruning method to remove low-importance edges based on flow analysis
- Implements growing method to duplicate circuit structure with noise injection
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
src/probabilistic_model/probabilistic_circuit/rx/probabilistic_circuit.py |
Adds prune() and grow() methods to ProbabilisticCircuit class |
src/probabilistic_model/probabilistic_circuit/rx/flow_analyzer.py |
New CircuitFlowAnalyzer class for computing edge flows through circuits |
doc/references.bib |
Adds citation for the sparse probabilistic circuits paper |
doc/pruning_growing.md |
Comprehensive tutorial demonstrating pruning and growing techniques |
doc/_toc.yml |
Adds the new tutorial to the documentation table of contents |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| def __repr__(self): | ||
| return f"{self.__class__.__name__} with {len(self.nodes())} nodes and {len(self.edges())} edges" | ||
|
|
||
| def prune(self, dataset: np.ndarray, pruning_percentage: float) -> Self: |
There was a problem hiding this comment.
[nitpick] The prune method modifies the circuit in-place and returns self, but this could be confusing since users might expect it to return a new pruned circuit. Consider either making this operation purely in-place (return None) or making it return a new circuit instance to avoid ambiguity.
| self.normalize() | ||
| return self | ||
|
|
||
| def grow(self, noise_variance: float) -> Self: |
There was a problem hiding this comment.
[nitpick] Similar to the prune method, the grow method modifies the circuit in-place and returns self, which could be confusing. Consider making the API consistent by either returning None for in-place operations or returning new instances.
| parent_in_new_values = any(parent is value for value in old2new.values()) | ||
| child_in_new_values = any(child is value for value in old2new.values()) | ||
| parent_not_in_keys = not any(parent is key for key in old2new.keys()) | ||
| child_not_in_keys = not any(child is key for key in old2new.keys()) | ||
| if parent_in_new_values or child_in_new_values or parent_not_in_keys or child_not_in_keys: |
There was a problem hiding this comment.
These identity checks using 'any()' with generator expressions are inefficient for large circuits. Since old2new is a dictionary, use 'parent in old2new' and 'parent in old2new.values()' instead, or better yet, use sets to track which nodes belong to which group.
| from collections import defaultdict | ||
| from typing import Dict, Tuple, TYPE_CHECKING | ||
| import numpy as np | ||
| import tqdm | ||
| if TYPE_CHECKING: | ||
| from .probabilistic_circuit import ProbabilisticCircuit, Unit |
There was a problem hiding this comment.
Import statements should be grouped and sorted. Standard library imports (collections, typing) should come first, followed by third-party imports (numpy, tqdm), then local imports. Consider using 'from tqdm import tqdm' for cleaner code.
| def compute_flows(self, dataset: np.ndarray) -> Dict[Tuple['Unit', 'Unit'], float]: | ||
| """ | ||
| Compute the flow of information through the circuit for a given dataset. | ||
|
|
||
| :param dataset: The input dataset. | ||
| :return: Dictionary mapping edge tuples to their flow values. | ||
| """ | ||
| edge_flows = defaultdict(float) | ||
|
|
||
| for x in tqdm.tqdm(dataset, desc="Computing circuit flows"): |
There was a problem hiding this comment.
Hard-coded progress bar description should be configurable or removed for library code. Consider adding a 'show_progress' parameter to allow users to control progress bar display.
| def compute_flows(self, dataset: np.ndarray) -> Dict[Tuple['Unit', 'Unit'], float]: | |
| """ | |
| Compute the flow of information through the circuit for a given dataset. | |
| :param dataset: The input dataset. | |
| :return: Dictionary mapping edge tuples to their flow values. | |
| """ | |
| edge_flows = defaultdict(float) | |
| for x in tqdm.tqdm(dataset, desc="Computing circuit flows"): | |
| def compute_flows(self, dataset: np.ndarray, show_progress: bool = False) -> Dict[Tuple['Unit', 'Unit'], float]: | |
| """ | |
| Compute the flow of information through the circuit for a given dataset. | |
| :param dataset: The input dataset. | |
| :param show_progress: If True, display a progress bar during computation. Default is False. | |
| :return: Dictionary mapping edge tuples to their flow values. | |
| """ | |
| edge_flows = defaultdict(float) | |
| iterator = tqdm.tqdm(dataset, desc="Computing circuit flows") if show_progress else dataset | |
| for x in iterator: |
| if hasattr(parent, '_flow_likelihood') and parent._flow_likelihood > 0: | ||
| if hasattr(node, '_flow_likelihood'): | ||
| contribution = node._flow_likelihood / parent._flow_likelihood | ||
| node_flows[node] += contribution * node_flows[parent] |
There was a problem hiding this comment.
Division by zero protection is insufficient. While checking parent._flow_likelihood > 0, there's no check for node._flow_likelihood existence before the division. This could cause issues if node doesn't have the attribute but parent does.
| if hasattr(parent, '_flow_likelihood') and parent._flow_likelihood > 0: | |
| if hasattr(node, '_flow_likelihood'): | |
| contribution = node._flow_likelihood / parent._flow_likelihood | |
| node_flows[node] += contribution * node_flows[parent] | |
| if hasattr(parent, '_flow_likelihood') and parent._flow_likelihood > 0 and hasattr(node, '_flow_likelihood'): | |
| contribution = node._flow_likelihood / parent._flow_likelihood | |
| node_flows[node] += contribution * node_flows[parent] |
| weight = 0.0 | ||
|
|
||
| if hasattr(child, '_flow_likelihood'): | ||
| edge_flow = (np.exp(weight) * child._flow_likelihood / node._flow_likelihood * node_flows[node]) |
There was a problem hiding this comment.
Potential division by zero if node._flow_likelihood is zero or very close to zero. Add a check to ensure node._flow_likelihood > 0 before performing the division.
This PR adds pruning and growing functionality to the probabilistic circuits implementation. See Dang et al. Sparse Probabilistic Circuits via Pruning and Growing for more details.
Try
pruneandgrowseparately using the example scripts in the respective folder.The
jpt_pruning_growing_demo.pyandjpt_pruning_growing_mnist.pydemo scripts show how pruning and growing can be applied on a learned JPT.