Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(co): any/some support shared_task #82

Merged
merged 8 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/when_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ task<int> f0() {
co_return 1;
}

task<const char *> f1() {
shared_task<std::string> f1() {
printf("f1 start.\n");
printf("f1 done.\n");
co_return "f1 Great!";
Expand Down
4 changes: 2 additions & 2 deletions example/when_any.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ task<int> f0() {
co_return 1;
}

task<const char *> f1() {
shared_task<std::string> f1() {
printf("f1 start.\n");
printf("f1 done.\n");
co_return "f1 Great!";
Expand All @@ -32,7 +32,7 @@ task<> run() {
overload{
[](std::monostate) { std::cout << "(void)\n"; },
[](int x) { std::cout << x << " : int\n"; },
[](const char *s) { std::cout << s << " : string\n"; },
[](const std::string &s) { std::cout << s << " : string\n"; },
},
var
);
Expand Down
4 changes: 2 additions & 2 deletions example/when_some.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ task<int> f0() {
co_return 0;
}

task<const char *> f1() {
shared_task<std::string> f1() {
printf("f1 start.\n");
printf("f1 done.\n");
co_return "f1 Great!";
Expand All @@ -32,7 +32,7 @@ task<> run() {
overload{
[](std::monostate) { std::cout << "(void)\n"; },
[](int x) { std::cout << x << " : int\n"; },
[](const char *s) { std::cout << s << " : string\n"; },
[](const std::string &s) { std::cout << s << " : string\n"; },
},
var
);
Expand Down
1 change: 1 addition & 0 deletions include/co_context/all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <co_context/io_context.hpp>
#include <co_context/lazy_io.hpp>
#include <co_context/net.hpp>
#include <co_context/shared_task.hpp>
#include <co_context/task.hpp>
#include <co_context/utility/as_buffer.hpp>
#include <co_context/utility/defer.hpp>
Expand Down
169 changes: 100 additions & 69 deletions include/co_context/co/when_all.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#pragma once

#include <co_context/config.hpp>
#include <co_context/detail/tasklike.hpp>
#include <co_context/io_context.hpp>
#include <co_context/lazy_io.hpp>
#include <co_context/task.hpp>
#include <co_context/utility/as_atomic.hpp>
#include <co_context/utility/mpl.hpp>

Expand All @@ -15,107 +14,139 @@

namespace co_context::detail {

template<typename... Ts>
using tuple_or_void = std::conditional_t<
std::is_same_v<std::tuple<>, mpl::remove_t<void, Ts...>>,
void,
mpl::remove_t<void, Ts...>>;

template<typename... Ts>
struct all_meta {
using result_type = std::tuple<Ts...>;
using buffer_type = std::tuple<mpl::uninitialized<Ts>...>;
using result_type_list = mpl::type_list<Ts...>;

static_assert(sizeof...(Ts) != 0);
static_assert(mpl::count_v<result_type_list, void> == 0);

buffer_type buffer;
struct all_meta_base {
std::coroutine_handle<> await_handle;

// NOTE NOT thread-safe! If `resume_on` is used, race condition may
// happen!
uint32_t count_down;
uint32_t wait_num;

explicit all_meta(std::coroutine_handle<> await_handle, uint32_t n) noexcept
explicit
all_meta_base(std::coroutine_handle<> await_handle, uint32_t n) noexcept
: await_handle(await_handle)
, count_down(n) {}
, wait_num(n) {}

template<safety is_thread_safe>
void count_down() noexcept {
bool need_wakeup;
if constexpr (is_thread_safe) {
need_wakeup =
(as_atomic(wait_num).fetch_sub(1, std::memory_order_relaxed)
== 1);
} else {
need_wakeup = (--wait_num == 0);
}
if (need_wakeup) {
if constexpr (is_thread_safe) {
std::atomic_thread_fence(std::memory_order_release);
}
detail::co_spawn_handle(await_handle);
}
}
};

template<mpl::TL tuple_list>
struct all_meta : all_meta_base {
static_assert(mpl::count_v<tuple_list, void> == 0);

using value_type = tuple_list::template to<std::tuple>;
using buffer_type =
mpl::map_t<tuple_list, mpl::uninitialized>::template to<std::tuple>;

buffer_type buffer;

~all_meta() noexcept(noexcept(std::destroy_at(&as_result()))) {
explicit all_meta(std::coroutine_handle<> await_handle, uint32_t n) noexcept
: all_meta_base(await_handle, n) {}

all_meta(const all_meta &) = delete;
all_meta(all_meta &&) = delete;
all_meta &operator=(const all_meta &) = delete;
all_meta &operator=(all_meta &&) = delete;

~all_meta() noexcept(std::is_nothrow_destructible_v<value_type>) {
std::destroy_at(&as_result());
}

result_type &as_result() & noexcept {
return *reinterpret_cast<result_type *>(&buffer);
value_type &as_result() & noexcept {
return *reinterpret_cast<value_type *>(&buffer);
}
};

template<>
struct all_meta<> {
std::coroutine_handle<> await_handle;

// NOTE NOT thread-safe! If `resume_on` is used, race condition may
// happen!
uint32_t count_down;

explicit all_meta(std::coroutine_handle<> await_handle, uint32_t n) noexcept
: await_handle(await_handle)
, count_down(n) {}
struct all_meta<mpl::type_list<>> : all_meta_base {
using all_meta_base::all_meta_base;
using value_type = void;
};

template<typename... Ts>
using to_all_meta_t = typename clear_void_t<Ts...>::template to<all_meta>;
template<tasklike... task_types>
struct all_trait {
private:
using type_list = mpl::type_list<typename task_types::value_type...>;

template<safety is_thread_safe, size_t idx, typename... Ts>
task<void> all_evaluate_to(
to_all_meta_t<Ts...> &meta, task<mpl::select_t<idx, Ts...>> &&node
) {
using node_return_type = mpl::select_t<idx, Ts...>;
using tuple_list = typename type_list::template to<clear_void_t>;

if constexpr (std::is_void_v<node_return_type>) {
co_await node;
} else {
using list = mpl::first_N_t<mpl::type_list<Ts...>, idx + 1>;
public:
using meta_type = all_meta<tuple_list>;

constexpr size_t pos = idx - mpl::count_v<list, void>;
using value_type = meta_type::value_type;

auto *const location =
reinterpret_cast<node_return_type *>(std::get<pos>(meta.buffer).data
);
template<size_t idx>
static constexpr size_t buffer_offset_v =
idx - mpl::count_v<mpl::first_N_t<type_list, idx + 1>, void>;
};

std::construct_at(location, std::move(co_await std::move(node)));
}
template<
safety is_thread_safe,
typename all_meta_type,
size_t buffer_offset,
tasklike task_type>
requires std::is_base_of_v<all_meta_base, all_meta_type>
task<void> all_evaluate_to(all_meta_type &meta, task_type &&node) {
using node_value_type = typename task_type::value_type;

bool wakeup;
if constexpr (is_thread_safe) {
wakeup =
(as_atomic(meta.count_down).fetch_sub(1, std::memory_order_relaxed)
== 1);
} else {
wakeup = (--meta.count_down == 0);
std::atomic_thread_fence(std::memory_order_acquire);
}
if (wakeup) {
if constexpr (is_thread_safe) {
std::atomic_thread_fence(std::memory_order_release);

if constexpr (std::is_void_v<node_value_type>) {
co_await node;
} else {
auto *const location = reinterpret_cast<node_value_type *>(
std::get<buffer_offset>(meta.buffer).data
);

if constexpr (requires { typename task_type::is_shared_task; }) {
std::construct_at(location, co_await std::forward<task_type>(node));
} else {
std::construct_at(
location, std::move(co_await std::forward<task_type>(node))
);
}
detail::co_spawn_handle(meta.await_handle);
}

meta.template count_down<is_thread_safe>();
}

} // namespace co_context::detail

namespace co_context {

template<safety is_thread_safe = safety::safe, typename... Ts>
task<detail::tuple_or_void<Ts...>> all(task<Ts>... node) {
constexpr size_t n = sizeof...(Ts);
template<safety is_thread_safe = safety::safe, tasklike... task_types>
task<typename detail::all_trait<task_types...>::value_type>
all(task_types... node) {
constexpr size_t n = sizeof...(task_types);
static_assert(n >= 2, "too few tasks for `all(...)`");

using meta_type = detail::to_all_meta_t<Ts...>;
meta_type meta{co_await lazy::who_am_i(), n};
using trait = detail::all_trait<task_types...>;

using all_meta_type = trait::meta_type;

all_meta_type meta{co_await lazy::who_am_i(), n};

auto spawn_all = [&]<size_t... idx>(std::index_sequence<idx...>) {
(..., co_spawn(all_evaluate_to<is_thread_safe, idx, Ts...>(
(..., co_spawn(detail::all_evaluate_to<
is_thread_safe, all_meta_type,
trait::template buffer_offset_v<idx>, task_types>(
meta, std::move(node)
)));
};
Expand All @@ -124,15 +155,15 @@ task<detail::tuple_or_void<Ts...>> all(task<Ts>... node) {
std::atomic_thread_fence(std::memory_order_release);
}

spawn_all(std::index_sequence_for<Ts...>{});
spawn_all(std::index_sequence_for<task_types...>{});

co_await lazy::forget();

if constexpr (is_thread_safe) {
std::atomic_thread_fence(std::memory_order_acquire);
}

if constexpr (std::is_void_v<detail::tuple_or_void<Ts...>>) {
if constexpr (std::is_void_v<typename trait::value_type>) {
co_return;
} else {
co_return std::move(meta.as_result());
Expand Down
Loading
Loading