From e36d46633a885360bd3b07739b1c16589e3bde42 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 13 Jan 2021 05:19:45 -0500 Subject: [PATCH] Add MNIST dataset (#1730) * Add MNIST dataset * Update datasets/mnist/README.md Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- datasets/mnist/README.md | 146 ++++++++++++++++++ datasets/mnist/dataset_infos.json | 1 + .../mnist/dummy/mnist/1.0.0/dummy_data.zip | Bin 0 -> 2672 bytes datasets/mnist/mnist.py | 116 ++++++++++++++ 4 files changed, 263 insertions(+) create mode 100644 datasets/mnist/README.md create mode 100644 datasets/mnist/dataset_infos.json create mode 100644 datasets/mnist/dummy/mnist/1.0.0/dummy_data.zip create mode 100644 datasets/mnist/mnist.py diff --git a/datasets/mnist/README.md b/datasets/mnist/README.md new file mode 100644 index 00000000000..9a05178b13d --- /dev/null +++ b/datasets/mnist/README.md @@ -0,0 +1,146 @@ +--- +annotations_creators: +- experts +language_creators: +- found +languages: [] +licenses: +- MIT +multilinguality: [] +size_categories: +- 10K0B1BkB+P5fLc}0CG=)0D#YfBS!?vd5sSc!Vwv! z$=(U(Jn|C;0F<|igomLBAuvKP3Wq%dBVZzRU=ir3aI7Zo*U$1B9O;JdkaLZ1o{gYK|&Da$Yp!G@YD}-muu2;X{<3vT<+hyTYc=UW)o64 zz$v5MvcW6mlxW1YH;xA606AkX3)ppkZLfhD_!CPd^n{RrW~Gfru9bPW_~TDq=82~b z?ZhvpcPqXfM?$hUGU`K>^qk+@&s7a~w48>#PmV*?_a|rW?5+*^#L7N!L}N_M8@Y12 z!@{_Ou=7T~abYHHwK3yWJ*KL*Pkk1Kvc~UGXIn=Bsqi}_e-O_H3+beNd>r1 z;gje*5ziBN@}jC|Ri+~tBJ8LuLiw`k=>w|tC8+zYcyheLm0j+$ecFgohTrL<0@BkG zQ=mw1Td$#0KDW<>nG;HgvA!`n*|+ z!4$FX%^Jb#-OXML42T;_&%pDV*^DI`>tw$L`Q4~OVNKQZ3MQIoqVxCFL((4lsiDo& z>S9c^Y_E^%@y0&g1YGN2w5|KoX0Fbq^^>B8ZjYg(ZV!FZr<3j~t2|@Ygpy3*6W)Q0 zeq1dr7>8H&XYO}z%BN!Qz-iV`1jJTL@~jJI^O@$=^b*Ph_=#=G?tQMFe_k>X48M9) zY>32KTogcqE|n^*$if`%duCu<{!0|0=P5eyTX#%aOl2H0#;YFo*+=4Tb-{xb#Iowld~yH^UQzf4z=M zO}A_(C7H7-!amZ0?IL-+1wE-KeUM~n>LfQ@&PM30VUmQC{x$VFoqN%PB64CGn^Dt7eT<)C{=)o^$>FxJSU%$IA7%Oi^?wU@n9+kV+teDh0 zk#O6dD`!*~za%1uRKR9x@BTYA_I{HZ2M+{p z9~#0O(4+4`5!Gc^SqZrqHQ<{Nuc(d&e!X!U1YVYBz=Kif`TS6%d3ed7t}W zuIwn9w!(_H3^?XQf+uNn5vbTg$V5Ihvrbi<&{O(m?YVK@(%aem5zh}4-@e6oW*qM9 z(hV9-oo{_|**|r8*&%bGx{|n|A1s44xlcT3pL`+UIrEQm^Or6!oYjcK(&60=s^Pgr*H`z%cEcat~*214A)D|xS>-9HG+(LAtANF*5U_j>*ifX3( zkIR^4q(CGJnb3XoXyLa}c~+}M%ZLo{&f6`ztIQ)>c&aH+NL3=f)5*ZFlfposKS_ES zXW^ZzZ)N{GR6lDPki)lo!3?a8wv-q~^0F>g^UIyzPl{n!FpEXQe~qPgpRC~L*F7ga zr<;_6?e{<_3x?EI6zmyoA$FthE_=+L(0Jbm{_R!MU9YZVFGO$WFpIJegj&H8QKFtA z4@?u$7<4|1p+gs~5eUzKWcY^ay%;y99=;o6!6h1YcqzFTHG*Amzc_0`CBuaHOseW?t%)*q&$J3oXRftp1*2jTH@xPw&aBjf*;XyfQ3NN0%XV$E3cc=@ko^BWB$EDx^5GhiFG+IMnsXmtfI3 z4m-l^a;35vqFOmUlf&xBF&~QObFW5MBs?+3CbYQYVM1%oeqSpCvP+pp@t@id#VNMd zV{7?V*fK3V9l~Eeg(?2cft}rT*15c)Eje`6D$45^oKRKJUMVYQ{vU_3l#;~--qGXsOWD{+eagZw7EY2hqQI=za}k$w@zx4wC$XJ VNaEq$TvQ&;{s{;G`0{Pu{RKmqCDs4{ literal 0 HcmV?d00001 diff --git a/datasets/mnist/mnist.py b/datasets/mnist/mnist.py new file mode 100644 index 00000000000..58e48b22b55 --- /dev/null +++ b/datasets/mnist/mnist.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""MNIST Data Set""" + +from __future__ import absolute_import, division, print_function + +import struct + +import numpy as np + +import datasets + + +_CITATION = """\ +@article{lecun2010mnist, + title={MNIST handwritten digit database}, + author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, + journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist}, + volume={2}, + year={2010} +} +""" + +_DESCRIPTION = """\ +The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000 +images per class. There are 60,000 training images and 10,000 test images. +""" + +_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" +_URLS = { + "train_images": "train-images-idx3-ubyte.gz", + "train_labels": "train-labels-idx1-ubyte.gz", + "test_images": "t10k-images-idx3-ubyte.gz", + "test_labels": "t10k-labels-idx1-ubyte.gz", +} + + +class MNIST(datasets.GeneratorBasedBuilder): + """MNIST Data Set""" + + BUILDER_CONFIGS = [ + datasets.BuilderConfig( + name="mnist", + version=datasets.Version("1.0.0"), + description=_DESCRIPTION, + ) + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "image": datasets.Array2D(shape=(28, 28), dtype="uint8"), + "label": datasets.features.ClassLabel(names=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]), + } + ), + supervised_keys=("image", "label"), + homepage="http://yann.lecun.com/exdb/mnist/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + urls_to_download = {key: _URL + fname for key, fname in _URLS.items()} + downloaded_files = dl_manager.download_and_extract(urls_to_download) + print(downloaded_files) + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepath": [downloaded_files["train_images"], downloaded_files["train_labels"]], + "split": "train", + }, + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepath": [downloaded_files["test_images"], downloaded_files["test_labels"]], + "split": "test", + }, + ), + ] + + def _generate_examples(self, filepath, split): + """This function returns the examples in the raw form.""" + # Images + with open(filepath[0], "rb") as f: + # First 16 bytes contain some metadata + _ = f.read(4) + size = struct.unpack(">I", f.read(4))[0] + _ = f.read(8) + images = np.frombuffer(f.read(), dtype=np.uint8).reshape(size, 28, 28) + + # Labels + with open(filepath[1], "rb") as f: + # First 8 bytes contain some metadata + _ = f.read(8) + labels = np.frombuffer(f.read(), dtype=np.uint8) + + for idx in range(size): + yield idx, {"image": images[idx], "label": str(labels[idx])}