Skip to content

Commit

Permalink
template tests
Browse files Browse the repository at this point in the history
Signed-off-by: Mishin, Ilya <[email protected]>
  • Loading branch information
Iliamish committed Nov 10, 2021
1 parent dad5e1b commit 80d0385
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 88 deletions.
5 changes: 2 additions & 3 deletions include/oneapi/tbb/detail/_flow_graph_node_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@ class function_input_base : public receiver<Input>, no_assign {
}

graph_task* try_put_task( const input_type& t) override {
if(my_is_no_throw){
if ( my_is_no_throw )
return try_put_task_impl(t, has_policy<lightweight, Policy>());
}else{
else
return try_put_task_impl(t, std::false_type());
}
}

//! Adds src to the list of cached predecessors.
Expand Down
101 changes: 64 additions & 37 deletions test/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,39 +761,7 @@ struct condition_predicate {
std::atomic<unsigned> g_lightweight_count;
std::atomic<unsigned> g_task_count;

class limited_lightweight_checker_body_noexcept {
public:
limited_lightweight_checker_body_noexcept() {
g_body_count = 0;
g_lightweight_count = 0;
g_task_count = 0;
}
private:
void increase_and_check(const std::thread::id& /*input*/) {
++g_body_count;

bool is_inside_task = oneapi::tbb::task::current_context() != nullptr;

if(is_inside_task) {
++g_task_count;
} else {
std::unique_lock<std::mutex> lock(m);
lightweight_condition.wait(lock, condition_predicate());
++g_lightweight_count;
lightweight_work_processed = true;
}
}
public:
template<typename gateway_type>
void operator()(const std::thread::id& input, gateway_type&) noexcept {
increase_and_check(input);
}
output_tuple_type operator()(const std::thread::id& input) noexcept {
increase_and_check(input);
return output_tuple_type();
}
};

template <bool NoExcept>
class limited_lightweight_checker_body {
public:
limited_lightweight_checker_body() {
Expand All @@ -818,10 +786,10 @@ class limited_lightweight_checker_body {
}
public:
template<typename gateway_type>
void operator()(const std::thread::id& input, gateway_type&) {
void operator()(const std::thread::id& input, gateway_type&) noexcept(NoExcept) {
increase_and_check(input);
}
output_tuple_type operator()(const std::thread::id& input) {
output_tuple_type operator()(const std::thread::id& input) noexcept(NoExcept) {
increase_and_check(input);
return output_tuple_type();
}
Expand All @@ -832,7 +800,7 @@ void test_limited_lightweight_execution(unsigned N, unsigned concurrency) {
CHECK_MESSAGE(concurrency != tbb::flow::unlimited,
"Test for limited concurrency cannot be called with unlimited concurrency argument");
tbb::flow::graph g;
NodeType node(g, concurrency, limited_lightweight_checker_body_noexcept());
NodeType node(g, concurrency, limited_lightweight_checker_body<true>());
// Execute first body as lightweight, then wait for all other threads to fill internal buffer.
// Then unblock the lightweight thread and check if other body executions are inside oneTBB task.
utils::SpinBarrier barrier(N - concurrency);
Expand All @@ -850,7 +818,7 @@ void test_limited_lightweight_execution_with_throwing_body(unsigned N, unsigned
CHECK_MESSAGE(concurrency != tbb::flow::unlimited,
"Test for limited concurrency cannot be called with unlimited concurrency argument");
tbb::flow::graph g;
NodeType node(g, concurrency, limited_lightweight_checker_body());
NodeType node(g, concurrency, limited_lightweight_checker_body<false>());
// Body is no noexcept, in this case it must be executed as tasks, instead of lightweight execution
utils::SpinBarrier barrier(N);
utils::NativeParallelFor(N, native_loop_limited_body<NodeType>(node, barrier));
Expand All @@ -862,6 +830,63 @@ void test_limited_lightweight_execution_with_throwing_body(unsigned N, unsigned
lightweight_work_processed = false;
}

template <int Threshold>
struct throwing_body{
std::atomic<int>& my_counter;

throwing_body(std::atomic<int>& counter) : my_counter(counter) {}

template<typename input_type, typename gateway_type>
void operator()(const input_type&, gateway_type&) {
++my_counter;
if(my_counter == Threshold)
throw Threshold;
}

template<typename input_type>
output_tuple_type operator()(const input_type&) {
++my_counter;
if(my_counter == Threshold)
throw Threshold;
return output_tuple_type();
}
};

//! Test excesption thrown in node with lightweight policy was rethrown by graph
template<template<typename, typename, typename> class NodeType>
void test_exception_ligthweight_policy(){
std::atomic<int> counter {0};
constexpr int threshold = 10;

using IndexerNodeType = oneapi::tbb::flow::indexer_node<int, int>;
using FuncNodeType = NodeType<IndexerNodeType::output_type, output_tuple_type, tbb::flow::lightweight>;
oneapi::tbb::flow::graph g;

IndexerNodeType indexer(g);
FuncNodeType tested_node(g, oneapi::tbb::flow::serial, throwing_body<threshold>(counter));
oneapi::tbb::flow::make_edge(indexer, tested_node);

utils::NativeParallelFor( threshold * 2, [&](int i){
if(i % 2)
std::get<1>(indexer.input_ports()).try_put(1);
else
std::get<0>(indexer.input_ports()).try_put(0);
});

bool catchException = false;
try
{
g.wait_for_all();
}
catch (const int& exc)
{
catchException = true;
CHECK_MESSAGE( exc == threshold, "graph.wait_for_all() rethrow current exception" );
}
CHECK_MESSAGE( catchException, "The exception must be thrown from graph.wait_for_all()" );
CHECK_MESSAGE( counter == threshold, "Graph must cancel all tasks after exception" );
}

template<typename NodeType>
void test_lightweight(unsigned N) {
test_unlimited_lightweight_execution<NodeType>(N);
Expand All @@ -876,6 +901,8 @@ void test(unsigned N) {
typedef std::thread::id input_type;
typedef NodeType<input_type, output_tuple_type, tbb::flow::queueing_lightweight> node_type;
test_lightweight<node_type>(N);

test_exception_ligthweight_policy<NodeType>();
}

} // namespace lightweight_testing
Expand Down
48 changes: 0 additions & 48 deletions test/tbb/test_multifunction_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,48 +480,6 @@ void test_ports_return_references() {
test_output_ports_return_ref(mf_node);
}

void test_exception_ligthweight_policy(){
std::atomic<int> counter {0};
constexpr int threshold = 10;

using IndexerNodeType = oneapi::tbb::flow::indexer_node<int,int>;
using MultifunctionNodeType = oneapi::tbb::flow::multifunction_node<IndexerNodeType::output_type,
std::tuple<int>, oneapi::tbb::flow::lightweight>;

oneapi::tbb::flow::graph g;

auto multifunctionNodeBody = [&](MultifunctionNodeType::input_type, MultifunctionNodeType::output_ports_type)
{
++counter;
if(counter == threshold)
throw threshold;
};

IndexerNodeType indexer(g);
MultifunctionNodeType multi(g, oneapi::tbb::flow::serial, multifunctionNodeBody);
oneapi::tbb::flow::make_edge(indexer, multi);

utils::NativeParallelFor( threshold * 2, [&](int i){
if(i % 2)
std::get<1>(indexer.input_ports()).try_put(1);
else
std::get<0>(indexer.input_ports()).try_put(0);
} );

bool catchException = false;
try
{
g.wait_for_all();
}
catch (const int& exc)
{
catchException = true;
CHECK_MESSAGE( exc == threshold, "graph.wait_for_all() rethrow current exception" );
}
CHECK_MESSAGE( catchException, "The exception must be thrown from graph.wait_for_all()" );
CHECK_MESSAGE( counter == threshold, "Graph must cancel all tasks after exception" );
}

#if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
#include <array>
#include <vector>
Expand Down Expand Up @@ -589,12 +547,6 @@ TEST_CASE("Lightweight testing"){
lightweight_testing::test<tbb::flow::multifunction_node>(10);
}

//! Test excesption thrown in node with lightweight policy was rethrown by graph
//! \brief \ref error_guessing
TEST_CASE("Exception in lightweight node"){
test_exception_ligthweight_policy();
}

#if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
//! Test follows and precedes API
//! \brief \ref error_guessing
Expand Down

0 comments on commit 80d0385

Please sign in to comment.