@@ -175,13 +175,12 @@ def _trim_inputs(self, inputs):
175175 def _combine_inputs (self , segments ):
176176 """Combine inputs with start and end values added."""
177177 dtype = segments [0 ].dtype
178+ batch_size = segments [0 ].nrows ()
178179 start_value = tf .convert_to_tensor (self .start_value , dtype = dtype )
179180 end_value = tf .convert_to_tensor (self .end_value , dtype = dtype )
180181
181- start_column = tf .tile ([start_value ], [segments [0 ].nrows ()])
182- start_column = tf .expand_dims (start_column , 1 )
183- end_column = tf .tile ([end_value ], [segments [0 ].nrows ()])
184- end_column = tf .expand_dims (end_column , 1 )
182+ start_column = tf .fill ((batch_size , 1 ), start_value )
183+ end_column = tf .fill ((batch_size , 1 ), end_value )
185184 ones_column = tf .ones_like (start_column , dtype = tf .int32 )
186185
187186 segments_to_combine = [start_column ]
@@ -211,7 +210,7 @@ def call(self, inputs):
211210 segments = self ._trim_inputs (inputs )
212211 token_ids , segment_ids = self ._combine_inputs (segments )
213212 # Pad to dense tensor output.
214- shape = tf .cast ([- 1 , self .sequence_length ], " int64" )
213+ shape = tf .cast ([- 1 , self .sequence_length ], tf . int64 )
215214 token_ids = token_ids .to_tensor (
216215 shape = shape , default_value = self .pad_value
217216 )
0 commit comments