From 05783f0000499c54cadeea2b3f14297ec4d0fa11 Mon Sep 17 00:00:00 2001 From: zhouyc Date: Thu, 2 Jul 2020 11:35:31 +0800 Subject: [PATCH] bug fix cant serialization fgcnn model bug report: Invalid input for serialization, type: --- deepctr/models/fgcnn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepctr/models/fgcnn.py b/deepctr/models/fgcnn.py index 1c7ee47f..7addb183 100644 --- a/deepctr/models/fgcnn.py +++ b/deepctr/models/fgcnn.py @@ -70,8 +70,12 @@ def FGCNN(linear_feature_columns, dnn_feature_columns, conv_kernel_width=(7, 7, combined_input = concat_func([origin_input, new_features], axis=1) else: combined_input = origin_input - inner_product = tf.keras.layers.Flatten()(InnerProductLayer()( - tf.keras.layers.Lambda(unstack, mask=[None] * int(combined_input.shape[1]))(combined_input))) + + #inner_product = tf.keras.layers.Flatten()(InnerProductLayer()( + # tf.keras.layers.Lambda(unstack, mask=[None] * int(combined_input.shape[1]))(combined_input))) + + inner_product = tf.keras.layers.Flatten()(InnerProductLayer()(tf.split(combined_input,combined_input.shape[1],1))) + linear_signal = tf.keras.layers.Flatten()(combined_input) dnn_input = tf.keras.layers.Concatenate()([linear_signal, inner_product]) dnn_input = tf.keras.layers.Flatten()(dnn_input)