2024-05-06 10:49:15 +08:00

272 lines
9.4 KiB
C++

#ifndef DOCTEST_MPI_REPORTER_H
#define DOCTEST_MPI_REPORTER_H
// #include <doctest/doctest.h>
#include <fstream>
#include <string>
#include "mpi.h"
#include <vector>
#include <mutex>
namespace doctest {
extern int nb_test_cases_skipped_insufficient_procs;
int mpi_comm_world_size();
namespace {
// https://stackoverflow.com/a/11826666/1583122
struct NullBuffer : std::streambuf {
int overflow(int c) { return c; }
};
class NullStream : public std::ostream {
public:
NullStream()
: std::ostream(&nullBuff)
{}
private:
NullBuffer nullBuff = {};
};
static NullStream nullStream;
/* \brief Extends the ConsoleReporter of doctest
* Each process writes its results to its own file
* Intended to be used when a test assertion fails and the user wants to know exactly what happens on which process
*/
struct MpiFileReporter : public ConsoleReporter {
std::ofstream logfile_stream = {};
MpiFileReporter(const ContextOptions& co)
: ConsoleReporter(co,logfile_stream)
{
int rank = 0;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
std::string logfile_name = "doctest_" + std::to_string(rank) + ".log";
logfile_stream = std::ofstream(logfile_name.c_str(), std::fstream::out);
}
};
/* \brief Extends the ConsoleReporter of doctest
* Allows to manage the execution of tests in a parallel framework
* All results are collected on rank 0
*/
struct MpiConsoleReporter : public ConsoleReporter {
private:
static std::ostream& replace_by_null_if_not_rank_0(std::ostream* os) {
int rank = 0;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank==0) {
return *os;
} else {
return nullStream;
}
}
std::vector<std::pair<std::string, int>> m_failure_str_queue = {};
public:
MpiConsoleReporter(const ContextOptions& co)
: ConsoleReporter(co,replace_by_null_if_not_rank_0(co.cout))
{}
std::string file_line_to_string(const char* file, int line,
const char* tail = ""){
std::stringstream ss;
ss << skipPathFromFilename(file)
<< (opt.gnu_file_line ? ":" : "(")
<< (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option
<< (opt.gnu_file_line ? ":" : "):") << tail;
return ss.str();
}
void test_run_end(const TestRunStats& p) override {
ConsoleReporter::test_run_end(p);
const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0;
// -----------------------------------------------------
// > Gather information in rank 0
int n_rank, rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &n_rank);
int g_numAsserts = 0;
int g_numAssertsFailed = 0;
int g_numTestCasesFailed = 0;
MPI_Reduce(&p.numAsserts , &g_numAsserts , 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
MPI_Reduce(&p.numAssertsFailed , &g_numAssertsFailed , 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
MPI_Reduce(&p.numTestCasesFailed, &g_numTestCasesFailed, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
std::vector<int> numAssertsFailedByRank;
if(rank == 0){
numAssertsFailedByRank.resize(static_cast<std::size_t>(n_rank));
}
MPI_Gather(&p.numAssertsFailed, 1, MPI_INT, numAssertsFailedByRank.data(), 1, MPI_INT, 0, MPI_COMM_WORLD);
if(rank == 0) {
separator_to_stream();
s << Color::Cyan << "[doctest] " << Color::None << "assertions on all processes: " << std::setw(6)
<< g_numAsserts << " | "
<< ((g_numAsserts == 0 || anythingFailed) ? Color::None : Color::Green)
<< std::setw(6) << (g_numAsserts - g_numAssertsFailed) << " passed" << Color::None
<< " | " << (g_numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(6)
<< g_numAssertsFailed << " failed" << Color::None << " |\n";
if (nb_test_cases_skipped_insufficient_procs>0) {
s << Color::Cyan << "[doctest] " << Color::Yellow << "WARNING: Skipped ";
if (nb_test_cases_skipped_insufficient_procs>1) {
s << nb_test_cases_skipped_insufficient_procs << " tests requiring more than ";
} else {
s << nb_test_cases_skipped_insufficient_procs << " test requiring more than ";
}
if (mpi_comm_world_size()>1) {
s << mpi_comm_world_size() << " MPI processes to run\n";
} else {
s << mpi_comm_world_size() << " MPI process to run\n";
}
}
separator_to_stream();
if(g_numAssertsFailed > 0){
s << Color::Cyan << "[doctest] " << Color::None << "fail on rank:" << std::setw(6) << "\n";
for(std::size_t i = 0; i < numAssertsFailedByRank.size(); ++i){
if( numAssertsFailedByRank[i] > 0 ){
s << std::setw(16) << " -> On rank [" << i << "] with " << numAssertsFailedByRank[i] << " test failed" << std::endl;
}
}
}
s << Color::Cyan << "[doctest] " << Color::None
<< "Status: " << (g_numTestCasesFailed > 0 ? Color::Red : Color::Green)
<< ((g_numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl;
}
}
void test_case_end(const CurrentTestCaseStats& st) override {
if (is_mpi_test_case()) {
// function called by every rank at the end of a test
// if failed assertions happened, they have been sent to rank 0
// here rank zero gathers them and prints them all
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
std::vector<MPI_Request> requests;
requests.reserve(m_failure_str_queue.size()); // avoid realloc & copy of MPI_Request
for (const std::pair<std::string, int> &failure : m_failure_str_queue)
{
const std::string & failure_str = failure.first;
const int failure_line = failure.second;
int failure_msg_size = static_cast<int>(failure_str.size());
requests.push_back(MPI_REQUEST_NULL);
MPI_Isend(failure_str.c_str(), failure_msg_size, MPI_BYTE,
0, failure_line, MPI_COMM_WORLD, &requests.back()); // Tag = file line
}
// Compute the number of assert with fail among all procs
const int nb_fail_asserts = static_cast<int>(m_failure_str_queue.size());
int nb_fail_asserts_glob = 0;
MPI_Reduce(&nb_fail_asserts, &nb_fail_asserts_glob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
if(rank == 0) {
MPI_Status status;
MPI_Status status_recv;
using id_string = std::pair<int,std::string>;
std::vector<id_string> msgs(static_cast<std::size_t>(nb_fail_asserts_glob));
for (std::size_t i=0; i<static_cast<std::size_t>(nb_fail_asserts_glob); ++i) {
MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
int count;
MPI_Get_count(&status, MPI_BYTE, &count);
std::string recv_msg(static_cast<std::size_t>(count),'\0');
void* recv_msg_data = const_cast<char*>(recv_msg.data()); // const_cast needed. Non-const .data() exists in C++11 though...
MPI_Recv(recv_msg_data, count, MPI_BYTE, status.MPI_SOURCE,
status.MPI_TAG, MPI_COMM_WORLD, &status_recv);
msgs[i] = {status.MPI_SOURCE,recv_msg};
}
std::sort(begin(msgs),end(msgs),[](const id_string& x, const id_string& y){ return x.first < y.first; });
// print
if (nb_fail_asserts_glob>0) {
separator_to_stream();
file_line_to_stream(tc->m_file.c_str(), static_cast<int>(tc->m_line), "\n");
if(tc->m_test_suite && tc->m_test_suite[0] != '\0')
s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n";
if(strncmp(tc->m_name, " Scenario:", 11) != 0)
s << Color::Yellow << "TEST CASE: ";
s << Color::None << tc->m_name << "\n\n";
for(const auto& msg : msgs) {
s << msg.second;
}
s << "\n";
}
}
MPI_Waitall(static_cast<int>(requests.size()), requests.data(), MPI_STATUSES_IGNORE);
m_failure_str_queue.clear();
}
ConsoleReporter::test_case_end(st);
}
bool is_mpi_test_case() const {
return tc->m_description != nullptr
&& std::string(tc->m_description) == std::string("MPI_TEST_CASE");
}
void log_assert(const AssertData& rb) override {
if (!is_mpi_test_case()) {
ConsoleReporter::log_assert(rb);
} else {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if(!rb.m_failed && !opt.success)
return;
std::lock_guard<std::mutex> lock(mutex);
std::stringstream failure_msg;
failure_msg << Color::Red << "On rank [" << rank << "] : " << Color::None;
failure_msg << file_line_to_string(rb.m_file, rb.m_line, " ");
if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) ==0){
failure_msg << Color::Cyan
<< assertString(rb.m_at)
<< "( " << rb.m_expr << " ) "
<< Color::None
<< (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")
<< " values: "
<< assertString(rb.m_at)
<< "( " << rb.m_decomp.c_str() << " )\n";
}
m_failure_str_queue.push_back({failure_msg.str(), rb.m_line});
}
}
}; // MpiConsoleReporter
// "1" is the priority - used for ordering when multiple reporters/listeners are used
REGISTER_REPORTER("MpiConsoleReporter", 1, MpiConsoleReporter);
REGISTER_REPORTER("MpiFileReporter", 1, MpiFileReporter);
} // anonymous
} // doctest
#endif // DOCTEST_REPORTER_H