@@ -68,48 +68,49 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None:
6868 v = get_analysis_vars (
6969 context_keys = ["target_token_analysis" ],
7070 local_keys = [
71- "logit_idx" ,
72- "logit_p" ,
71+ "targets" ,
7372 "total_nodes" ,
7473 "total_active_feats" ,
7574 "max_feature_nodes" ,
7675 "edge_matrix" ,
7776 "row_to_node_index" ,
78- "logit_vecs" ,
7977 "n_layers" ,
8078 "n_pos" ,
8179 ],
8280 local_vars = local_vars ,
8381 )
8482 tta = v ["target_token_analysis" ]
85- tta .update_logit_info (v ["logit_idx" ], v ["logit_p" ])
86- max_n_logits = len (v ["logit_idx" ])
83+ targets = v ["targets" ]
84+ # Extract token IDs and probabilities from targets object
85+ logit_idx , logit_p = targets .token_ids , targets .logit_probabilities
86+ tta .update_logit_info (logit_idx , logit_p )
87+ max_n_logits = len (targets )
8788
8889 HOOK_REGISTRY .set_context (
8990 max_feature_nodes = v ["max_feature_nodes" ],
9091 total_nodes = v ["total_nodes" ],
91- logit_idx = v [ " logit_idx" ] ,
92- logit_p = v [ " logit_p" ] ,
92+ logit_idx = logit_idx ,
93+ logit_p = logit_p ,
9394 n_pos = v ["n_pos" ],
9495 max_n_logits = max_n_logits ,
9596 )
9697 data = {
97- "logit_idx" : v [ " logit_idx" ] ,
98- "logit_p" : VarAnnotate ("logit_p" , var_value = v [ " logit_p" ] , annotation = "non-demeaned logits probabilities" ),
98+ "logit_idx" : logit_idx ,
99+ "logit_p" : VarAnnotate ("logit_p" , var_value = logit_p , annotation = "logit probabilities" ),
99100 "target_tokens" : tta .tokens ,
100101 "target_logit_indices" : tta .logit_indices ,
101102 "target_logit_p" : tta .logit_probabilities ,
102- "logit_cumulative_prob" : float (v [ " logit_p" ] .sum ().item ()),
103+ "logit_cumulative_prob" : float (logit_p .sum ().item ()),
103104 "total_nodes" : v ["total_nodes" ],
104105 "max_feature_nodes" : v ["max_feature_nodes" ],
105106 "total_active_feats" : v ["total_active_feats" ],
106- "n_logits" : len (v [ "logit_idx" ] ),
107+ "n_logits" : len (targets ),
107108 "n_layers" : v ["n_layers" ],
108109 "n_pos" : v ["n_pos" ],
109110 "max_n_logits" : max_n_logits ,
110111 "edge_matrix.shape" : v ["edge_matrix" ].shape ,
111112 "row_to_node_index.shape" : v ["row_to_node_index" ].shape if v ["row_to_node_index" ] is not None else None ,
112- "logit_vecs.shape" : v [ "logit_vecs" ] .shape ,
113+ "logit_vecs.shape" : targets . logit_vectors .shape ,
113114 }
114115 analysis_log_point ("after building input vectors w/ target logits" , data )
115116
@@ -151,7 +152,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None:
151152 # Use dict directly for cleaner access
152153 v = get_analysis_vars (
153154 context_keys = ["target_token_analysis" ],
154- local_keys = ["n_visited" , "max_feature_nodes" , "logit_p " , "edge_matrix" , "ctx" ],
155+ local_keys = ["n_visited" , "max_feature_nodes" , "targets " , "edge_matrix" , "ctx" ],
155156 local_vars = local_vars ,
156157 )
157158 tta = v ["target_token_analysis" ]
0 commit comments