@@ -152,6 +152,24 @@ pub fn create_line_search(config: LineSearchConfig) -> Box<dyn LineSearch> {
152152 }
153153}
154154
155+ fn unflatten_tensors (
156+ flat : & [ f64 ] ,
157+ shapes : & [ Vec < usize > ] ,
158+ ) -> Result < Vec < Vec < f32 > > > {
159+ let mut result = Vec :: new ( ) ;
160+ let mut offset = 0 ;
161+ for shape in shapes {
162+ let size: usize = shape. iter ( ) . product ( ) ;
163+ if offset + size > flat. len ( ) {
164+ return Err ( anyhow:: anyhow!( "Size mismatch in unflattening" ) ) ;
165+ }
166+ let chunk = & flat[ offset..offset + size] ;
167+ result. push ( chunk. iter ( ) . map ( |& x| x as f32 ) . collect ( ) ) ;
168+ offset += size;
169+ }
170+ Ok ( result)
171+ }
172+
155173/// Trait for line search algorithms
156174pub trait LineSearch : Send + Sync + Debug {
157175 /// Perform 1D line search optimization
@@ -191,32 +209,26 @@ pub trait LineSearch: Send + Sync + Debug {
191209 /// executes the graph, and returns the loss value.
192210 fn evaluate_at_step (
193211 & self ,
194- context : & OptimizationContext ,
212+ context : & mut OptimizationContext ,
195213 current_params : & [ f64 ] ,
196214 direction : & [ f64 ] ,
197215 step : f64 ,
198216 ) -> Result < f64 > {
199217 if self . is_verbose ( ) {
200218 println ! ( "LineSearch: Evaluating f(x + alpha * d) at alpha = {:.6e}" , step) ;
201219 }
202- let candidate_params: Vec < f32 > = current_params
220+ let candidate_params: Vec < f64 > = current_params
203221 . iter ( )
204222 . zip ( direction. iter ( ) )
205- . map ( |( x, d) | ( x + step * d) as f32 )
223+ . map ( |( x, d) | x + step * d)
206224 . collect ( ) ;
207225
208- let mut offset = 0 ;
209- for weight in & context. weights {
210- let len: usize = weight. shape . to_shape ( ) . iter ( ) . map ( |d| d. to_usize ( ) . unwrap ( ) ) . product ( ) ;
211-
212- if offset + len > candidate_params. len ( ) {
213- return Err ( anyhow:: anyhow!( "Parameter size mismatch" ) ) ;
214- }
215-
216- let chunk = & candidate_params[ offset..offset + len] ;
217- context. graph ( ) . set_tensor ( weight. id , 0 , Tensor :: new ( chunk. to_vec ( ) ) ) ;
218- offset += len;
219- }
226+ let shapes = context. weights . iter ( ) . map ( |w| w. shape . to_shape ( ) . iter ( ) . map (
227+ |& d| d. to_usize ( ) . unwrap ( )
228+ ) . collect_vec ( ) ) . collect :: < Vec < _ > > ( ) ;
229+
230+ let mut weights_data = unflatten_tensors ( & candidate_params, & shapes) ?;
231+ context. write_weights ( & mut weights_data) ;
220232
221233 context. graph ( ) . execute ( ) ;
222234 let f_val = context
@@ -235,32 +247,26 @@ pub trait LineSearch: Send + Sync + Debug {
235247 /// This is more efficient than separate calls when both are needed.
236248 fn evaluate_with_gradient (
237249 & self ,
238- context : & OptimizationContext ,
250+ context : & mut OptimizationContext ,
239251 current_params : & [ f64 ] ,
240252 direction : & [ f64 ] ,
241253 step : f64 ,
242254 ) -> Result < ( f64 , Vec < f64 > ) > {
243255 if self . is_verbose ( ) {
244256 println ! ( "LineSearch: Evaluating f and g at alpha = {:.6e}" , step) ;
245257 }
246- let candidate_params: Vec < f32 > = current_params
258+ let candidate_params: Vec < f64 > = current_params
247259 . iter ( )
248260 . zip ( direction. iter ( ) )
249- . map ( |( x, d) | ( x + step * d) as f32 )
261+ . map ( |( x, d) | x + step * d)
250262 . collect ( ) ;
251263
252- let mut offset = 0 ;
253- for weight in & context. weights {
254- let len: usize = weight. shape . to_shape ( ) . iter ( ) . map ( |d| d. to_usize ( ) . unwrap ( ) ) . product ( ) ;
255-
256- if offset + len > candidate_params. len ( ) {
257- return Err ( anyhow:: anyhow!( "Parameter size mismatch" ) ) ;
258- }
259-
260- let chunk = & candidate_params[ offset..offset + len] ;
261- context. graph ( ) . set_tensor ( weight. id , 0 , Tensor :: new ( chunk. to_vec ( ) ) ) ;
262- offset += len;
263- }
264+ let shapes = context. weights . iter ( ) . map ( |w| w. shape . to_shape ( ) . iter ( ) . map (
265+ |& d| d. to_usize ( ) . unwrap ( )
266+ ) . collect_vec ( ) ) . collect :: < Vec < _ > > ( ) ;
267+
268+ let mut weights_data = unflatten_tensors ( & candidate_params, & shapes) ?;
269+ context. write_weights ( & mut weights_data) ;
264270
265271 context. graph ( ) . execute ( ) ;
266272 // Get loss
0 commit comments