Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tmva/tmva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(TMVAUtils
TMVA/BatchGenerator/RBatchLoader.hxx
TMVA/BatchGenerator/RChunkLoader.hxx
TMVA/BatchGenerator/RChunkConstructor.hxx
TMVA/BatchGenerator/RFlat2DMatrix.hxx

SOURCES

Expand Down
23 changes: 10 additions & 13 deletions tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#ifndef TMVA_RBATCHGENERATOR
#define TMVA_RBATCHGENERATOR

#include "TMVA/RTensor.hxx"
#include "TMVA/BatchGenerator/RFlat2DMatrix.hxx"
#include "ROOT/RDF/RDatasetSpec.hxx"
#include "TMVA/BatchGenerator/RChunkLoader.hxx"
#include "TMVA/BatchGenerator/RBatchLoader.hxx"
Expand Down Expand Up @@ -100,11 +100,12 @@ private:
std::size_t fNumTrainingBatches;
std::size_t fNumValidationBatches;

TMVA::Experimental::RTensor<float> fTrainTensor;
TMVA::Experimental::RTensor<float> fTrainChunkTensor;
// flattened buffers for chunks and temporary tensors (rows * cols)
RFlat2DMatrix fTrainTensor;
RFlat2DMatrix fTrainChunkTensor;

TMVA::Experimental::RTensor<float> fValidationTensor;
TMVA::Experimental::RTensor<float> fValidationChunkTensor;
RFlat2DMatrix fValidationTensor;
RFlat2DMatrix fValidationChunkTensor;

public:
RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
Expand All @@ -125,11 +126,7 @@ public:
fShuffle(shuffle),
fNotFiltered(f_rdf.GetFilterNames().empty()),
fUseWholeFile(maxChunks == 0),
fNumColumns(cols.size()),
fTrainTensor({0, 0}),
fTrainChunkTensor({0, 0}),
fValidationTensor({0, 0}),
fValidationChunkTensor({0, 0})
fNumColumns(cols.size())
{

fNumEntries = f_rdf.Count().GetValue();
Expand Down Expand Up @@ -255,7 +252,7 @@ public:
}

/// \brief Loads a training batch from the queue
TMVA::Experimental::RTensor<float> GetTrainBatch()
RFlat2DMatrix GetTrainBatch()
{
auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();

Expand All @@ -276,8 +273,8 @@ public:
return fBatchLoader->GetTrainBatch();
}

/// \brief Loads a validation batch from the queue
TMVA::Experimental::RTensor<float> GetValidationBatch()
/// \brief Loads a validation batch from the queue
RFlat2DMatrix GetValidationBatch()
{
auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();

Expand Down
Loading
Loading