99#include " dalotia.hpp"
1010#include " dalotia_safetensors_file.hpp"
1111
12- // for cblas_sgemm:
13- #include " cblas.h"
14-
15- // Kokkos mdspan?
16- // #include "mdspan/mdspan.hpp"
17-
12+ // Boost.Multi must be included BEFORE cblas.h — multi's blas/core.hpp
13+ // uses an #ifdef CBLAS_H guard that skips type definitions needed later.
1814#ifdef DALOTIA_E_WITH_BOOST_MULTI
1915#include < boost/multi/array.hpp>
20- #include < multi/adaptors/blas.hpp>
21- // #include <multi/adaptors/tblis.hpp>
16+ #include < boost/multi/adaptors/blas.hpp>
2217#endif // DALOTIA_E_WITH_BOOST_MULTI
2318
19+ // for cblas_sgemm:
20+ #include " cblas.h"
21+
2422#ifdef DALOTIA_E_WITH_NDIRECT
2523#include < NDIRECT_direct.h>
2624#endif // DALOTIA_E_WITH_NDIRECT
@@ -465,68 +463,62 @@ std::ostream &operator<<(
465463 return os;
466464}
467465
468- void run_inference_boost_multi (std::string filename) {
466+ std::chrono::duration<double > run_inference_boost_multi (
467+ const dalotia::vector<float > &conv1_weight,
468+ const dalotia::vector<float > &conv1_bias,
469+ const std::array<int , 4 > &conv1_weight_extents,
470+ const dalotia::vector<float > &conv2_weight,
471+ const dalotia::vector<float > &conv2_bias,
472+ const std::array<int , 4 > &conv2_weight_extents,
473+ const dalotia::vector<float > &fc1_weight,
474+ const dalotia::vector<float > &fc1_bias,
475+ const std::array<int , 2 > &fc1_weight_extents,
476+ const dalotia::vector<float > &images, const dalotia::vector<float > &labels,
477+ dalotia::vector<int > &results) {
469478 using span_4d_float = multi::array_ref<float , 4 >;
470479 using span_3d_float = multi::array_ref<float , 3 >;
471480 using span_2d_float = multi::array_ref<float , 2 >;
472481
473- auto [conv1_weight, conv1_bias] =
474- test_load (filename, " conv1" ); // TODO why can't I make them const?
475- const auto conv1_weight_span =
476- span_4d_float ({8 , 1 , 3 , 3 }, conv1_weight.data ());
477- const auto conv1_bias_span = span_2d_float ({8 , 1 }, conv1_bias.data ());
482+ const auto conv1_weight_span = span_4d_float (
483+ {conv1_weight_extents[0 ], conv1_weight_extents[1 ],
484+ conv1_weight_extents[2 ], conv1_weight_extents[3 ]},
485+ const_cast <float *>(conv1_weight.data ()));
486+ const auto conv1_bias_span = span_2d_float (
487+ {conv1_weight_extents[0 ], 1 }, const_cast <float *>(conv1_bias.data ()));
478488 assert (conv1_weight_span.sizes ().get <1 >() == 1 ); // 1 input channel
479489
480- auto [conv2_weight, conv2_bias] = test_load (filename, " conv2" );
481- const auto conv2_weight_span =
482- span_4d_float ({16 , 8 , 3 , 3 }, conv1_weight.data ());
483- const auto conv2_bias_span = span_2d_float ({16 , 1 }, conv2_bias.data ());
484-
485- auto [fc1_weight, fc1_bias] = test_load (filename, " fc1" );
486- const auto fc1_weight_span = span_2d_float ({10 , 784 }, fc1_weight.data ());
487- const auto fc1_bias_span = span_2d_float ({10 , 1 }, fc1_bias.data ());
490+ const auto conv2_weight_span = span_4d_float (
491+ {conv2_weight_extents[0 ], conv2_weight_extents[1 ],
492+ conv2_weight_extents[2 ], conv2_weight_extents[3 ]},
493+ const_cast <float *>(conv2_weight.data ()));
494+ const auto conv2_bias_span = span_2d_float (
495+ {conv2_weight_extents[0 ], 1 }, const_cast <float *>(conv2_bias.data ()));
488496
489- // load the mnist test data // as in
490- // https://medium.com/@myringoleMLGOD/simple-convolutional-neural-network-cnn-for-dummies-in-pytorch-a-step-by-step-guide-6f4109f6df80
491- // too
492- std::string mnist_test_images_filename = " t10k-images-idx3-ubyte" ;
493- std::string mnist_test_labels_filename = " t10k-labels-idx3-ubyte" ;
497+ const auto fc1_weight_span = span_2d_float (
498+ {fc1_weight_extents[0 ], fc1_weight_extents[1 ]},
499+ const_cast <float *>(fc1_weight.data ()));
494500
495- auto images = read_mnist_scaled (mnist_test_images_filename);
496- // auto labels = read_mnist(mnist_test_labels_filename);
497501 auto total_num_images = images.size () / (28 * 28 );
498502
499- // minibatching
500503 constexpr size_t batch_size = 64 ;
501- auto num_batches = static_cast <int >(
504+ auto num_batches = static_cast <size_t >(
502505 std::ceil (total_num_images / static_cast <float >(batch_size)));
503- for (size_t batch_index = 0 ; batch_index < 1 ; ++batch_index) {
506+
507+ const auto start = std::chrono::high_resolution_clock::now ();
508+ for (size_t batch_index = 0 ; batch_index < num_batches; ++batch_index) {
504509 auto num_images_in_batch =
505510 std::min (batch_size, total_num_images - batch_index * batch_size);
506511 auto inum_images_in_batch = static_cast <int >(num_images_in_batch);
507- std::cout << " batch index: " << batch_index << " / " << num_batches
508- << " num images in batch: " << num_images_in_batch
509- << std::endl;
510-
511512 // apply first convolution
512513 // copy data to larger array for zero-padding at the edges
513514 auto image_vector_padded =
514515 dalotia::vector<float >(num_images_in_batch * 30 * 30 );
515516 auto image_padded_span = span_3d_float ({inum_images_in_batch, 30 , 30 },
516517 image_vector_padded.data ());
517518
518- std::cout << " image_padded "
519- << image_padded_span (
520- 0 , multi::_,
521- multi::_) // <- TODO why does this segfault on fugaku?
522- << std::endl;
523-
524519 image_padded_span (multi::_, {1 , 29 }, {1 , 29 }) =
525520 span_3d_float ({inum_images_in_batch, 28 , 28 },
526- images.data () + batch_index * (batch_size * 28 * 28 ));
527-
528- std::cout << " image_padded " << image_padded_span (0 , multi::_, multi::_)
529- << std::endl;
521+ const_cast <float *>(images.data ()) + batch_index * (batch_size * 28 * 28 ));
530522
531523 auto conv1_output =
532524 dalotia::vector<float >(num_images_in_batch * 8 * 28 * 28 );
@@ -584,15 +576,15 @@ void run_inference_boost_multi(std::string filename) {
584576 auto conv1_output_pooled_span = span_4d_float (
585577 {inum_images_in_batch, 8 , 14 , 14 }, conv1_output_pooled.data ());
586578#pragma omp parallel for
587- for (int o = 0 ; o < num_images_in_batch ; ++o) {
579+ for (int o = 0 ; o < inum_images_in_batch ; ++o) {
588580 for (int k = 0 ; k < 8 ; ++k) {
589581 for (int i = 0 ; i < 14 ; ++i) {
590582 for (int j = 0 ; j < 14 ; ++j) {
591- auto window = conv1_output_span (
592- o, k, { 2 * i, 2 * i + 1 }, { 2 * j, 2 * j + 1 });
593- auto max_val = (* std::max_element (window. begin (),
594- window. end ()))[ 0 ] ;
595- conv1_output_pooled_span (o, k, i, j) = max_val ;
583+ float mv = 0 ;
584+ for ( int m = 0 ; m < 2 ; ++m)
585+ for ( int n = 0 ; n < 2 ; ++n)
586+ mv = std::max (mv, conv1_output_span (o, k, 2 *i+m, 2 *j+n)) ;
587+ conv1_output_pooled_span (o, k, i, j) = mv ;
596588 }
597589 }
598590 }
@@ -637,26 +629,26 @@ void run_inference_boost_multi(std::string filename) {
637629 for (int l = 0 ; l < conv2_weight_span.sizes ().get <1 >();
638630 ++l) {
639631 value +=
640- conv2_weight_span (l, k , 0 , 0 ) *
632+ conv2_weight_span (k, l , 0 , 0 ) *
641633 c_feature_padded_span (o, l, i - 1 , j - 1 ) +
642- conv2_weight_span (l, k , 0 , 1 ) *
634+ conv2_weight_span (k, l , 0 , 1 ) *
643635 c_feature_padded_span (o, l, i - 1 , j + 0 ) +
644- conv2_weight_span (l, k , 0 , 2 ) *
636+ conv2_weight_span (k, l , 0 , 2 ) *
645637 c_feature_padded_span (o, l, i - 1 , j + 1 ) +
646- conv2_weight_span (l, k , 1 , 0 ) *
638+ conv2_weight_span (k, l , 1 , 0 ) *
647639 c_feature_padded_span (o, l, i + 0 , j - 1 ) +
648- conv2_weight_span (l, k , 1 , 1 ) *
640+ conv2_weight_span (k, l , 1 , 1 ) *
649641 c_feature_padded_span (o, l, i + 0 , j + 0 ) +
650- conv2_weight_span (l, k , 1 , 2 ) *
642+ conv2_weight_span (k, l , 1 , 2 ) *
651643 c_feature_padded_span (o, l, i + 0 , j + 1 ) +
652- conv2_weight_span (l, k , 2 , 0 ) *
644+ conv2_weight_span (k, l , 2 , 0 ) *
653645 c_feature_padded_span (o, l, i + 1 , j - 1 ) +
654- conv2_weight_span (l, k , 2 , 1 ) *
646+ conv2_weight_span (k, l , 2 , 1 ) *
655647 c_feature_padded_span (o, l, i + 1 , j + 0 ) +
656- conv2_weight_span (l, k, 2 , 2 ) *
657- c_feature_padded_span (o, l, i + 1 , j + 1 ) +
658- conv2_bias_span (l, 0 );
648+ conv2_weight_span (k, l, 2 , 2 ) *
649+ c_feature_padded_span (o, l, i + 1 , j + 1 );
659650 }
651+ value += conv2_bias[k];
660652 // apply activation function (relu)
661653 if (value < 0 .) {
662654 value = 0 .;
@@ -667,71 +659,63 @@ void run_inference_boost_multi(std::string filename) {
667659 }
668660 }
669661
670- std::cout << conv2_output_span (0 , multi::_, multi::_, multi::_)
671- << std::endl;
672-
673- if (batch_index == 0 ) { // compare to python result
674- assert (conv2_output_span[0 ][0 ][0 ][0 ] < 0.4063 );
675- assert (conv2_output_span[0 ][0 ][0 ][0 ] > 0.4062 );
676- }
677-
678662 // apply max pooling
679663 dalotia::vector<float > conv2_output_pooled (num_images_in_batch * 16 *
680664 7 * 7 );
681665 auto conv2_output_pooled_span = span_4d_float (
682666 {inum_images_in_batch, 16 , 7 , 7 }, conv2_output_pooled.data ());
683667#pragma omp parallel for
684- for (int o = 0 ; o < num_images_in_batch ; ++o) {
668+ for (int o = 0 ; o < inum_images_in_batch ; ++o) {
685669 for (int i = 0 ; i < 7 ; ++i) {
686670 for (int j = 0 ; j < 7 ; ++j) {
687671 for (int k = 0 ; k < 16 ; ++k) {
688- auto window = conv2_output_span (
689- o, k, { 2 * i, 2 * i + 1 }, { 2 * j, 2 * j + 1 });
690- auto max_val = (* std::max_element (window. begin (),
691- window. end ()))[ 0 ] ;
692- conv2_output_pooled_span (o, k, i, j) = max_val ;
672+ float mv = 0 ;
673+ for ( int m = 0 ; m < 2 ; ++m)
674+ for ( int n = 0 ; n < 2 ; ++n)
675+ mv = std::max (mv, conv2_output_span (o, k, 2 *i+m, 2 *j+n)) ;
676+ conv2_output_pooled_span (o, k, i, j) = mv ;
693677 }
694678 }
695679 }
696680 }
697681
698682 // apply dense layer
683+ // fc1_output = conv2_flat @ fc1_weight^T + fc1_bias
699684 dalotia::vector<float > fc1_output (num_images_in_batch * 10 );
700685 auto fc1_output_span =
701686 span_2d_float ({inum_images_in_batch, 10 }, fc1_output.data ());
702687 auto conv2_output_flattened = span_2d_float (
703688 {inum_images_in_batch, 16 * 7 * 7 }, conv2_output_pooled.data ());
704- // fc1_output_span = multi::blas::gemm(1., conv2_output_flattened,
705- // //TODO use one of them!
706- // fc1_weight_span.transposed());
707- // using multi::operator+=; // doesn't work yet? ->
708- // https://github.com/correaa/boost-multi/blob/master/include/boost/multi/adaptors/blas/README.md
709- // footnote 3
710- // std::transform(fc1_bias_span.begin(), fc1_bias_span.end(),
711- // // appears to not work
712- // fc1_output_span.begin(), fc1_output_span.begin(),
713- // [](auto ex, auto ey) {
714- // return ex[0] + ey[0];
715- // }); // this would also be nicer without the [0]
716- // indexing
717689
718- // {
719- // using namespace tblis::indices;
720- // tblis::mult(fc1_weight_span(a, b), conv2_output_flattened(o,
721- // b),
722- // fc1_output(o, a));
723- // }
690+ // fill with bias
691+ for (int o = 0 ; o < inum_images_in_batch; ++o) {
692+ for (int k = 0 ; k < 10 ; ++k) {
693+ fc1_output_span (o, k) = fc1_bias[k];
694+ }
695+ }
696+ // gemm: C = alpha * A @ B^T + beta * C
697+ multi::blas::gemm (1 .f , conv2_output_flattened,
698+ fc1_weight_span.transposed (),
699+ 1 .f , fc1_output_span);
724700
725- std::transform (fc1_bias.begin (), fc1_bias.end (), fc1_output.begin (),
726- fc1_output.begin (),
727- [](auto ex, auto ey) { return ex + ey; });
701+ // argmax per image -> results
702+ for (size_t o = 0 ; o < num_images_in_batch; ++o) {
703+ auto result = std::max_element (fc1_output.begin () + o * 10 ,
704+ fc1_output.begin () + (o + 1 ) * 10 ) -
705+ (fc1_output.begin () + o * 10 );
706+ results[batch_index * batch_size + o] = result;
707+ }
728708
729- // output first image's result
730- std::cout << " output for first image: " ;
731- for (int i = 0 ; i < 10 ; ++i) {
732- std::cout << fc1_output_span[i][0 ] << " " ;
709+ #ifndef NDEBUG
710+ if (batch_index == 0 ) {
711+ assert_close (conv2_output_pooled[0 ], 0.40625 , 1e-5 );
712+ assert_close (fc1_output[0 ], -80.9247 );
713+ assert_close (fc1_output[7 ], 38.1572 );
733714 }
715+ #endif
734716 }
717+ const auto end = std::chrono::high_resolution_clock::now ();
718+ return end - start;
735719}
736720#endif // DALOTIA_E_WITH_BOOST_MULTI
737721
0 commit comments