Skip to content

Commit

Permalink
first pass implement ML/target blending
Browse files Browse the repository at this point in the history
  • Loading branch information
miketynes committed Aug 2, 2024
1 parent f2f6c74 commit e72a960
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions cascade/proxima/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class SerialLearningCalculator(Calculator):
Minimum fraction of timesteps to run the target function.
This value is used as the probability of running the target function
even if it need not be used based on the UQ metric.
n_blending_steps: int
How many timesteps to smoothly combine target and surrogate forces.
When the threshold is satisy we apply an increasing mixture of ML and
target forces.
db_path: Path or str
Database in which to store the results of running the target calculator,
which are used to train the surrogate model
Expand All @@ -72,6 +76,7 @@ class SerialLearningCalculator(Calculator):
'train_recency_bias': 1.,
'target_ferr': 0.1, # TODO (wardlt): Make the error metric configurable
'min_target_fraction': 0.,
'n_blending_steps': 0,
'history_length': 8,
'db_path': None
}
Expand All @@ -94,6 +99,8 @@ class SerialLearningCalculator(Calculator):
"""Total number of calls to the calculator"""
target_invocations: int = 0
"""Total number of calls to the target calculator"""
seq_surrogate_invocations = 0
"""Number of surrogate invocations in a row (resets once threshold exceeded)"""
model_version: int = 0
"""How many times the model has been retrained"""

Expand Down Expand Up @@ -180,11 +187,20 @@ def calculate(
unc_metric = forces_diff.max()
logger.debug(f'Computed the uncertainty metric for the model to be: {unc_metric:.2e}')

# Use the result from the surrogate
# Check whether to use the result from the surrogate
uq_small_enough = self.threshold is not None and unc_metric < self.threshold
self.used_surrogate = uq_small_enough and (random() > self.parameters['min_target_fraction'])
self.total_invocations += 1

# Track blending parameters for surrogate/target
blend_with_target = self.n_blend_steps - self.seq_surrogate_invocations
if self.used_surrogate:
self.seq_surrogate_invocations += 1
else:
self.seq_surrogate_invocations = 0

# Case: fully use the surrogate
if self.used_surrogate and not blend_with_target:
logger.debug(f'The uncertainty metric is low enough ({unc_metric:.2e} < {self.threshold:.2e}). Using the surrogate result.')
self.results = self.surrogate_calc.results.copy()
return
Expand All @@ -193,7 +209,22 @@ def calculate(
target_calc: Calculator = self.parameters['target_calc']
target_calc.calculate(atoms, properties, system_changes)
self.target_invocations += 1
self.results = target_calc.results.copy()

if self.used_surrogate and blend_with_target:
# return a blend if appropriate
lambda_target = blend_with_target / self.n_blend_steps
results_target = target_calc.results.copy()
results_surrogate = self.surrogate_calc.results.copy()
self.results = {}
for k in results_surrogate.keys():
if k in results_target.keys():
self.results[k] = lambda_target * results_target[k] + (1-lambda_target)*results_surrogate[k]
else:
# the surrogate may have some extra results which we store
self.results[k] = results_surrogate[k]
else:
# otherwise just return the target
self.results = target_calc.results.copy()

# Increment the training set with this new result
db_atoms = atoms.copy()
Expand Down

0 comments on commit e72a960

Please sign in to comment.