Skip to content

Commit

Permalink
added some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Oct 4, 2024
1 parent 1396eb1 commit efac13e
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/arc_spice/dropout_utils/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,27 @@ def clean_inference(self, x: Union[np.ndarray, bytes, str]):
return output

def variational_inference(self, x, n_runs=5):
# we need clean inputs to pass to each step, we run that first
output = {"clean": {}, "variational": {}}
output["clean"] = self.clean_inference(x)
# each step accepts a different input from the clean pipeline
input_map = {
"transcription": x,
"translation": output["clean"]["transcription"]["outputs"],
"summarisation": output["clean"]["translation"]["outputs"],
}
# for each model in pipeline
for model_key, pl in self.pipeline_map.items():
# perhaps we could use a context handler here?
# turn on dropout for this model
set_dropout(model=pl.model, dropout_flag=True)
# create the output list
output["variational"][model_key] = [None] * n_runs
# do n runs of the inference
for run_idx in range(n_runs):
output["variational"][model_key][run_idx] = self.func_map[model_key](
input_map[model_key]
)
# turn off dropout for this model
set_dropout(model=pl.model, dropout_flag=False)
return output

Expand Down

0 comments on commit efac13e

Please sign in to comment.