diff --git a/crates/luminal_training/src/optimizer.rs b/crates/luminal_training/src/optimizer.rs index a62e592e..8c2856d1 100644 --- a/crates/luminal_training/src/optimizer.rs +++ b/crates/luminal_training/src/optimizer.rs @@ -50,7 +50,7 @@ pub fn sgd_on_graph( // SGD let new_weight = old_weight - (gradient * lr.expand_to(grad_shape)); - new_weight.keep(); + new_weight.retrieve(); new_weights.push(new_weight.id); } diff --git a/examples/train_math_net/src/main.rs b/examples/train_math_net/src/main.rs index 34a9394e..fc3d52e1 100644 --- a/examples/train_math_net/src/main.rs +++ b/examples/train_math_net/src/main.rs @@ -13,6 +13,9 @@ fn main() { let model = <(Linear<8, 16>, Swish, Linear<16, 16>, Swish, Linear<16, 5>)>::initialize(&mut cx); let mut input = cx.tensor::>(); let mut target = cx.tensor::>(); + model.0.weight.retrieve(); + model.2.weight.retrieve(); + model.4.weight.retrieve(); let mut output = model.forward(input).retrieve(); let mut loss = mse_loss(output, target).retrieve();