Skip to content

Commit 9cee8e0

Browse files
author
Theresa
committed
mnist: fix with boost::multi
1 parent 1feafb0 commit 9cee8e0

2 files changed

Lines changed: 89 additions & 107 deletions

File tree

benchmarks/mnist/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ target_link_libraries( mnist_cpp PUBLIC dalotia::dalotia_cpp )
44
target_include_directories( mnist_cpp PUBLIC ${BLAS_INCLUDE_DIRS})
55
target_link_libraries(mnist_cpp PRIVATE BLAS::BLAS)
66
if (DALOTIA_E_WITH_BOOST_MULTI)
7-
target_compile_options( mnist_cpp PUBLIC "-DDALOTIA_E_WITH_BOOST_MULTI")
8-
target_include_directories( mnist_cpp PUBLIC ${MULTI_CPP_INCLUDE_DIR} ${MULTI_DIR})
9-
message(STATUS "multi: ${MULTI_DIR}")
10-
add_dependencies( mnist_cpp multi ) #TODO tblis?
7+
target_compile_definitions( mnist_cpp PUBLIC DALOTIA_E_WITH_BOOST_MULTI)
8+
target_link_libraries( mnist_cpp PUBLIC multi )
119
endif(DALOTIA_E_WITH_BOOST_MULTI)
1210
if (DALOTIA_E_WITH_NDIRECT)
1311
target_compile_options( mnist_cpp PUBLIC "-DDALOTIA_E_WITH_NDIRECT")

benchmarks/mnist/mnist.cpp

Lines changed: 87 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,16 @@
99
#include "dalotia.hpp"
1010
#include "dalotia_safetensors_file.hpp"
1111

12-
// for cblas_sgemm:
13-
#include "cblas.h"
14-
15-
// Kokkos mdspan?
16-
// #include "mdspan/mdspan.hpp"
17-
12+
// Boost.Multi must be included BEFORE cblas.h — multi's blas/core.hpp
13+
// uses an #ifdef CBLAS_H guard that skips type definitions needed later.
1814
#ifdef DALOTIA_E_WITH_BOOST_MULTI
1915
#include <boost/multi/array.hpp>
20-
#include <multi/adaptors/blas.hpp>
21-
// #include <multi/adaptors/tblis.hpp>
16+
#include <boost/multi/adaptors/blas.hpp>
2217
#endif // DALOTIA_E_WITH_BOOST_MULTI
2318

19+
// for cblas_sgemm:
20+
#include "cblas.h"
21+
2422
#ifdef DALOTIA_E_WITH_NDIRECT
2523
#include <NDIRECT_direct.h>
2624
#endif // DALOTIA_E_WITH_NDIRECT
@@ -465,68 +463,62 @@ std::ostream &operator<<(
465463
return os;
466464
}
467465

468-
void run_inference_boost_multi(std::string filename) {
466+
std::chrono::duration<double> run_inference_boost_multi(
467+
const dalotia::vector<float> &conv1_weight,
468+
const dalotia::vector<float> &conv1_bias,
469+
const std::array<int, 4> &conv1_weight_extents,
470+
const dalotia::vector<float> &conv2_weight,
471+
const dalotia::vector<float> &conv2_bias,
472+
const std::array<int, 4> &conv2_weight_extents,
473+
const dalotia::vector<float> &fc1_weight,
474+
const dalotia::vector<float> &fc1_bias,
475+
const std::array<int, 2> &fc1_weight_extents,
476+
const dalotia::vector<float> &images, const dalotia::vector<float> &labels,
477+
dalotia::vector<int> &results) {
469478
using span_4d_float = multi::array_ref<float, 4>;
470479
using span_3d_float = multi::array_ref<float, 3>;
471480
using span_2d_float = multi::array_ref<float, 2>;
472481

473-
auto [conv1_weight, conv1_bias] =
474-
test_load(filename, "conv1"); // TODO why can't I make them const?
475-
const auto conv1_weight_span =
476-
span_4d_float({8, 1, 3, 3}, conv1_weight.data());
477-
const auto conv1_bias_span = span_2d_float({8, 1}, conv1_bias.data());
482+
const auto conv1_weight_span = span_4d_float(
483+
{conv1_weight_extents[0], conv1_weight_extents[1],
484+
conv1_weight_extents[2], conv1_weight_extents[3]},
485+
const_cast<float*>(conv1_weight.data()));
486+
const auto conv1_bias_span = span_2d_float(
487+
{conv1_weight_extents[0], 1}, const_cast<float*>(conv1_bias.data()));
478488
assert(conv1_weight_span.sizes().get<1>() == 1); // 1 input channel
479489

480-
auto [conv2_weight, conv2_bias] = test_load(filename, "conv2");
481-
const auto conv2_weight_span =
482-
span_4d_float({16, 8, 3, 3}, conv1_weight.data());
483-
const auto conv2_bias_span = span_2d_float({16, 1}, conv2_bias.data());
484-
485-
auto [fc1_weight, fc1_bias] = test_load(filename, "fc1");
486-
const auto fc1_weight_span = span_2d_float({10, 784}, fc1_weight.data());
487-
const auto fc1_bias_span = span_2d_float({10, 1}, fc1_bias.data());
490+
const auto conv2_weight_span = span_4d_float(
491+
{conv2_weight_extents[0], conv2_weight_extents[1],
492+
conv2_weight_extents[2], conv2_weight_extents[3]},
493+
const_cast<float*>(conv2_weight.data()));
494+
const auto conv2_bias_span = span_2d_float(
495+
{conv2_weight_extents[0], 1}, const_cast<float*>(conv2_bias.data()));
488496

489-
// load the mnist test data // as in
490-
// https://medium.com/@myringoleMLGOD/simple-convolutional-neural-network-cnn-for-dummies-in-pytorch-a-step-by-step-guide-6f4109f6df80
491-
// too
492-
std::string mnist_test_images_filename = "t10k-images-idx3-ubyte";
493-
std::string mnist_test_labels_filename = "t10k-labels-idx3-ubyte";
497+
const auto fc1_weight_span = span_2d_float(
498+
{fc1_weight_extents[0], fc1_weight_extents[1]},
499+
const_cast<float*>(fc1_weight.data()));
494500

495-
auto images = read_mnist_scaled(mnist_test_images_filename);
496-
// auto labels = read_mnist(mnist_test_labels_filename);
497501
auto total_num_images = images.size() / (28 * 28);
498502

499-
// minibatching
500503
constexpr size_t batch_size = 64;
501-
auto num_batches = static_cast<int>(
504+
auto num_batches = static_cast<size_t>(
502505
std::ceil(total_num_images / static_cast<float>(batch_size)));
503-
for (size_t batch_index = 0; batch_index < 1; ++batch_index) {
506+
507+
const auto start = std::chrono::high_resolution_clock::now();
508+
for (size_t batch_index = 0; batch_index < num_batches; ++batch_index) {
504509
auto num_images_in_batch =
505510
std::min(batch_size, total_num_images - batch_index * batch_size);
506511
auto inum_images_in_batch = static_cast<int>(num_images_in_batch);
507-
std::cout << "batch index: " << batch_index << " / " << num_batches
508-
<< " num images in batch: " << num_images_in_batch
509-
<< std::endl;
510-
511512
// apply first convolution
512513
// copy data to larger array for zero-padding at the edges
513514
auto image_vector_padded =
514515
dalotia::vector<float>(num_images_in_batch * 30 * 30);
515516
auto image_padded_span = span_3d_float({inum_images_in_batch, 30, 30},
516517
image_vector_padded.data());
517518

518-
std::cout << "image_padded "
519-
<< image_padded_span(
520-
0, multi::_,
521-
multi::_) // <- TODO why does this segfault on fugaku?
522-
<< std::endl;
523-
524519
image_padded_span(multi::_, {1, 29}, {1, 29}) =
525520
span_3d_float({inum_images_in_batch, 28, 28},
526-
images.data() + batch_index * (batch_size * 28 * 28));
527-
528-
std::cout << "image_padded " << image_padded_span(0, multi::_, multi::_)
529-
<< std::endl;
521+
const_cast<float*>(images.data()) + batch_index * (batch_size * 28 * 28));
530522

531523
auto conv1_output =
532524
dalotia::vector<float>(num_images_in_batch * 8 * 28 * 28);
@@ -584,15 +576,15 @@ void run_inference_boost_multi(std::string filename) {
584576
auto conv1_output_pooled_span = span_4d_float(
585577
{inum_images_in_batch, 8, 14, 14}, conv1_output_pooled.data());
586578
#pragma omp parallel for
587-
for (int o = 0; o < num_images_in_batch; ++o) {
579+
for (int o = 0; o < inum_images_in_batch; ++o) {
588580
for (int k = 0; k < 8; ++k) {
589581
for (int i = 0; i < 14; ++i) {
590582
for (int j = 0; j < 14; ++j) {
591-
auto window = conv1_output_span(
592-
o, k, {2 * i, 2 * i + 1}, {2 * j, 2 * j + 1});
593-
auto max_val = (*std::max_element(window.begin(),
594-
window.end()))[0];
595-
conv1_output_pooled_span(o, k, i, j) = max_val;
583+
float mv = 0;
584+
for (int m = 0; m < 2; ++m)
585+
for (int n = 0; n < 2; ++n)
586+
mv = std::max(mv, conv1_output_span(o, k, 2*i+m, 2*j+n));
587+
conv1_output_pooled_span(o, k, i, j) = mv;
596588
}
597589
}
598590
}
@@ -637,26 +629,26 @@ void run_inference_boost_multi(std::string filename) {
637629
for (int l = 0; l < conv2_weight_span.sizes().get<1>();
638630
++l) {
639631
value +=
640-
conv2_weight_span(l, k, 0, 0) *
632+
conv2_weight_span(k, l, 0, 0) *
641633
c_feature_padded_span(o, l, i - 1, j - 1) +
642-
conv2_weight_span(l, k, 0, 1) *
634+
conv2_weight_span(k, l, 0, 1) *
643635
c_feature_padded_span(o, l, i - 1, j + 0) +
644-
conv2_weight_span(l, k, 0, 2) *
636+
conv2_weight_span(k, l, 0, 2) *
645637
c_feature_padded_span(o, l, i - 1, j + 1) +
646-
conv2_weight_span(l, k, 1, 0) *
638+
conv2_weight_span(k, l, 1, 0) *
647639
c_feature_padded_span(o, l, i + 0, j - 1) +
648-
conv2_weight_span(l, k, 1, 1) *
640+
conv2_weight_span(k, l, 1, 1) *
649641
c_feature_padded_span(o, l, i + 0, j + 0) +
650-
conv2_weight_span(l, k, 1, 2) *
642+
conv2_weight_span(k, l, 1, 2) *
651643
c_feature_padded_span(o, l, i + 0, j + 1) +
652-
conv2_weight_span(l, k, 2, 0) *
644+
conv2_weight_span(k, l, 2, 0) *
653645
c_feature_padded_span(o, l, i + 1, j - 1) +
654-
conv2_weight_span(l, k, 2, 1) *
646+
conv2_weight_span(k, l, 2, 1) *
655647
c_feature_padded_span(o, l, i + 1, j + 0) +
656-
conv2_weight_span(l, k, 2, 2) *
657-
c_feature_padded_span(o, l, i + 1, j + 1) +
658-
conv2_bias_span(l, 0);
648+
conv2_weight_span(k, l, 2, 2) *
649+
c_feature_padded_span(o, l, i + 1, j + 1);
659650
}
651+
value += conv2_bias[k];
660652
// apply activation function (relu)
661653
if (value < 0.) {
662654
value = 0.;
@@ -667,71 +659,63 @@ void run_inference_boost_multi(std::string filename) {
667659
}
668660
}
669661

670-
std::cout << conv2_output_span(0, multi::_, multi::_, multi::_)
671-
<< std::endl;
672-
673-
if (batch_index == 0) { // compare to python result
674-
assert(conv2_output_span[0][0][0][0] < 0.4063);
675-
assert(conv2_output_span[0][0][0][0] > 0.4062);
676-
}
677-
678662
// apply max pooling
679663
dalotia::vector<float> conv2_output_pooled(num_images_in_batch * 16 *
680664
7 * 7);
681665
auto conv2_output_pooled_span = span_4d_float(
682666
{inum_images_in_batch, 16, 7, 7}, conv2_output_pooled.data());
683667
#pragma omp parallel for
684-
for (int o = 0; o < num_images_in_batch; ++o) {
668+
for (int o = 0; o < inum_images_in_batch; ++o) {
685669
for (int i = 0; i < 7; ++i) {
686670
for (int j = 0; j < 7; ++j) {
687671
for (int k = 0; k < 16; ++k) {
688-
auto window = conv2_output_span(
689-
o, k, {2 * i, 2 * i + 1}, {2 * j, 2 * j + 1});
690-
auto max_val = (*std::max_element(window.begin(),
691-
window.end()))[0];
692-
conv2_output_pooled_span(o, k, i, j) = max_val;
672+
float mv = 0;
673+
for (int m = 0; m < 2; ++m)
674+
for (int n = 0; n < 2; ++n)
675+
mv = std::max(mv, conv2_output_span(o, k, 2*i+m, 2*j+n));
676+
conv2_output_pooled_span(o, k, i, j) = mv;
693677
}
694678
}
695679
}
696680
}
697681

698682
// apply dense layer
683+
// fc1_output = conv2_flat @ fc1_weight^T + fc1_bias
699684
dalotia::vector<float> fc1_output(num_images_in_batch * 10);
700685
auto fc1_output_span =
701686
span_2d_float({inum_images_in_batch, 10}, fc1_output.data());
702687
auto conv2_output_flattened = span_2d_float(
703688
{inum_images_in_batch, 16 * 7 * 7}, conv2_output_pooled.data());
704-
// fc1_output_span = multi::blas::gemm(1., conv2_output_flattened,
705-
// //TODO use one of them!
706-
// fc1_weight_span.transposed());
707-
// using multi::operator+=; // doesn't work yet? ->
708-
// https://github.com/correaa/boost-multi/blob/master/include/boost/multi/adaptors/blas/README.md
709-
// footnote 3
710-
// std::transform(fc1_bias_span.begin(), fc1_bias_span.end(),
711-
// // appears to not work
712-
// fc1_output_span.begin(), fc1_output_span.begin(),
713-
// [](auto ex, auto ey) {
714-
// return ex[0] + ey[0];
715-
// }); // this would also be nicer without the [0]
716-
// indexing
717689

718-
// {
719-
// using namespace tblis::indices;
720-
// tblis::mult(fc1_weight_span(a, b), conv2_output_flattened(o,
721-
// b),
722-
// fc1_output(o, a));
723-
// }
690+
// fill with bias
691+
for (int o = 0; o < inum_images_in_batch; ++o) {
692+
for (int k = 0; k < 10; ++k) {
693+
fc1_output_span(o, k) = fc1_bias[k];
694+
}
695+
}
696+
// gemm: C = alpha * A @ B^T + beta * C
697+
multi::blas::gemm(1.f, conv2_output_flattened,
698+
fc1_weight_span.transposed(),
699+
1.f, fc1_output_span);
724700

725-
std::transform(fc1_bias.begin(), fc1_bias.end(), fc1_output.begin(),
726-
fc1_output.begin(),
727-
[](auto ex, auto ey) { return ex + ey; });
701+
// argmax per image -> results
702+
for (size_t o = 0; o < num_images_in_batch; ++o) {
703+
auto result = std::max_element(fc1_output.begin() + o * 10,
704+
fc1_output.begin() + (o + 1) * 10) -
705+
(fc1_output.begin() + o * 10);
706+
results[batch_index * batch_size + o] = result;
707+
}
728708

729-
// output first image's result
730-
std::cout << "output for first image: ";
731-
for (int i = 0; i < 10; ++i) {
732-
std::cout << fc1_output_span[i][0] << " ";
709+
#ifndef NDEBUG
710+
if (batch_index == 0) {
711+
assert_close(conv2_output_pooled[0], 0.40625, 1e-5);
712+
assert_close(fc1_output[0], -80.9247);
713+
assert_close(fc1_output[7], 38.1572);
733714
}
715+
#endif
734716
}
717+
const auto end = std::chrono::high_resolution_clock::now();
718+
return end - start;
735719
}
736720
#endif // DALOTIA_E_WITH_BOOST_MULTI
737721

0 commit comments

Comments
 (0)