@@ -65,41 +65,40 @@ def __init__(self, input_files, ntuple_map, nevents=-1, vectorized=False, batch_
65
65
self .nevents = nevents
66
66
self ._trees = {}
67
67
68
- self ._used_arrays = False
69
- self ._used_trees = False
68
+ self ._used_arrays = vectorized
69
+ self ._used_trees = not vectorized
70
70
self ._passed_events = 0
71
71
self ._batch_size = batch_size
72
72
73
- if vectorized :
74
- self ._load_arrays ()
75
- else :
76
- self ._load_trees ()
73
+ self ._numentries = uproot .numentries (
74
+ self .input_files ,
75
+ list (self ._treeNames )[0 ],
76
+ total = False ,
77
+ )
77
78
78
- def _load_trees (self ):
79
+ def _load_trees (self , input_file ):
79
80
for treeName in self ._treeNames :
80
81
try :
81
82
self ._trees [treeName ] = TreeChain (
82
83
treeName ,
83
- self . input_files ,
84
+ [ input_file ] ,
84
85
cache = True ,
85
86
events = self .nevents ,
86
87
)
87
88
logger .debug ("Successfully loaded {0}" .format (treeName ))
88
89
except RuntimeError as e :
89
- logger .warning (
90
- "Cannot find tree: {0} in input file" .format (treeName ))
90
+ logger .warning ("Cannot find tree: {0} in input file" .format (treeName ))
91
91
logger .error (e )
92
92
if treeName in self ._trees :
93
93
logger .warning ('DELETING TREE' )
94
94
del self ._trees [treeName ]
95
95
continue
96
- self ._used_trees = True
97
96
98
- def _load_arrays (self ):
97
+ def _load_arrays (self , input_file ):
99
98
for treeName in self ._treeNames :
100
99
try :
101
100
self ._trees [treeName ] = uproot .iterate (
102
- self . input_files ,
101
+ [ input_file ] ,
103
102
treeName ,
104
103
entrysteps = self ._batch_size ,
105
104
)
@@ -112,27 +111,45 @@ def _load_arrays(self):
112
111
logger .warning ('DELETING TREE' )
113
112
del self ._trees [treeName ]
114
113
continue
115
- self ._used_arrays = True
116
114
117
115
def __contains__ (self , name ):
118
116
return name in self ._aliasMap .keys ()
119
117
120
118
def __iter__ (self ):
121
119
# event loop
122
- try :
123
- if self ._used_trees :
124
- for trees in six .moves .zip (* six .itervalues (self ._trees )):
125
- yield Event (self ._trees , self ._aliasMap )
120
+ for input_file in self .input_files :
121
+ nevents = self ._numentries [input_file ]
122
+ logger .info ('Opening file {} ({} events)' .format (input_file , nevents ))
126
123
if self ._used_arrays :
127
- for treeGen in six .moves .zip (* six .itervalues (self ._trees )):
128
- data = dict (six .moves .zip (self ._trees , treeGen ))
129
- yield UprootEvent (data , self ._aliasMap , batch_size = self ._batch_size )
130
- self ._passed_events += self ._batch_size
131
- if self .nevents > 0 and self ._passed_events >= self .nevents :
132
- break
133
- except Exception as e :
134
- logger .critical ("Error when reading data from ROOT file: {}" .format (e ))
135
- sys .exit (- 1 )
124
+ self ._load_arrays (input_file )
125
+ else :
126
+ self ._load_trees (input_file )
127
+ try :
128
+ eventGenerator = self .new_event
129
+ if self ._used_arrays :
130
+ eventGenerator = self .new_uproot_event
131
+ for event in eventGenerator ():
132
+ yield event
133
+ except Exception as e :
134
+ logger .critical ("Error when reading data from ROOT file: {}" .format (e ))
135
+ sys .exit (- 1 )
136
+ logger .info ('Closing file {}' .format (input_file ))
137
+
138
+ def new_event (self ):
139
+ for trees in six .moves .zip (* six .itervalues (self ._trees )):
140
+ yield Event (self ._trees , self ._aliasMap )
141
+
142
+ def new_uproot_event (self ):
143
+ for treeGen in six .moves .zip (* six .itervalues (self ._trees )):
144
+ data = dict (six .moves .zip (self ._trees , treeGen ))
145
+ yield UprootEvent (data , self ._aliasMap , batch_size = self ._batch_size )
146
+ self ._passed_events += self ._batch_size
147
+ if self .nevents > 0 and self ._passed_events >= self .nevents :
148
+ break
149
+
150
+ @property
151
+ def numentries (self ):
152
+ return sum ([x for x in self ._numentries .values ()])
136
153
137
154
138
155
class UprootEvent (object ):
0 commit comments