This project is a framework for High-Energy Physics (HEP) analysis that leverages automatic differentiation to optimise analysis selections for maximal statistical significance. It is built on top of the scientific Python ecosystem, including coffea
, awkward-array
, and uproot
for data handling, and JAX
for gradient-based optimisation.
The example analysis implemented here searches for a Z' boson decaying to a top-antitop quark pair (tt̄).
This section guides you through running the default analysis configuration provided in the repository.
Before running, you must set up the Python environment and install the required dependencies. The recommended method is to use Pixi.
- Install Pixi.
- Run
pixi run start
from the top level of the repository.
This will launch a JupyterLab interface with the environment described the by pixi.toml
Pixi manifest and instantiated in the pixi.lock
lock file activated.
To interactively use the environment in your local machine's shell run
pixi shell
from the top level of the repository.
The environment.yml
file contains all the necessary packages. Create and activate the conda environment with the following commands:
conda env create -f environment.yml
conda activate zprime_diff_analysis
We also provide a requirements.txt
file, you can still leverage conda
for environment management:
# Create a new environment with Python 3.10 (or adjust version as needed)
conda create -n zprime_diff_analysis python=3.10
# Activate the environment
conda activate zprime_diff_analysis
# Install all dependencies from requirements.txt
pip install -r requirements.txt
Alternatively, you can use Python’s built-in virtual environment module:
# Create a virtual environment in a folder named .venv
python3 -m venv .venv
# Activate the virtual environment
# On macOS/Linux:
source .venv/bin/activate
# Install dependencies
pip install --upgrade pip
pip install -r requirements.txt
The analysis expects pre-processed data files. If you do not have them, you can generate them by running the pre-processing step. This will download the necessary data from the CERN Open Data Portal and skim it according to the configuration.
# This command overrides the default config to run only the pre-processing step.
# It may take a while to download and process the data.
python run.py general.run_preprocessing=True general.run_mva_training=False general.analysis=nondiff general.run_histogramming=False general.run_statistics=False
Once the pre-processed data is available, you can run the main analysis with a single command:
python run.py
The default configuration (user/configuration.py
) is set up to perform a differentiable analysis. The command above will:
- MVA Pre-training: First, it trains a small, JAX-based neural network to distinguish between
W+jets
andttbar
background events. The trained model parameters are saved to disk. - Differentiable Optimisation: It then runs the main analysis optimisation loop. The goal is to find the selection cuts that maximise the statistical significance of the Z' signal. At each step, it calculates the gradient of the significance with respect to the cut thresholds (e.g.,
met_threshold
,btag_threshold
) and uses theoptax
optimiser to update them. - Outputs: The analysis will produce plots in the
outputs/
directory showing the evolution of the parameters and significance during optimisation, along with the final histograms. The final optimised significance will be printed to the console.
This framework separates user-configurable code from framework code:
-
user/
directory: This is where you make changes for your analysisuser/configuration.py
: Main configuration file - modify this for your analysis settingsuser/cuts.py
: Selection functions - define your analysis regions and cuts hereuser/observables.py
: Physics observables - define what variables you want to computeuser/systematics.py
: Systematic variations - define uncertainty sources
-
analysis/
andutils/
directories: Framework code - you typically don't need to modify these- These contain the analysis infrastructure, plotting utilities, and technical implementation
To adapt this framework for your analysis, focus on modifying the files in the user/
directory. The framework will handle the rest automatically.
- Core Concepts
- How to Implement an Analysis
- Configuration Reference
- Under the Hood: The Differentiable Engine
- The Differentiable Statistical Model in JAX
- Multi-Variate Analysis (MVA) Integration
- Extending the Analysis
- Non-Differentiable Analysis
- Directory Structure
- [Developer] Building the Documentation
The central idea is to treat the entire analysis chain—from event selection to statistical significance—as a single, differentiable function. The inputs to this function are not just the data, but also the analysis parameters themselves (e.g., selection cut thresholds). By calculating the gradient of the significance with respect to these parameters, we can use optimizers like optax
to iteratively update them and find the optimal set of cuts.
coffea
&awkward-array
: For handling complex, jagged data structures typical in HEP.JAX
: For just-in-time (JIT) compilation and automatic differentiation of Python/NumPy code.relaxed
: A JAX-based library for differentiable statistical models, providing a differentiable approximation of the profile likelihood ratio.optax
: A library of gradient-based optimizers for JAX.
The analysis is orchestrated by the DifferentiableAnalysis
class in analysis/diff.py
. The workflow proceeds as follows:
- Preprocessing: Raw NanoAOD files are skimmed to keep only necessary branches and apply a baseline selection. This is a one-time, non-differentiable step to reduce data volume.
- MVA Pre-training (Optional): If configured, a Machine Learning model (e.g., a neural network) is trained on pre-selected data to serve as a powerful discriminator. The weights of this model can themselves become optimizable parameters.
- Event Processing: For each event, object corrections and systematic variations are applied.
- Differentiable Histogramming: Instead of making hard cuts, we apply "soft" selections using sigmoid functions. This results in a per-event weight. Histograms are filled using a Kernel Density Estimation (KDE) approach, which is smooth and differentiable.
- Statistical Significance: The
relaxed
library is used to construct a statistical model from the histograms and compute an asymptotic significance (a differentiable quantity). - Gradient Calculation:
JAX
computes the gradient of the significance with respect to all optimizable parameters (cut thresholds, MVA weights, etc.). - Parameter optimisation: The
optax
optimizer takes a step in the direction of the gradient to update the parameters, aiming to maximize significance. Steps 4-7 are repeated for a set number of iterations.
Implementing a new analysis or modifying the existing one primarily involves three steps:
- Modifying the central configuration file.
- Defining the analysis logic (observables and selections) in Python functions.
- Running the analysis workflow.
This file is the central hub for defining your entire analysis. The config
dictionary controls every aspect of the workflow.
Key Sections:
general
: Global settings like integrated luminosity, which analysis steps to run (run_preprocessing
,run_mva_training
), and file paths.preprocess
: Defines the branches to keep from the input NanoAOD files.good_object_masks
: Defines baseline "good" object criteria (e.g., muon pT > 55 GeV) that are applied before any channel-specific logic. This is useful for creating a common object collection for all analysis channels.baseline_selection
: A hard, non-differentiable selection applied to all events early on.channels
: Defines the different analysis regions (e.g., signal region, control regions). For each channel, you specify:name
: A unique name for the channel.selection
: The selection function to apply for this channel.observables
: A list of variables to be histogrammed.fit_observable
: The specific observable used for the final statistical fit.
ghost_observables
: A powerful feature for computing derived quantities (e.g., ST, ΔR) once and attaching them to the event record. These can then be used by any downstream function.corrections
&systematics
: Define object and event-level corrections and systematic uncertainties. The framework supports bothcorrectionlib
and custom Python functions.mva
: Configure MVA models. You can define the architecture, features, and training parameters for a JAX or TensorFlow/Keras network.jax
: This is the core of the differentiable analysis.params
: A dictionary of all optimizable parameters and their initial values (e.g.,'met_threshold': 50.0
).soft_selection
: Points to the Python function that implements your differentiable selection logic.param_updates
: Defines clamping functions to keep parameters within physical bounds during optimisation (e.g.,jnp.clip(x + d, 0.0, 3.0)
).learning_rates
: Allows you to set custom learning rates for different parameters.
The config
file points to Python functions that contain the actual physics logic. These typically live in user/
.
An observable function takes awkward-array
collections as input and returns a flat array of the computed values.
Example: get_mtt
def get_mtt(
muons: ak.Array,
jets: ak.Array,
fatjets: ak.Array,
met: ak.Array,
) -> ak.Array:
# ... logic to calculate four-vectors and sum them ...
p4tot = p4mu + p4fj + p4j + p4met
return p4tot.mass
There are two types of selection functions:
-
Standard Selections: Used for non-differentiable analysis or initial hard cuts. They take
awkward
arrays and should return acoffea.analysis_tools.PackedSelection
object. -
Differentiable "Soft" Selections: This is where the magic happens. Instead of returning a boolean mask, this function returns a continuous, per-event weight between 0 and 1. This is achieved by replacing hard cuts like
met.pt > 50
with a sigmoid functionjax.nn.sigmoid((met.pt - 50) / scale)
.
Example: Zprime_softcuts_jax_workshop
This function takes JAX-backed awkward arrays and a params
dictionary (containing the optimizable parameters defined in the config).
def Zprime_softcuts_jax_workshop(
muons: ak.Array,
jets: ak.Array,
met: ak.Array,
jet_mass: ak.Array,
nn,
params: dict
) -> jnp.ndarray:
# ...
# A differentiable cut on MET
met_cut_weight = jax.nn.sigmoid(
(ak.to_jax(met) - params["met_threshold"]) / 25.0
)
# A differentiable cut on a b-tagging score
btag_cut_weight = jax.nn.sigmoid(
(soft_b_counts - params["btag_threshold"]) * 10.0
)
# ... other cuts
# Combine all weights multiplicatively (like a logical AND)
selection_weight = jnp.prod(jnp.stack([met_cut_weight, btag_cut_weight, ...]))
return selection_weight
With the configuration and functions in place, you can run the analysis using a top-level script.
A typical script would:
- Load the base configuration from
user/configuration.py
. - Optionally, override configuration settings from the command line.
- Construct the fileset of data samples.
- Instantiate the
DifferentiableAnalysis
class fromanalysis/diff.py
. - Call the main
run_analysis_optimisation
method.
# In a hypothetical run.py
import sys
from analysis.diff import DifferentiableAnalysis
from user.configuration import config
from utils.schema import load_config_with_restricted_cli
from utils.input_files import construct_fileset
if __name__ == "__main__":
# Load base config and override with CLI args
cfg = load_config_with_restricted_cli(config, sys.argv[1:])
fileset = construct_fileset(n_files_max_per_sample=cfg.general.max_files)
analysis = DifferentiableAnalysis(cfg)
final_histograms, final_significance = analysis.run_analysis_optimisation(fileset)
print(f"optimisation complete! Final significance: {final_significance:.3f}")
You can override certain configuration options directly from the command line using a dot-list format. This is useful for quick tests and batch submissions without modifying the main configuration file.
Example:
python run.py general.max_files=10 general.run_systematics=False
Important Limitations: For safety and to prevent breaking the analysis logic, only a restricted set of configuration keys can be overridden from the command line. This is because the main Python configuration file contains complex objects like functions and lambdas, which cannot be expressed as simple command-line arguments.
The allowed top-level keys for CLI overrides are:
general
preprocess
statistics
Attempting to override other keys (e.g., jax.params
) will result in an error. To change these, you must edit the user/configuration.py
file directly.
The analysis is controlled by a central configuration dictionary, typically defined in user/configuration.py
.
The structure of this configuration is validated against a Pydantic schema in utils/schema.py
.
Below is a comprehensive reference for all available options, grouped by their top-level key.
Global settings that control the overall workflow of the analysis.
Parameter | Type | Default | Description |
---|---|---|---|
lumi |
float |
Required | Integrated luminosity in inverse picobarns (/pb). |
weights_branch |
str |
Required | Branch name containing event weights (e.g. genWeight ). |
lumifile |
str |
Required | Path to the JSON file containing certified good luminosity sections (Golden JSON). |
analysis |
str |
"nondiff" |
Analysis mode: "nondiff" , "diff" , or "both" . |
max_files |
int |
-1 |
Max number of files per dataset. -1 = unlimited. |
run_preprocessing |
bool |
False |
Run NanoAOD skimming and filtering. |
run_histogramming |
bool |
True |
Run histogramming for non-differentiable analysis. |
run_statistics |
bool |
True |
Run statistical analysis step (e.g. cabinetry fit). |
run_systematics |
bool |
True |
Process systematic variations for non-differentiable analysis. |
run_plots_only |
bool |
False |
Generate plots from cached results only. |
run_mva_training |
bool |
False |
Run MVA model pre-training. |
read_from_cache |
bool |
True |
Read preprocessed data from cache if available. |
output_dir |
str |
"output/" |
Root directory for all analysis outputs. |
preprocessor |
str |
"uproot" |
Preprocessing engine: "uproot" or "dask" . |
preprocessed_dir |
str |
None |
Directory with pre-processed (skimmed) files. |
cache_dir |
str |
"/tmp/gradients_analysis/" |
Cache directory for differentiable analysis. |
processes |
list[str] |
None |
Limit analysis to specific processes. |
channels |
list[str] |
None |
Limit analysis to specific channels. |
Settings for the initial data skimming and filtering step.
Parameter | Type | Default | Description |
---|---|---|---|
branches |
dict |
Required | Mapping of collection names to branch lists. |
ignore_missing |
bool |
False |
Ignore missing branches if True . |
mc_branches |
dict |
Required | Additional branches for MC samples. |
Configuration for the differentiable analysis workflow.
Parameter | Type | Default | Description |
---|---|---|---|
soft_selection |
dict |
Required | Differentiable selection function. |
↳ function |
Callable |
Required | Selection function to apply. |
↳ use |
list[str] |
Required | Input variables passed to function . |
params |
dict |
Required | Optimizable parameters (e.g. {'met_threshold': 50.0} ). |
optimize |
bool |
True |
Run gradient-based optimisation if True . |
learning_rate |
float |
0.01 |
Default optimizer learning rate. |
max_iterations |
int |
50 |
Number of optimisation steps. |
param_updates |
dict |
{} |
Parameter-specific clamping functions. |
↳ param_name |
Callable |
- | (old_value, delta) -> new_value function. |
learning_rates |
dict |
None |
Parameter-specific learning rates. |
explicit_optimisation |
bool |
False |
Use manual optimisation loop if True . |
List of MVA model configurations.
Parameter | Type | Default | Description |
---|---|---|---|
name |
str |
Required | Unique model name. |
framework |
str |
Required | "jax" or "keras" . |
learning_rate |
float |
0.01 |
Learning rate for pre-training. |
grad_optimisation |
dict |
{} |
MVA optimisation settings. |
↳ optimise |
bool |
False |
Include MVA in global optimisation. |
↳ learning_rate |
float |
0.001 |
Learning rate for MVA in optimisation. |
layers |
list[dict] |
Required | Network architecture layers. |
↳ ndim |
int |
Required | Number of nodes. |
↳ activation |
str |
Required | Activation function. |
↳ weights |
str |
Required | Name for weights parameter. |
↳ bias |
str |
Required | Name for bias parameter. |
loss |
Callable or str |
Required | Loss function (callable for JAX, string for Keras). |
features |
list[dict] |
Required | Input features for the model. |
↳ name |
str |
Required | Feature name. |
↳ function |
Callable |
Required | Function to compute the feature. |
↳ use |
list[str] |
Required | Input dependencies for the feature. |
classes |
list |
Required | Target classes (e.g. ["wjets", {"ttbar": [...]}] ). |
balance_strategy |
str |
"undersample" |
"none" , "undersample" , "oversample" , "class_weight" . |
random_state |
int |
42 |
Random seed for reproducibility. |
epochs |
int |
1000 |
Pre-training epochs. |
batch_size |
int |
32 |
Training batch size. |
validation_split |
float |
0.2 |
Fraction for validation. |
log_interval |
int |
100 |
Log frequency during training. |
List of analysis channels or regions.
Parameter | Type | Default | Description |
---|---|---|---|
name |
str |
Required | Channel name (e.g. "signal_region" ). |
observables |
list[dict] |
Required | Observables to histogram. |
↳ name |
str |
Required | Observable name. |
↳ binning |
tuple |
Required | Histogram binning. |
↳ function |
Callable |
Required | Function to compute observable. |
↳ use |
list[str] |
Required | Inputs to the observable function. |
fit_observable |
str |
Required | Observable used in statistical fit. |
selection |
dict |
None |
Channel selection function. |
↳ function |
Callable |
Required | Selection function. |
↳ use |
list[str] |
Required | Inputs to selection function. |
use_in_diff |
bool |
False |
Include in differentiable analysis. |
Parameter | Type | Default | Description |
---|---|---|---|
name |
str |
Required | Correction or systematic name. |
type |
str |
Required | "object" or "event" . |
op |
str |
"mult" |
Operation: "mult" or "add" . |
target |
tuple or list |
None |
Object/variable to modify (e.g. ("Jet", "pt") ). |
use |
list |
[] |
Inputs required for the variation function. |
Corrections Only | |||
↳ file |
str |
Required | Path to correction file. |
↳ key |
str |
None |
Key within the file. |
↳ use_correctionlib |
bool |
True |
Use correctionlib . |
↳ transform |
Callable |
None |
Transform arguments before evaluation. |
↳ up_and_down_idx |
list[str] |
["up", "down"] |
Variation labels in file. |
Systematics Only | |||
↳ up_function |
Callable |
None |
Function for "up" variation. |
↳ down_function |
Callable |
None |
Function for "down" variation. |
↳ symmetrise |
bool |
False |
Auto-generate "down" from "up" (not implemented). |
The core of the differentiable workflow is the _run_traced_analysis_chain
method in analysis/diff.py
. This function is what JAX
traces and differentiates.
def _run_traced_analysis_chain(
self,
params: dict[str, Any],
processed_data_events: dict,
) -> tuple[jnp.ndarray, dict[str, Any]]:
# ...
# 1. Collect histograms for all processes using the current `params`
histograms_by_process = self._collect_histograms(...)
# 2. Calculate significance from these histograms
significance, mle_params = self._calculate_significance(histograms_by_process, params["fit"])
# ...
return significance, mle_params
The optimisation loop in run_analysis_optimisation
then does the following:
# Define the objective function to be *maximized* (significance)
# Note: Optimizers typically *minimize*, so we would differentiate the *negative* significance.
# The `relaxed` library handles this internally.
def objective(params):
return self._run_traced_analysis_chain(params, processed_data)
# Get the gradient of the objective function w.r.t. the parameters
gradients = jax.grad(objective)(all_parameters)
# Use an optimizer to update the parameters
# solver = OptaxSolver(fun=objective, opt=tx, ...)
# new_parameters, state = solver.update(parameters, state)
This loop continues until the significance converges or a maximum number of iterations is reached.
A key innovation of this framework is its end-to-end differentiable statistical model. This model is constructed in JAX, allowing the entire analysis—from event selection to statistical inference—to be differentiated.
Instead of standard histograms with hard bin counts, the analysis produces "soft" histograms using a Kernel Density Estimation (KDE) approach. For each event, instead of adding 1
to a single bin, a Gaussian kernel is placed at the event's observable value. The histogram's bin contents are then the integral of all event kernels over that bin's range. This process is smooth and differentiable with respect to both the event weights (from soft selections) and the observable values themselves.
The statistical model is built upon these differentiable histograms. The goal is to construct a likelihood function L(data | params)
that can be maximized to find the best-fit parameters.
The core of the model is the prediction for the expected number of events (expected_yields
) in each histogram bin for each analysis channel. This is a function of the model parameters. In the current implementation (utils/jax_stats.py
), the model is simplified and includes two global scalar parameters:
-
mu
($\mu$ ): The signal strength parameter.$\mu=0$ corresponds to the background-only hypothesis, and$\mu=1$ corresponds to the nominal signal hypothesis. -
norm_ttbar_semilep
($\kappa_{t\bar{t}}$ ): A normalisation factor for thettbar_semilep
background process, applied uniformly across all channels.
For a single channel, the expected yield
Note that in this simplified model, systematic uncertainties are not incorporated via nuisance parameters in the differentiable fit. The framework can produce systematically-varied histograms, but they are used in the non-differentiable analysis path.
With the model for expected_yields
defined, the final step is to perform statistical inference. This is where the relaxed
library is used.
-
Likelihood Construction:
relaxed
takes the JAX model of expected yields and the observed data histograms to construct the full likelihood function. For this simplified model, the likelihood is a product of Poisson probability mass functions over all bins:$L(\text{data} | \mu, \kappa_{t\bar{t}}) = \prod_{c \in \text{channels}} \prod_{i \in \text{bins}} \text{Pois}(N_{\text{obs}, c, i} | \lambda_{c, i}(\mu, \kappa_{t\bar{t}}))$ Since the model in
utils/jax_stats.py
does not include systematic uncertainties as nuisance parameters, there are no constraint terms in the likelihood. -
Differentiable Hypothesis Test: The main goal is to test for the presence of a signal. This is done using a profile likelihood ratio test statistic,
$q_0$ .relaxed.infer.hypotest
performs this test. It finds the values of the model parameters (in this case,$\kappa_{t\bar{t}}$ ) that maximize the likelihood for a given$\mu$ (profiling) and then computes$q_0$ . The significance is then$Z = \sqrt{q_0}$ .
This means we can compute the gradient of the final significance Z
with respect to any parameter in the chain, including the selection cut thresholds (met_threshold
, etc.) and even the weights of a JAX-based MVA. This gradient is what drives the optimisation.
The framework includes support for integrating Machine Learning models (MVAs) into the analysis, with implementations for both JAX and TensorFlow/Keras. This allows for the use of non-linear discriminants that can be pre-trained and, in the case of JAX models, optimized in-situ with the rest of the analysis.
The core logic is handled by the JAXNetwork
and TFNetwork
classes in utils/mva.py
.
This is a neural network implementation written purely in JAX, providing deep integration with the differentiable analysis workflow.
- Explicit Parameter Management: Unlike frameworks that encapsulate model weights, the
JAXNetwork
manages its weights and biases in a simple Python dictionary. This transparency is key to its integration. Parameter names follow a convention (e.g.,__NN_my_model_W1
) that allows the framework to automatically identify them. - End-to-End optimisation: When MVA optimisation is enabled in the configuration (
grad_optimisation.optimise: True
), the network's parameters are added to the global set of variables that the main optimizer tunes. This means the optimizer can simultaneously adjust the MVA's weights to improve signal/background separation and tune the analysis selection cuts, all to directly maximize the final statistical significance. - Full Control: The from-scratch implementation gives full control over the network's forward pass, loss function, and training loop, all within the JAX ecosystem.
- Configuration: The network architecture (layers, activations) is defined in
user/configuration.py
. Activations are provided as Pythonlambda
functions, allowing for custom, non-standard activation functions if needed.
This class provides a wrapper around a standard tf.keras.Sequential
model.
- Leverage Keras: It allows you to use the rich and mature Keras API for building and training models.
- Pre-training Only: The primary use case is to pre-train a powerful discriminator. The trained model is then used to compute a score for each event, which is used as a static input feature in the main analysis. The weights of a TF/Keras model are not part of the global significance optimisation.
- Configuration: The architecture is defined in the configuration file, with activations specified as strings (e.g.,
"relu"
,"tanh"
).
The framework handles MVAs in a two-stage process: an initial, one-off pre-training phase, followed by an optional, continuous optimisation phase that happens alongside the main analysis optimisation.
- When: This happens once at the start of the analysis if
general.run_mva_training
isTrue
. - Data: The framework allows for a completely separate object definition for MVA training. In
config.good_object_masks
, you can define anmva
key with different object selection criteria than theanalysis
key. This is useful for training on a broader, less-biased dataset. - Process:
- During event processing (
_prepare_data
), two parallel sets of object collections are created: one for the main analysis and one for MVA training, each with its own "good object" masks applied. - After all files are processed, the MVA-specific data is passed to
_run_mva_training
. - The models (both JAX and Keras) are trained on this dedicated dataset.
- The resulting trained model (for Keras) or parameters (for JAX) are saved to disk.
- During event processing (
- When: This happens at every step of the main gradient-based optimisation loop.
- Inference (On-the-fly):
- The MVA instance (containing the forward pass logic) and its input features are attached to a special collection in the event record.
- The differentiable selection function (e.g.,
Zprime_softcuts_jax_workshop
) must be designed to accept this collection and the globalparams
dictionary as inputs. - Inside this traced function, the MVA's
forward_pass
is called on-the-fly, using the current state of the MVA's weights from theparams
dictionary. This ensures that the entire calculation, from MVA inputs to score, is part of the JAX computation graph.
- In-situ Optimisation (JAX only):
- If an MVA is configured with
framework: "jax"
andgrad_optimisation.optimise: True
, its pre-trained parameters are included in the set of globally optimisable parameters. - Because the MVA's forward pass is executed on-the-fly within the main traced function, its output (the MVA score) is fully differentiable with respect to its weights and biases.
- The gradient of the final statistical significance is therefore also calculated with respect to these MVA parameters.
- The
optax
optimiser updates the MVA weights at each step, effectively "fine-tuning" the MVA to directly maximise the analysis significance, alongside all other selection cuts. - Keras models are used for inference only; their weights are not optimised during this stage.
- If an MVA is configured with
Adding a new MVA to the analysis is a configuration-driven process:
-
Add to Config: Create a new dictionary entry in the
mva
list inuser/configuration.py
. Give it a uniquename
. -
Define Architecture & Framework:
- Set
framework
to"jax"
or"keras"
. - Define the
layers
list, specifying the dimensions, activation functions, and names for weights/biases for each layer. - Define the
loss
function.
- Set
-
Define Input Features:
- In the
features
list for your MVA, define each input variable. - Each feature needs a
name
, afunction
to compute it (e.g.,lambda mva: mva.n_jet
), and ause
key specifying the inputs to that function. The framework automatically computes "ghost observables" first, so you can define features that depend on them.
- In the
-
Use the MVA in Selection:
- The framework will automatically train the MVA (if
run_mva_training
isTrue
) and compute its output score for every event. - This score is attached to a special object collection named after your MVA (e.g.,
wjets_vs_ttbar_nn
). - You can then use this score in your differentiable selection function in
user/cuts.py
just like any other variable.
- The framework will automatically train the MVA (if
-
Enable Gradient optimisation (JAX only):
- To make the JAX MVA's weights optimizable, set
grad_optimisation.optimise: True
in its configuration. - The framework will automatically find all parameters with the
__NN
prefix and add them to the set of variables that the optimizer will tune. You can even set a custom learning rate for the MVA weights.
- To make the JAX MVA's weights optimizable, set
- Add to Config: Add a new key-value pair to
config["jax"]["params"]
."params": { 'met_threshold': 50.0, 'btag_threshold': 0.5, 'my_new_cut': 100.0, # <-- Add new parameter },
- Use in Soft Selection: Use
params["my_new_cut"]
in your soft selection function inuser/cuts.py
. - (Optional) Add a Clamp: Add a rule for your new parameter in
config["jax"]["param_updates"]
to keep it within a sensible range.
- Add to Config: Add a new dictionary to the
config["systematics"]
list. - Define Logic:
- If it's a simple scale factor, you can define the
up_function
anddown_function
directly in the config (e.g.,lambda: 1.05
). - For more complex variations, define a function in
user/systematics.py
that takes an object collection and returns a per-object weight. - For
correctionlib
-based uncertainties, ensure thefile
andkey
are specified correctly.
- If it's a simple scale factor, you can define the
- Specify Target: Define the
target
(which object and variable are affected) and theop
(how the variation is applied, e.g.,mult
oradd
).
The framework will automatically propagate these systematics through the non-differentiable analysis path to produce varied histograms for tools like cabinetry
.
Alongside the differentiable path, the framework fully supports a traditional, non-differentiable analysis via the NonDiffAnalysis
class in analysis/nondiff.py
. This path uses standard hard cuts and fills hist
objects, which can then be used with tools like cabinetry
for statistical inference. You can control which analysis runs via the analysis
key in config.general
.
├── user/ # USER-CONFIGURABLE MODULES - Modify these for your analysis
│ ├── __init__.py # Package initialization
│ ├── configuration.py # Main configuration file for the analysis
│ ├── cuts.py # Selection logic (both hard and soft/differentiable)
│ ├── observables.py # Physics observables and reconstruction functions
│ └── systematics.py # Systematic variation functions
├── analysis/ # FRAMEWORK CODE - Core analysis classes and pipeline logic
│ ├── base.py # Base class with common analysis logic (corrections, etc.)
│ ├── diff.py # Implements the full differentiable analysis workflow
│ └── nondiff.py # Implements a traditional, non-differentiable analysis
├── utils/ # FRAMEWORK CODE - Supporting utility functions
│ ├── mva.py # MVA (neural network) model definitions and training logic
│ ├── schema.py # Pydantic schemas for validating the configuration
│ ├── plot.py # Plotting utilities and visualization functions
│ ├── stats.py # Statistical analysis functions
│ ├── tools.py # General utility functions
│ ├── input_files.py # File handling utilities
│ ├── output_files.py # Output management utilities
│ └── ... # Other helper utilities
├── cabinetry/
│ └── ... # Configuration for the `cabinetry` statistical tool
├── corrections/
│ └── ... # Correction files (e.g., from `correctionlib`)
└── README.md
The framework separates user-configurable modules (user/
) from framework code (analysis/
, utils/
):
user/
directory: Contains modules that users should modify for their specific analysis needsanalysis/
andutils/
directories: Contains framework code that provides the analysis infrastructure
This separation ensures that users can focus on physics configuration while the framework handles the technical implementation details.
Understanding the logical flow of the differentiable analysis helps users see how their configuration choices in the user/
directory affect the overall workflow. Here's a step-by-step breakdown:
user/configuration.py → Analysis Setup
- The analysis starts by loading your configuration from
user/configuration.py
- This defines all analysis parameters, observables, cuts, and optimization settings
- The framework validates the configuration against the schema in
utils/schema.py
Raw NanoAOD → Preprocessing → Cached Data
- If
general.run_preprocessing=True
, raw NanoAOD files are skimmed - Only branches specified in
config.preprocess.branches
are kept - Baseline selections from
config.baseline_selection
are applied - Results are cached for faster subsequent runs
Cached Data → Feature Extraction → Model Training → Saved Model
- If
general.run_mva_training=True
, neural networks are pre-trained - Features defined in
config.mva[].features
are computed using functions fromuser/observables.py
- Models are trained to distinguish between background processes
- Trained parameters are saved and later used in the main analysis
For each event batch:
Raw Objects → Corrections → Good Objects → Ghost Observables
- Object corrections from
config.corrections
are applied - "Good object" masks from
config.good_object_masks
filter objects - "Ghost observables" from
config.ghost_observables
are computed usinguser/observables.py
- This creates an event records with all necessary variables
For each optimization step:
Events → Soft Cuts → Event Weights → Histograms → Significance → Gradients → Parameter Updates
- Soft Cuts: Your selection function from
user/cuts.py
(e.g.,Zprime_softcuts_jax_workshop
) is called - Event Weights: Instead of hard cuts, sigmoid functions produce continuous weights (0-1) per event
- Histograms: Events are binned using Kernel Density Estimation (KDE) - smooth and differentiable
- Significance: Statistical model computes discovery significance using the
relaxed
library - Gradients: JAX computes gradients of significance w.r.t. all parameters in
config.jax.params
- Updates: Optimizer (optax) updates parameters to maximize significance
config.jax.params → Selection Function → Event Weights → Final Significance
↑ ↓
Parameter Updates ←← Gradients ←← Statistical Model ←← Histograms
- Parameters you define in
config.jax.params
(e.g.,met_threshold: 50.0
) flow into your selection function - Your selection function in
user/cuts.py
uses these parameters in sigmoid cuts - The resulting event weights affect histogram shapes
- Changes in histograms affect the final statistical significance
- Gradients flow backward through this entire chain to update parameters
For each channel in config.channels:
Selection → Observable Computation → Histogramming → Statistical Combination
- Each analysis channel (signal region, control regions) is processed
- Channel-specific selections from
config.channels[].selection
are applied - Observables from
config.channels[].observables
are computed usinguser/observables.py
- All channels contribute to the final statistical model
- Configuration (
user/configuration.py
): You define what gets optimized and how - Observables (
user/observables.py
): You define what physics quantities to compute - Cuts (
user/cuts.py
): You define how events are selected (both hard and soft cuts) - Systematics (
user/systematics.py
): You define uncertainty sources
The framework handles the technical details (JAX tracing, gradient computation, optimisation) while you focus on the physics logic. Every function you write in the user/
directory becomes part of a fully differentiable computation graph that can be optimised end-to-end.
This project uses Sphinx to generate documentation from the source code's docstrings and other reStructuredText files. The documentation is hosted on Read the Docs.
To build and view the documentation on your local machine, follow these steps:
-
Install Dependencies: The documentation dependencies are listed in
docs/requirements.txt
. You can install them usingpip
. It is recommended to do this within the project's conda environment to keep dependencies organized.pip install -r docs/requirements.txt
-
Build the HTML pages: Navigate to the
docs/
directory and use the providedMakefile
to build the documentation.cd docs make html
-
View the Documentation: The generated HTML files will be in the
docs/build/html/
directory. You can open the main page in your browser (e.g., by navigating to the file path in your browser's address bar or using a command likeopen
on macOS orxdg-open
on Linux).
The repository is configured to automatically build and deploy the documentation to Read the Docs on every push to the main branch. This process is controlled by the .readthedocs.yaml
file at the root of the repository.
If you have forked this repository and wish to set up your own Read the Docs deployment:
- Sign up/in to Read the Docs: Go to readthedocs.org and create an account or log in. Make sure your GitHub account is connected.
- Import the Repository: From your Read the Docs dashboard, click on "Import a Project" and select your GitHub repository fork.
- Configuration: Read the Docs will automatically detect the
.readthedocs.yaml
file and configure the build process. You should not need to change any settings in the Read the Docs project configuration page. - Trigger a Build: The first build should trigger automatically after importing. Subsequent builds will be triggered by pushing new commits to your repository's default branch. You can monitor the build status and logs in your Read the Docs project dashboard.