@@ -383,7 +383,7 @@ pub struct AdamOptimizer {
383383 /// Stagnation multiplier for relaxed convergence criteria (future use)
384384 stagnation_multiplier : f32 ,
385385 /// Learning rate tensor for graph-based updates
386- lr_tensor : Option < GraphTensor > ,
386+ lr_tensor : Option < SafeTensor > ,
387387 /// Stagnation count threshold (future use)
388388 stagnation_count : usize ,
389389 /// Name of the optimizer variant
@@ -462,7 +462,7 @@ impl Optimizer for AdamOptimizer {
462462 ) -> OptimizerSetup {
463463 // Create learning rate tensor
464464 let lr = graph. tensor ( ( ) ) . set ( vec ! [ self . config. learning_rate] ) ;
465- self . lr_tensor = Some ( lr ) ;
465+ self . lr_tensor = Some ( SafeTensor ( lr ) ) ;
466466
467467 // Create constants for beta values and epsilon
468468 let beta1 = graph. tensor ( ( ) ) . set ( vec ! [ self . config. beta1] ) ;
@@ -489,35 +489,35 @@ impl Optimizer for AdamOptimizer {
489489 for ( i, ( w, g) ) in weights. iter ( ) . zip ( gradients. iter ( ) ) . enumerate ( ) {
490490 // Create moment estimate tensors initialized to zero
491491 // These will accumulate across iterations
492- let m = graph. tensor ( ( ) ) . set ( vec ! [ 0.0f32 ] ) . expand_to ( g. shape ) ;
493- let v = graph. tensor ( ( ) ) . set ( vec ! [ 0.0f32 ] ) . expand_to ( g. shape ) ;
492+ let m = graph. tensor ( ( ) ) . set ( vec ! [ 0.0f32 ] ) . expand ( g. shape ) ;
493+ let v = graph. tensor ( ( ) ) . set ( vec ! [ 0.0f32 ] ) . expand ( g. shape ) ;
494494
495495 // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
496- let m_new = m * beta1. expand_to ( g. shape ) + * g * one_minus_beta1. expand_to ( g. shape ) ;
496+ let m_new = m * beta1. expand ( g. shape ) + * g * one_minus_beta1. expand ( g. shape ) ;
497497
498498 // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
499499 let g_squared = * g * * g;
500500 let v_new =
501- v * beta2. expand_to ( g. shape ) + g_squared * one_minus_beta2. expand_to ( g. shape ) ;
501+ v * beta2. expand ( g. shape ) + g_squared * one_minus_beta2. expand ( g. shape ) ;
502502
503503 // Bias-corrected estimates
504504 // m_hat = m_t / (1 - beta1^t)
505- let m_hat = m_new / bc1_tensor. expand_to ( g. shape ) ;
505+ let m_hat = m_new / bc1_tensor. expand ( g. shape ) ;
506506
507507 // v_hat = v_t / (1 - beta2^t)
508- let v_hat = v_new / bc2_tensor. expand_to ( g. shape ) ;
508+ let v_hat = v_new / bc2_tensor. expand ( g. shape ) ;
509509
510510 // Update: theta_{t+1} = theta_t - lr * m_hat / (sqrt(v_hat) + epsilon)
511511 let v_hat_sqrt = v_hat. sqrt ( ) ;
512- let denom = v_hat_sqrt + epsilon. expand_to ( g. shape ) ;
512+ let denom = v_hat_sqrt + epsilon. expand ( g. shape ) ;
513513 let update = m_hat / denom;
514514
515- let mut w_new = * w - update * lr. expand_to ( g. shape ) ;
515+ let mut w_new = * w - update * lr. expand ( g. shape ) ;
516516
517517 // Apply weight decay if configured
518518 if self . config . weight_decay > 0.0 {
519519 let wd = graph. tensor ( ( ) ) . set ( vec ! [ self . config. weight_decay] ) ;
520- w_new = w_new - * w * wd. expand_to ( g. shape ) * lr. expand_to ( g. shape ) ;
520+ w_new = w_new - * w * wd. expand ( g. shape ) * lr. expand ( g. shape ) ;
521521 }
522522
523523 new_weights. push ( w_new) ;
@@ -545,7 +545,6 @@ impl Optimizer for AdamOptimizer {
545545 fn reset ( & mut self ) {
546546 self . state . reset ( ) ;
547547 self . current_lr = self . config . learning_rate ;
548- self . lr_tensor = None ;
549548 // Note: name is not reset as it's determined by configuration
550549 }
551550
@@ -587,4 +586,4 @@ mod tests {
587586 assert ! ( state. v. is_none( ) ) ;
588587 assert ! ( state. v_max. is_none( ) ) ;
589588 }
590- }
589+ }
0 commit comments