Skip to content

Commit d32fe55

Browse files
* Create a SoftmaxV2 layer which has better numerical handling of the mask (currently if the mask violates any of the assumptions it will do numerically silly things silently).
* Plumb an argument that would opt into the usage of the new softmax layer for the official keras `MultiHeadAttention` layer and the model garden `TransformerEncoderBlock` layer. PiperOrigin-RevId: 831896060
1 parent 924b2d0 commit d32fe55

11 files changed

+576
-6
lines changed

tf_keras/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'use_softmax_v2\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
path: "tensorflow.keras.layers.SoftmaxV2"
2+
tf_class {
3+
is_instance: "<class \'tf_keras.layers.activation.softmax.SoftmaxV2\'>"
4+
is_instance: "<class \'tf_keras.engine.base_layer.Layer\'>"
5+
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
6+
is_instance: "<class \'tensorflow.python.trackable.autotrackable.AutoTrackable\'>"
7+
is_instance: "<class \'tensorflow.python.trackable.base.Trackable\'>"
8+
is_instance: "<class \'tf_keras.utils.version_utils.LayerVersionSelector\'>"
9+
is_instance: "<type \'object\'>"
10+
member {
11+
name: "activity_regularizer"
12+
mtype: "<type \'property\'>"
13+
}
14+
member {
15+
name: "compute_dtype"
16+
mtype: "<type \'property\'>"
17+
}
18+
member {
19+
name: "dtype"
20+
mtype: "<type \'property\'>"
21+
}
22+
member {
23+
name: "dtype_policy"
24+
mtype: "<type \'property\'>"
25+
}
26+
member {
27+
name: "dynamic"
28+
mtype: "<type \'property\'>"
29+
}
30+
member {
31+
name: "inbound_nodes"
32+
mtype: "<type \'property\'>"
33+
}
34+
member {
35+
name: "input"
36+
mtype: "<type \'property\'>"
37+
}
38+
member {
39+
name: "input_mask"
40+
mtype: "<type \'property\'>"
41+
}
42+
member {
43+
name: "input_shape"
44+
mtype: "<type \'property\'>"
45+
}
46+
member {
47+
name: "input_spec"
48+
mtype: "<type \'property\'>"
49+
}
50+
member {
51+
name: "losses"
52+
mtype: "<type \'property\'>"
53+
}
54+
member {
55+
name: "metrics"
56+
mtype: "<type \'property\'>"
57+
}
58+
member {
59+
name: "name"
60+
mtype: "<type \'property\'>"
61+
}
62+
member {
63+
name: "name_scope"
64+
mtype: "<type \'property\'>"
65+
}
66+
member {
67+
name: "non_trainable_variables"
68+
mtype: "<type \'property\'>"
69+
}
70+
member {
71+
name: "non_trainable_weights"
72+
mtype: "<type \'property\'>"
73+
}
74+
member {
75+
name: "outbound_nodes"
76+
mtype: "<type \'property\'>"
77+
}
78+
member {
79+
name: "output"
80+
mtype: "<type \'property\'>"
81+
}
82+
member {
83+
name: "output_mask"
84+
mtype: "<type \'property\'>"
85+
}
86+
member {
87+
name: "output_shape"
88+
mtype: "<type \'property\'>"
89+
}
90+
member {
91+
name: "stateful"
92+
mtype: "<type \'property\'>"
93+
}
94+
member {
95+
name: "submodules"
96+
mtype: "<type \'property\'>"
97+
}
98+
member {
99+
name: "supports_masking"
100+
mtype: "<type \'property\'>"
101+
}
102+
member {
103+
name: "trainable"
104+
mtype: "<type \'property\'>"
105+
}
106+
member {
107+
name: "trainable_variables"
108+
mtype: "<type \'property\'>"
109+
}
110+
member {
111+
name: "trainable_weights"
112+
mtype: "<type \'property\'>"
113+
}
114+
member {
115+
name: "updates"
116+
mtype: "<type \'property\'>"
117+
}
118+
member {
119+
name: "variable_dtype"
120+
mtype: "<type \'property\'>"
121+
}
122+
member {
123+
name: "variables"
124+
mtype: "<type \'property\'>"
125+
}
126+
member {
127+
name: "weights"
128+
mtype: "<type \'property\'>"
129+
}
130+
member_method {
131+
name: "__init__"
132+
argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
133+
}
134+
member_method {
135+
name: "add_loss"
136+
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
137+
}
138+
member_method {
139+
name: "add_metric"
140+
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
141+
}
142+
member_method {
143+
name: "add_update"
144+
argspec: "args=[\'self\', \'updates\'], varargs=None, keywords=None, defaults=None"
145+
}
146+
member_method {
147+
name: "add_variable"
148+
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
149+
}
150+
member_method {
151+
name: "add_weight"
152+
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregationV2.NONE\'], "
153+
}
154+
member_method {
155+
name: "build"
156+
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
157+
}
158+
member_method {
159+
name: "build_from_config"
160+
argspec: "args=[\'self\', \'config\'], varargs=None, keywords=None, defaults=None"
161+
}
162+
member_method {
163+
name: "call"
164+
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
165+
}
166+
member_method {
167+
name: "compute_mask"
168+
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
169+
}
170+
member_method {
171+
name: "compute_output_shape"
172+
argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
173+
}
174+
member_method {
175+
name: "compute_output_signature"
176+
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
177+
}
178+
member_method {
179+
name: "count_params"
180+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
181+
}
182+
member_method {
183+
name: "finalize_state"
184+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
185+
}
186+
member_method {
187+
name: "from_config"
188+
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
189+
}
190+
member_method {
191+
name: "get_build_config"
192+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
193+
}
194+
member_method {
195+
name: "get_config"
196+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
197+
}
198+
member_method {
199+
name: "get_input_at"
200+
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
201+
}
202+
member_method {
203+
name: "get_input_mask_at"
204+
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
205+
}
206+
member_method {
207+
name: "get_input_shape_at"
208+
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
209+
}
210+
member_method {
211+
name: "get_output_at"
212+
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
213+
}
214+
member_method {
215+
name: "get_output_mask_at"
216+
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
217+
}
218+
member_method {
219+
name: "get_output_shape_at"
220+
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
221+
}
222+
member_method {
223+
name: "get_weights"
224+
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
225+
}
226+
member_method {
227+
name: "load_own_variables"
228+
argspec: "args=[\'self\', \'store\'], varargs=None, keywords=None, defaults=None"
229+
}
230+
member_method {
231+
name: "save_own_variables"
232+
argspec: "args=[\'self\', \'store\'], varargs=None, keywords=None, defaults=None"
233+
}
234+
member_method {
235+
name: "set_weights"
236+
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
237+
}
238+
member_method {
239+
name: "with_name_scope"
240+
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
241+
}
242+
}

tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ tf_module {
420420
name: "Softmax"
421421
mtype: "<type \'type\'>"
422422
}
423+
member {
424+
name: "SoftmaxV2"
425+
mtype: "<type \'type\'>"
426+
}
423427
member {
424428
name: "SpatialDropout1D"
425429
mtype: "<type \'type\'>"

tf_keras/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'use_softmax_v2\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

0 commit comments

Comments
 (0)