@@ -326,12 +326,13 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
326326 return NULL ;
327327}
328328
329- RAI_Model * RAI_ModelCreateORT (RAI_Backend backend , const char * devicestr , RAI_ModelOpts opts ,
330- const char * modeldef , size_t modellen , RAI_Error * error ) {
329+ int RAI_ModelCreateORT (RAI_Model * model , RAI_Error * error ) {
331330
332331 const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
333332 char * * inputs_ = NULL ;
334333 char * * outputs_ = NULL ;
334+ size_t ninputs ;
335+ size_t noutputs ;
335336 OrtSessionOptions * session_options = NULL ;
336337 OrtSession * session = NULL ;
337338 OrtStatus * status = NULL ;
@@ -348,7 +349,7 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
348349 }
349350
350351 ONNX_VALIDATE_STATUS (ort -> CreateSessionOptions (& session_options ))
351- if (strcasecmp (devicestr , "CPU" ) == 0 ) {
352+ if (strcasecmp (model -> devicestr , "CPU" ) == 0 ) {
352353 // These are required to ensure that onnx will use the registered REDIS allocator (for
353354 // a model that defined to run on CPU).
354355 ONNX_VALIDATE_STATUS (
@@ -359,24 +360,31 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
359360 // TODO: these options could be configured at the AI.CONFIG level
360361 ONNX_VALIDATE_STATUS (ort -> SetSessionGraphOptimizationLevel (session_options , ORT_ENABLE_BASIC ))
361362 ONNX_VALIDATE_STATUS (
362- ort -> SetIntraOpNumThreads (session_options , (int )opts .backends_intra_op_parallelism ))
363+ ort -> SetIntraOpNumThreads (session_options , (int )model -> opts .backends_intra_op_parallelism ))
363364 ONNX_VALIDATE_STATUS (
364- ort -> SetInterOpNumThreads (session_options , (int )opts .backends_inter_op_parallelism ))
365+ ort -> SetInterOpNumThreads (session_options , (int )model -> opts .backends_inter_op_parallelism ))
365366
366367 // If the model is set for GPU, this will set CUDA provider for the session,
367368 // so that onnx will use its own allocator for CUDA (not Redis allocator)
368- if (!setDeviceId (devicestr , session_options , error )) {
369+ if (!setDeviceId (model -> devicestr , session_options , error )) {
369370 ort -> ReleaseSessionOptions (session_options );
370- return NULL ;
371+ return REDISMODULE_ERR ;
371372 }
372373
373374 ONNX_VALIDATE_STATUS (
374- ort -> CreateSessionFromArray (env , modeldef , modellen , session_options , & session ))
375+ ort -> CreateSessionFromArray (env , model -> data , model -> datalen , session_options , & session ))
375376 ort -> ReleaseSessionOptions (session_options );
376377
378+ model -> session = session ;
379+
377380 size_t n_input_nodes ;
378- ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
379381 size_t n_output_nodes ;
382+
383+ // We save the model's inputs and outputs only in the first time that we create the model.
384+ // We might create the model again when loading from RDB, in this case the inputs and outputs
385+ // are already loaded from RDB.
386+ // if (!model->inputs) {
387+ ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
380388 ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes ))
381389
382390 inputs_ = array_new (char * , n_input_nodes );
@@ -393,27 +401,13 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
393401 outputs_ = array_append (outputs_ , output_name );
394402 }
395403
396- // Since ONNXRuntime doesn't have a re-serialization function,
397- // we cache the blob in order to re-serialize it.
398- // Not optimal for storage purposes, but again, it may be temporary
399- char * buffer = RedisModule_Calloc (modellen , sizeof (* buffer ));
400- memcpy (buffer , modeldef , modellen );
401-
402- RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
403- ret -> model = NULL ;
404- ret -> session = session ;
405- ret -> backend = backend ;
406- ret -> devicestr = RedisModule_Strdup (devicestr );
407- ret -> refCount = 1 ;
408- ret -> opts = opts ;
409- ret -> data = buffer ;
410- ret -> datalen = modellen ;
411- ret -> ninputs = n_input_nodes ;
412- ret -> noutputs = n_output_nodes ;
413- ret -> inputs = inputs_ ;
414- ret -> outputs = outputs_ ;
415-
416- return ret ;
404+ model -> ninputs = n_input_nodes ;
405+ model -> noutputs = n_output_nodes ;
406+ model -> inputs = inputs_ ;
407+ model -> outputs = outputs_ ;
408+ //}
409+
410+ return REDISMODULE_OK ;
417411
418412error :
419413 RAI_SetError (error , RAI_EMODELCREATE , ort -> GetErrorMessage (status ));
@@ -438,28 +432,32 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
438432 ort -> ReleaseSession (session );
439433 }
440434 ort -> ReleaseStatus (status );
441- return NULL ;
435+ return REDISMODULE_ERR ;
442436}
443437
444438void RAI_ModelFreeORT (RAI_Model * model , RAI_Error * error ) {
445439 const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
446440 OrtStatus * status = NULL ;
447441
448- for (uint32_t i = 0 ; i < model -> ninputs ; i ++ ) {
449- ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
442+ if (model -> inputs ) {
443+ for (uint32_t i = 0 ; i < model -> ninputs ; i ++ ) {
444+ ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
445+ }
446+ array_free (model -> inputs );
447+ model -> inputs = NULL ;
450448 }
451- array_free (model -> inputs );
452449
453- for (uint32_t i = 0 ; i < model -> noutputs ; i ++ ) {
454- ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> outputs [i ]))
450+ if (model -> outputs ) {
451+ for (uint32_t i = 0 ; i < model -> noutputs ; i ++ ) {
452+ ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> outputs [i ]))
453+ }
454+ array_free (model -> outputs );
455+ model -> outputs = NULL ;
455456 }
456- array_free (model -> outputs );
457457
458- RedisModule_Free (model -> devicestr );
459- RedisModule_Free (model -> data );
460- ort -> ReleaseSession (model -> session );
461- model -> model = NULL ;
462- model -> session = NULL ;
458+ if (model -> session ) {
459+ ort -> ReleaseSession (model -> session );
460+ }
463461 return ;
464462
465463error :
0 commit comments