@@ -229,18 +229,63 @@ def keras_to_jte_optimizer(
229229 # pylint: disable-next=protected-access
230230 learning_rate = keras_to_jte_learning_rate (optimizer ._learning_rate )
231231
232- # SGD or Adagrad
232+ # Unsupported keras optimizer general options.
233+ if optimizer .clipnorm is not None :
234+ raise ValueError ("Unsupported optimizer option `clipnorm`." )
235+ if optimizer .global_clipnorm is not None :
236+ raise ValueError ("Unsupported optimizer option `global_clipnorm`." )
237+ if optimizer .use_ema :
238+ raise ValueError ("Unsupported optimizer option `use_ema`." )
239+ if optimizer .loss_scale_factor is not None :
240+ raise ValueError ("Unsupported optimizer option `loss_scale_factor`." )
241+
242+ # Supported optimizers.
233243 if isinstance (optimizer , keras .optimizers .SGD ):
244+ if getattr (optimizer , "nesterov" , False ):
245+ raise ValueError ("Unsupported optimizer option `nesterov`." )
246+ if getattr (optimizer , "momentum" , 0.0 ) != 0.0 :
247+ raise ValueError ("Unsupported optimizer option `momentum`." )
234248 return embedding_spec .SGDOptimizerSpec (learning_rate = learning_rate )
235249 elif isinstance (optimizer , keras .optimizers .Adagrad ):
250+ if getattr (optimizer , "epsilon" , 1e-7 ) != 1e-7 :
251+ raise ValueError ("Unsupported optimizer option `epsilon`." )
236252 return embedding_spec .AdagradOptimizerSpec (
237253 learning_rate = learning_rate ,
238254 initial_accumulator_value = optimizer .initial_accumulator_value ,
239255 )
256+ elif isinstance (optimizer , keras .optimizers .Adam ):
257+ if getattr (optimizer , "amsgrad" , False ):
258+ raise ValueError ("Unsupported optimizer option `amsgrad`." )
240259
241- # Default to SGD for now, since other optimizers are still being created,
242- # and we don't want to fail.
243- return embedding_spec .SGDOptimizerSpec (learning_rate = learning_rate )
260+ return embedding_spec .AdamOptimizerSpec (
261+ learning_rate = learning_rate ,
262+ beta_1 = optimizer .beta_1 ,
263+ beta_2 = optimizer .beta_2 ,
264+ epsilon = optimizer .epsilon ,
265+ )
266+ elif isinstance (optimizer , keras .optimizers .Ftrl ):
267+ if (
268+ getattr (optimizer , "l2_shrinkage_regularization_strength" , 0.0 )
269+ != 0.0
270+ ):
271+ raise ValueError (
272+ "Unsupported optimizer option "
273+ "`l2_shrinkage_regularization_strength`."
274+ )
275+
276+ return embedding_spec .FTRLOptimizerSpec (
277+ learning_rate = learning_rate ,
278+ learning_rate_power = optimizer .learning_rate_power ,
279+ l1_regularization_strength = optimizer .l1_regularization_strength ,
280+ l2_regularization_strength = optimizer .l2_regularization_strength ,
281+ beta = optimizer .beta ,
282+ initial_accumulator_value = optimizer .initial_accumulator_value ,
283+ )
284+
285+ raise ValueError (
286+ f"Unsupported optimizer type { type (optimizer )} . Optimizer must be "
287+ f"one of [Adagrad, Adam, Ftrl, SGD]."
288+ )
244289
245290
246291def jte_to_keras_optimizer (
@@ -262,8 +307,33 @@ def jte_to_keras_optimizer(
262307 learning_rate = learning_rate ,
263308 initial_accumulator_value = optimizer .initial_accumulator_value ,
264309 )
310+ elif isinstance (optimizer , embedding_spec .AdamOptimizerSpec ):
311+ return keras .optimizers .Adam (
312+ learning_rate = learning_rate ,
313+ beta_1 = optimizer .beta_1 ,
314+ beta_2 = optimizer .beta_2 ,
315+ epsilon = optimizer .epsilon ,
316+ )
317+ elif isinstance (optimizer , embedding_spec .FTRLOptimizerSpec ):
318+ if getattr (optimizer , "initial_linear_value" , 0.0 ) != 0.0 :
319+ raise ValueError (
320+ "Unsupported optimizer option `initial_linear_value`."
321+ )
322+ if getattr (optimizer , "multiply_linear_by_learning_rate" , False ):
323+ raise ValueError (
324+ "Unsupported optimizer option "
325+ "`multiply_linear_by_learning_rate`."
326+ )
327+ return keras .optimizers .Ftrl (
328+ learning_rate = learning_rate ,
329+ learning_rate_power = optimizer .learning_rate_power ,
330+ initial_accumulator_value = optimizer .initial_accumulator_value ,
331+ l1_regularization_strength = optimizer .l1_regularization_strength ,
332+ l2_regularization_strength = optimizer .l2_regularization_strength ,
333+ beta = optimizer .beta ,
334+ )
265335
266- raise ValueError (f"Unknown optimizer spec { optimizer } " )
336+ raise ValueError (f"Unknown optimizer spec { type ( optimizer ) } . " )
267337
268338
269339def _keras_to_jte_table_config (
0 commit comments