Skip to content

Commit 5e37505

Browse files
committed
scripts to run training on mnist
1 parent e21b4ee commit 5e37505

13 files changed

+752
-0
lines changed

Datasets/DatasetUtilities.swift

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
import Foundation
15+
import ModelSupport
16+
17+
#if canImport(FoundationNetworking)
18+
import FoundationNetworking
19+
#endif
20+
21+
public enum DatasetUtilities {
22+
public static let currentWorkingDirectoryURL = URL(
23+
fileURLWithPath: FileManager.default.currentDirectoryPath)
24+
25+
public static func downloadResource(
26+
filename: String,
27+
fileExtension: String,
28+
remoteRoot: URL,
29+
localStorageDirectory: URL = currentWorkingDirectoryURL
30+
) -> URL {
31+
printError("Loading resource: \(filename)")
32+
33+
let resource = ResourceDefinition(
34+
filename: filename,
35+
fileExtension: fileExtension,
36+
remoteRoot: remoteRoot,
37+
localStorageDirectory: localStorageDirectory)
38+
39+
let localURL = resource.localURL
40+
41+
if !FileManager.default.fileExists(atPath: localURL.path) {
42+
printError(
43+
"File does not exist locally at expected path: \(localURL.path) and must be fetched"
44+
)
45+
fetchFromRemoteAndSave(resource)
46+
}
47+
48+
return localURL
49+
}
50+
51+
public static func fetchResource(
52+
filename: String,
53+
fileExtension: String,
54+
remoteRoot: URL,
55+
localStorageDirectory: URL = currentWorkingDirectoryURL
56+
) -> Data {
57+
let localURL = DatasetUtilities.downloadResource(
58+
filename: filename, fileExtension: fileExtension, remoteRoot: remoteRoot,
59+
localStorageDirectory: localStorageDirectory)
60+
61+
do {
62+
let data = try Data(contentsOf: localURL)
63+
return data
64+
} catch {
65+
fatalError("Failed to contents of resource: \(localURL)")
66+
}
67+
}
68+
69+
struct ResourceDefinition {
70+
let filename: String
71+
let fileExtension: String
72+
let remoteRoot: URL
73+
let localStorageDirectory: URL
74+
75+
var localURL: URL {
76+
localStorageDirectory.appendingPathComponent(filename)
77+
}
78+
79+
var remoteURL: URL {
80+
remoteRoot.appendingPathComponent(filename).appendingPathExtension(fileExtension)
81+
}
82+
83+
var archiveURL: URL {
84+
localURL.appendingPathExtension(fileExtension)
85+
}
86+
}
87+
88+
static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
89+
let remoteLocation = resource.remoteURL
90+
let archiveLocation = resource.localStorageDirectory
91+
92+
do {
93+
printError("Fetching URL: \(remoteLocation)...")
94+
try download(from: remoteLocation, to: archiveLocation)
95+
} catch {
96+
fatalError("Failed to fetch and save resource with error: \(error)")
97+
}
98+
printError("Archive saved to: \(archiveLocation.path)")
99+
100+
extractArchive(for: resource)
101+
}
102+
103+
static func extractArchive(for resource: ResourceDefinition) {
104+
printError("Extracting archive...")
105+
106+
let archivePath = resource.archiveURL.path
107+
108+
#if os(macOS)
109+
let binaryLocation = "/usr/bin/"
110+
#else
111+
let binaryLocation = "/bin/"
112+
#endif
113+
114+
let toolName: String
115+
let arguments: [String]
116+
switch resource.fileExtension {
117+
case "gz":
118+
toolName = "gunzip"
119+
arguments = [archivePath]
120+
case "tar.gz", "tgz":
121+
toolName = "tar"
122+
arguments = ["xzf", archivePath, "-C", resource.localStorageDirectory.path]
123+
default:
124+
printError("Unable to find archiver for extension \(resource.fileExtension).")
125+
exit(-1)
126+
}
127+
let toolLocation = "\(binaryLocation)\(toolName)"
128+
129+
let task = Process()
130+
task.executableURL = URL(fileURLWithPath: toolLocation)
131+
task.arguments = arguments
132+
do {
133+
try task.run()
134+
task.waitUntilExit()
135+
} catch {
136+
printError("Failed to extract \(archivePath) with error: \(error)")
137+
exit(-1)
138+
}
139+
140+
if FileManager.default.fileExists(atPath: archivePath) {
141+
do {
142+
try FileManager.default.removeItem(atPath: archivePath)
143+
} catch {
144+
printError("Could not remove archive, error: \(error)")
145+
exit(-1)
146+
}
147+
}
148+
}
149+
}
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
import TensorFlow
15+
16+
public protocol ImageClassificationDataset {
17+
init()
18+
var trainingDataset: Dataset<LabeledExample> { get }
19+
var testDataset: Dataset<LabeledExample> { get }
20+
var trainingExampleCount: Int { get }
21+
var testExampleCount: Int { get }
22+
}

Datasets/LabeledExample.swift

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
import TensorFlow
15+
16+
public struct LabeledExample: TensorGroup {
17+
public var label: Tensor<Int32>
18+
public var data: Tensor<Float>
19+
20+
public init(label: Tensor<Int32>, data: Tensor<Float>) {
21+
self.label = label
22+
self.data = data
23+
}
24+
25+
public init<C: RandomAccessCollection>(
26+
_handles: C
27+
) where C.Element: _AnyTensorHandle {
28+
precondition(_handles.count == 2)
29+
let labelIndex = _handles.startIndex
30+
let dataIndex = _handles.index(labelIndex, offsetBy: 1)
31+
label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
32+
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
33+
}
34+
}

Datasets/MNIST/MNIST.swift

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// Original source:
15+
// "The MNIST database of handwritten digits"
16+
// Yann LeCun, Corinna Cortes, and Christopher J.C. Burges
17+
// http://yann.lecun.com/exdb/mnist/
18+
import Foundation
19+
import TensorFlow
20+
21+
public struct MNIST: ImageClassificationDataset {
22+
public let trainingDataset: Dataset<LabeledExample>
23+
public let testDataset: Dataset<LabeledExample>
24+
public let trainingExampleCount = 60000
25+
public let testExampleCount = 10000
26+
27+
public init() {
28+
self.init(flattening: false, normalizing: false)
29+
}
30+
31+
public init(
32+
flattening: Bool = false, normalizing: Bool = false,
33+
localStorageDirectory: URL = FileManager.default.temporaryDirectory.appendingPathComponent(
34+
"MNIST")
35+
) {
36+
self.trainingDataset = Dataset<LabeledExample>(
37+
elements: fetchDataset(
38+
localStorageDirectory: localStorageDirectory,
39+
imagesFilename: "train-images-idx3-ubyte",
40+
labelsFilename: "train-labels-idx1-ubyte",
41+
flattening: flattening,
42+
normalizing: normalizing))
43+
44+
self.testDataset = Dataset<LabeledExample>(
45+
elements: fetchDataset(
46+
localStorageDirectory: localStorageDirectory,
47+
imagesFilename: "t10k-images-idx3-ubyte",
48+
labelsFilename: "t10k-labels-idx1-ubyte",
49+
flattening: flattening,
50+
normalizing: normalizing))
51+
}
52+
}
53+
54+
fileprivate func fetchDataset(
55+
localStorageDirectory: URL,
56+
imagesFilename: String,
57+
labelsFilename: String,
58+
flattening: Bool,
59+
normalizing: Bool
60+
) -> LabeledExample {
61+
guard let remoteRoot = URL(string: "https://storage.googleapis.com/cvdf-datasets/mnist") else {
62+
fatalError("Failed to create MNIST root url: https://storage.googleapis.com/cvdf-datasets/mnist")
63+
}
64+
65+
let imagesData = DatasetUtilities.fetchResource(
66+
filename: imagesFilename,
67+
fileExtension: "gz",
68+
remoteRoot: remoteRoot,
69+
localStorageDirectory: localStorageDirectory)
70+
let labelsData = DatasetUtilities.fetchResource(
71+
filename: labelsFilename,
72+
fileExtension: "gz",
73+
remoteRoot: remoteRoot,
74+
localStorageDirectory: localStorageDirectory)
75+
76+
let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
77+
let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)
78+
79+
let rowCount = labels.count
80+
let (imageWidth, imageHeight) = (28, 28)
81+
82+
if flattening {
83+
var flattenedImages =
84+
Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
85+
/ 255.0
86+
if normalizing {
87+
flattenedImages = flattenedImages * 2.0 - 1.0
88+
}
89+
return LabeledExample(label: Tensor(labels), data: flattenedImages)
90+
} else {
91+
return LabeledExample(
92+
label: Tensor(labels),
93+
data:
94+
Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
95+
.transposed(permutation: [0, 2, 3, 1]) / 255 // NHWC
96+
)
97+
}
98+
}

Examples/LeNet-MNIST/main.swift

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import TensorFlow
2+
import Datasets
3+
4+
let epochCount = 12
5+
let batchSize = 128
6+
7+
let dataset = MNIST()
8+
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
9+
var classifier = Sequential {
10+
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
11+
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
12+
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
13+
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
14+
Flatten<Float>()
15+
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
16+
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
17+
Dense<Float>(inputSize: 84, outputSize: 10)
18+
}
19+
20+
let optimizer = SGD(for: classifier, learningRate: 0.1)
21+
22+
print("Beginning training...")
23+
24+
struct Statistics {
25+
var correctGuessCount: Int = 0
26+
var totalGuessCount: Int = 0
27+
var totalLoss: Float = 0
28+
var batches: Int = 0
29+
}
30+
31+
let testBatches = dataset.testDataset.batched(batchSize)
32+
33+
// The training loop.
34+
for epoch in 1...epochCount {
35+
var trainStats = Statistics()
36+
var testStats = Statistics()
37+
let trainingShuffled = dataset.trainingDataset.shuffled(
38+
sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
39+
40+
Context.local.learningPhase = .training
41+
for batch in trainingShuffled.batched(batchSize) {
42+
let (labels, images) = (batch.label, batch.data)
43+
// Compute the gradient with respect to the model.
44+
let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor<Float> in
45+
let ŷ = classifier(images)
46+
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
47+
trainStats.correctGuessCount += Int(
48+
Tensor<Int32>(correctPredictions).sum().scalarized())
49+
trainStats.totalGuessCount += batch.data.shape[0]
50+
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
51+
trainStats.totalLoss += loss.scalarized()
52+
trainStats.batches += 1
53+
return loss
54+
}
55+
// Update the model's differentiable variables along the gradient vector.
56+
optimizer.update(&classifier, along: 𝛁model)
57+
}
58+
59+
Context.local.learningPhase = .inference
60+
for batch in testBatches {
61+
let (labels, images) = (batch.label, batch.data)
62+
// Compute loss on test set
63+
let ŷ = classifier(images)
64+
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
65+
testStats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
66+
testStats.totalGuessCount += batch.data.shape[0]
67+
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
68+
testStats.totalLoss += loss.scalarized()
69+
testStats.batches += 1
70+
}
71+
72+
let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
73+
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
74+
print(
75+
"""
76+
[Epoch \(epoch)] \
77+
Training Loss: \(trainStats.totalLoss / Float(trainStats.batches)), \
78+
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
79+
(\(trainAccuracy)), \
80+
Test Loss: \(testStats.totalLoss / Float(testStats.batches)), \
81+
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
82+
(\(testAccuracy))
83+
""")
84+
}

0 commit comments

Comments
 (0)