Skip to content

Commit 2e6f73b

Browse files
committed
Updated examples
1 parent afadccd commit 2e6f73b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/mnist/cpu.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ fn main() {
117117
let (images, labels) = load_mnist(&data_dir, "train").expect("Failed to load training data");
118118
println!("Loaded {} training images", images.shape()[0]);
119119

120-
println!("\nTraining with batch size 32...");
121-
nn.train(images, labels, SGD::new(0.01), 100, 32);
120+
println!("\nTraining with batch size 256...");
121+
nn.train(images, labels, SGD::new(0.01), 25, 256);
122122

123123
println!("\nSaving model to {}...", model_path);
124124
nn.save(model_path).expect("Failed to save model");

examples/mnist/gpu.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ fn main() {
111111
NeuralNetwork::load(model_path, CrossEntropy).expect("Failed to load model")
112112
} else {
113113
println!("Creating new model...");
114-
let dense_layer_1 = DenseLayer::new(28 * 28, 500, ReLU);
114+
let dense_layer_1 = DenseLayer::new(28 * 28, 100, ReLU);
115115
let dense_layer_2 = DenseLayer::new(500, 128, ReLU);
116116
let dense_layer_3 = DenseLayer::new(128, 10, Softmax);
117117
NeuralNetwork::new(
@@ -124,7 +124,7 @@ fn main() {
124124
println!("Loaded {} training images", images.shape()[0]);
125125

126126
println!("\nTraining with batch size 1024...");
127-
nn.train(images, labels, SGD::new(0.01), 100, 1024);
127+
nn.train(images, labels, SGD::new(0.01), 25, 1024);
128128

129129
println!("\nSaving model to {}...", model_path);
130130
nn.save(model_path).expect("Failed to save model");

0 commit comments

Comments
 (0)