Skip to content

Commit

Permalink
add ground truth plot
Browse files Browse the repository at this point in the history
  • Loading branch information
khairulislam committed Nov 1, 2023
1 parent e350277 commit e426707
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 58 deletions.
1 change: 0 additions & 1 deletion explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def _attribute_by_index(
)['mu_star'].data
attr[:, pred_index, t] = torch.tensor(morris_index, device=device)

print(attr.shape)
return attr

def attribute(
Expand Down
198 changes: 198 additions & 0 deletions plot_ground_truth.ipynb

Large diffs are not rendered by default.

Binary file added results/weekly_ground_truth.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 1 addition & 57 deletions utils/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def batch_compute_attr(
baselines = get_baseline(inputs, mode=baseline_mode)

# get attributions
attr = compute_attr(
attr = compute_regressor_attr(
inputs, baselines, explainer, additional_forward_args, exp.args
)
attr_list.append(attr)
Expand Down Expand Up @@ -155,62 +155,6 @@ def reshape_over_output_horizon(attr, inputs, args):

return attr

def compute_attr(
inputs, baselines, explainer,
additional_forward_args, args
):
assert type(inputs) == torch.Tensor, \
f'Only input type tensor supported, found {type(inputs)} instead.'
name = explainer.get_name()

# these methods don't support having multiple outputs at the same time
if name in ['Deep Lift', 'Lime', 'Integrated Gradients', 'Gradient Shap']:
attr_list = []
for target in range(args.pred_len):
score = explainer.attribute(
inputs=inputs, baselines=baselines, target=target,
additional_forward_args=additional_forward_args
)
attr_list.append(score)

attr = torch.stack(attr_list)
# pred_len x batch x seq_len x features -> batch x pred_len x seq_len x features
attr = attr.permute(1, 0, 2, 3)

elif name == 'Feature Ablation':
attr = explainer.attribute(
inputs=inputs, baselines=baselines,
additional_forward_args=additional_forward_args
)
elif name == 'Occlusion':
attr = explainer.attribute(
inputs=inputs,
baselines=baselines,
sliding_window_shapes = (1,1),
additional_forward_args=additional_forward_args
)
elif name == 'Augmented Occlusion':
attr = explainer.attribute(
inputs=inputs,
sliding_window_shapes = (1,1),
additional_forward_args=additional_forward_args
)
elif name == 'Morris Sensitivity':
attr = explainer.attribute(
inputs=inputs,
additional_forward_args=additional_forward_args
)
else:
print(f'{name} not supported.')
raise NotImplementedError

attr = attr.reshape(
# batch x pred_len x seq_len x features
(inputs.shape[0], args.pred_len, args.seq_len, attr.shape[-1])
)

return attr

def get_total_data(dataloader, device, add_x_mark=False):
if add_x_mark:
return (
Expand Down

0 comments on commit e426707

Please sign in to comment.