@@ -45,16 +45,6 @@ def __getitems__(self, keys: Iterable[int]) -> T:
4545 """Returns the value for the given `keys`."""
4646
4747
48- def file_instructions (
49- dataset_info : dataset_info_lib .DatasetInfo ,
50- split : splits_lib .Split | None = None ,
51- ) -> list [shard_utils .FileInstruction ]:
52- """Retrieves the file instructions from the DatasetInfo."""
53- split_infos = dataset_info .splits .values ()
54- split_dict = splits_lib .SplitDict (split_infos = split_infos )
55- return split_dict [split ].file_instructions
56-
57-
5848@dataclasses .dataclass
5949class BaseDataSource (MappingView , Sequence ):
6050 """Base DataSource to override all dunder methods with the deserialization.
@@ -94,6 +84,13 @@ def _deserialize(self, record: Any) -> Any:
9484 return features .deserialize_example_np (record , decoders = self .decoders ) # pylint: disable=attribute-error
9585 raise ValueError ('No features set, cannot decode example!' )
9686
87+ @property
88+ def split_info (self ) -> splits_lib .SplitInfo | splits_lib .SubSplitInfo :
89+ """Returns the SplitInfo for the split."""
90+ split_infos = self .dataset_info .splits .values ()
91+ splits_dict = splits_lib .SplitDict (split_infos = split_infos )
92+ return splits_dict [self .split ] # will raise an error if split is not found
93+
9794 def __getitem__ (self , key : SupportsIndex ) -> Any :
9895 record = self .data_source [key .__index__ ()]
9996 return self ._deserialize (record )
@@ -133,7 +130,7 @@ def __repr__(self) -> str:
133130 )
134131
135132 def __len__ (self ) -> int :
136- return self .data_source . __len__ ( )
133+ return sum ( fi . take for fi in self .split_info . file_instructions )
137134
138135 def __iter__ (self ):
139136 for i in range (self .__len__ ()):
0 commit comments