Skip to content

Commit

Permalink
Update docs to not recommend pure max as reduce type.
Browse files Browse the repository at this point in the history
A more typical example for modeling is `mean|sum`. When max is needed,
the safe way in the presence of zero values is `max_no_inf`.

PiperOrigin-RevId: 672983309
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Sep 11, 2024
1 parent bcb74be commit 11cfcc0
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tensorflow_gnn/keras/layers/convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class SimpleConv(convolution_base.AnyToAnyConvolutionBase):
combined input features (see combine_type).
reduce_type: Specifies how to pool the messages to receivers. Defaults to
`"sum"`, can be any reduce_type understood by `tfgnn.pool()`, including
concatenations like `"sum|max"` (but mind the increased dimension of the
concatenations like `"sum|mean"` (but mind the increased dimension of the
result and the growing number of model weights in the next-state layer).
combine_type: a string understood by tfgnn.combine_values(), to specify how
the inputs are combined before passing them to the message_fn. Defaults
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/models/mt_albis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ states from incoming messages. Its main architectural choices are:
* how to aggregate the incoming messages from each node set:
* by element-wise averaging (reduce type `"mean"`),
* by a concatenation of the average with other fixed expressions
(e.g., `"mean|max"`, `"mean|sum"`), or
(e.g., `"mean|max_no_inf"`, `"mean|sum"`), or
* with attention, that is, a trained, data-dependent weighting;
* whether to use residual connections for updating node states;
* if and how to normalize node states.
Expand Down
10 changes: 5 additions & 5 deletions tensorflow_gnn/models/mt_albis/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def MtAlbisSimpleConv( # To be called like a class initializer. pylint: disabl
If left unset for init, the tag must be passed at call time.
reduce_type: Controls how messages are aggregated on an EdgeSet for each
receiver node; defaults to `"mean"`. Can be any reduce_type understood by
`tfgnn.pool()`, including concatenations like `"mean|max"` (but mind the
`tfgnn.pool()`, including concatenations like `"mean|sum"` (but mind the
increased dimension of the result and the growing number of model weights
in the next-state layer).
activation: The nonlinearity used on each message before pooling.
Expand Down Expand Up @@ -291,10 +291,10 @@ def MtAlbisGraphUpdate( # To be called like a class initializer. pylint: disab
simple_conv_reduce_type: For attention_type `"none"`, controls how messages
are aggregated on an EdgeSet for each receiver node. Defaults to `"mean"`;
other recommended values are the concatenations `"mean|sum"`,
`"mean|max"`, and `"mean|sum|max"` (but mind the increased output
dimension and the corresponding increase in the number of weights in the
next-state layer). Technically, can be set to any reduce_type understood
by `tfgnn.pool()`.
`"mean|max_no_inf"`, and `"mean|sum|max_no_inf"` (but mind the increased
output dimension and the corresponding increase in the number of weights
in the next-state layer). Technically, can be set to any reduce_type
understood by `tfgnn.pool()`.
simple_conv_use_receiver_state: For attention_type `"none"`, controls
whether the receiver node state is used in computing each edge's message
(in addition to the sender node state and possibly an `edge feature`).
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/models/vanilla_mpnn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def VanillaMPNNGraphUpdate( # To be called like a class initializer. pylint: d
this input.
reduce_type: How to pool the messages from edges to receiver nodes; defaults
to `"sum"`. Can be any reduce_type understood by `tfgnn.pool()`, including
concatenations like `"sum|max"` (but mind the increased dimension of the
concatenations like `"sum|mean"` (but mind the increased dimension of the
result and the growing number of model weights in the next-state layer).
l2_regularization: The coefficient of L2 regularization for weights and
biases.
Expand Down

0 comments on commit 11cfcc0

Please sign in to comment.