Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions include/stdexec/__detail/__parallel_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "__bulk.hpp"
#include "__domain.hpp"
#include "__manual_lifetime.hpp"
#include "__parallel_scheduler_replacement_api.hpp"
#include "__schedulers.hpp"
#include "__sender_introspection.hpp"
Expand Down Expand Up @@ -287,6 +288,14 @@ namespace STDEXEC
return {};
}

template <class... _Env>
[[nodiscard]]
auto query(get_completion_domain_t<set_value_t>, _Env const &...) const noexcept
-> __parallel_scheduler_domain
{
return {};
}

/// Schedules new work, returning the sender that signals the start of the work.
[[nodiscard]]
auto schedule() const noexcept -> __parallel_sender
Expand Down Expand Up @@ -328,10 +337,10 @@ namespace STDEXEC
template <class _Previous>
struct __forward_args_receiver : parallel_scheduler_replacement::bulk_item_receiver_proxy
{
using __storage_t = __detail::__sender_data_t<_Previous>;
using __storage_t = __decay_t<__detail::__sender_data_t<_Previous>>;

/// Storage for the arguments received from the previous sender.
alignas(__storage_t) unsigned char __arguments_data_[sizeof(__storage_t)];
__manual_lifetime<__storage_t> __arguments_;
};

/// Derived class that properly forwards the arguments received from `_Previous` to the receiver methods.
Expand All @@ -345,55 +354,59 @@ namespace STDEXEC
/// Stores `__as` in the base class storage, with the right types.
explicit __typed_forward_args_receiver(_As&&... __as)
{
static_assert(sizeof(std::tuple<_As...>) <= sizeof(__base_t::__arguments_data_));
// BUGBUG: this seems wrong. we are not ever destroying this tuple.
new (__base_t::__arguments_data_) std::tuple<__decay_t<_As>...>{std::move(__as)...};
__base_t::__arguments_.__construct(std::forward<_As>(__as)...);
}

/// Calls `set_value()` on the final receiver of the bulk operation, using the values from the previous sender.
void set_value() noexcept override
{
auto __state = reinterpret_cast<_BulkState*>(this);
auto __args = std::move(__base_t::__arguments_.__get());
__base_t::__arguments_.__destroy();
std::destroy_at(this);
std::apply(
[&](auto&&... __args)
{
STDEXEC::set_value(std::forward<__rcvr_t>(__state->__rcvr_),
std::forward<_As>(__args)...);
STDEXEC::set_value(std::move(__state->__rcvr_), std::move(__args)...);
},
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
std::move(__args));
}

/// Calls `set_error()` on the final receiver of the bulk operation, passing `__ex`.
void set_error(std::exception_ptr __ex) noexcept override
{
auto __state = reinterpret_cast<_BulkState*>(this);
STDEXEC::set_error(std::forward<__rcvr_t>(__state->__rcvr_), std::move(__ex));
__base_t::__arguments_.__destroy();
std::destroy_at(this);
STDEXEC::set_error(std::move(__state->__rcvr_), std::move(__ex));
}

/// Calls `set_stopped()` on the final receiver of the bulk operation.
void set_stopped() noexcept override
{
auto __state = reinterpret_cast<_BulkState*>(this);
STDEXEC::set_stopped(std::forward<__rcvr_t>(__state->__rcvr_));
__base_t::__arguments_.__destroy();
std::destroy_at(this);
STDEXEC::set_stopped(std::move(__state->__rcvr_));
}

/// Calls the bulk functor passing `__index` and the values from the previous sender.
void execute(uint32_t __begin, uint32_t __end) noexcept override
void execute(size_t __begin, size_t __end) noexcept override
{
auto __state = reinterpret_cast<_BulkState*>(this);
if constexpr (_BulkState::__is_unchunked)
{
(void) __end; // not used
// If we are not parallelizing, we need to run all the iterations sequentially.
uint32_t __increments = 1;
size_t __increments = 1;
if constexpr (!_BulkState::__parallelize)
{
__increments = static_cast<uint32_t>(__state->__size_);
__increments = static_cast<size_t>(__state->__size_);
}
for (uint32_t __i = __begin; __i < __begin + __increments; __i++)
for (size_t __i = __begin; __i < __begin + __increments; __i++)
{
std::apply([&](auto&&... __args) { __state->__fun_(__i, __args...); },
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
__base_t::__arguments_.__get());
}
}
else
Expand All @@ -402,10 +415,10 @@ namespace STDEXEC
if constexpr (!_BulkState::__parallelize)
{
__begin = 0;
__end = static_cast<uint32_t>(__state->__size_);
__end = static_cast<size_t>(__state->__size_);
}
std::apply([&](auto&&... __args) { __state->__fun_(__begin, __end, __args...); },
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
__base_t::__arguments_.__get());
}
}

Expand Down Expand Up @@ -504,7 +517,7 @@ namespace STDEXEC
__typed_forward_args_receiver_t(std::forward<_As>(__as)...);

auto __scheduler = __sched_;
auto __size = static_cast<uint32_t>(__state_.__size_);
auto __size = static_cast<size_t>(__state_.__size_);

auto __storage = __state_.__prepare_storage_for_backend(&__state_);
// This might destroy the `this` object.
Expand All @@ -514,14 +527,14 @@ namespace STDEXEC
if constexpr (_BulkState::__is_unchunked)
{
__scheduler->schedule_bulk_unchunked(_BulkState::__parallelize ? __size : 1,
__storage,
*__r);
*__r,
__storage);
}
else
{
__scheduler->schedule_bulk_chunked(_BulkState::__parallelize ? __size : 1,
__storage,
*__r);
*__r,
__storage);
}
}

Expand Down Expand Up @@ -603,7 +616,7 @@ namespace STDEXEC
&__system_bulk_op::__prepare_storage_for_backend_impl;

// Start using the preallocated buffer to store the inner operation state.
new (__preallocated_.__as_ptr()) __inner_op_state(__initFunc(*this));
new (__preallocated_.__as_ptr()) __inner_op_state(std::move(__initFunc)(*this));
}

__system_bulk_op(__system_bulk_op const &) = delete;
Expand Down
214 changes: 214 additions & 0 deletions test/stdexec/schedulers/test_parallel_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* limitations under the License.
*/

#include <atomic>
#include <memory>
#include <stdexcept>
#include <thread>

#define STDEXEC_PARALLEL_SCHEDULER_HEADER_ONLY 1
Expand Down Expand Up @@ -407,6 +410,165 @@ struct my_inline_scheduler_backend_impl : scr::parallel_scheduler_backend
}
};

enum class bulk_completion_kind
{
error,
stopped
};

struct terminal_bulk_scheduler_backend_impl : scr::parallel_scheduler_backend
{
explicit terminal_bulk_scheduler_backend_impl(bulk_completion_kind completion) noexcept
: completion_(completion)
{}

void schedule(scr::receiver_proxy& r, std::span<std::byte>) noexcept override
{
r.set_value();
}

void schedule_bulk_chunked(size_t count,
scr::bulk_item_receiver_proxy& r,
std::span<std::byte>) noexcept override
{
r.execute(0, count);
complete(r);
}

void schedule_bulk_unchunked(size_t count,
scr::bulk_item_receiver_proxy& r,
std::span<std::byte>) noexcept override
{
for (size_t i = 0; i < count; ++i)
r.execute(i, i + 1);
complete(r);
}

void complete(scr::bulk_item_receiver_proxy& r) const noexcept
{
if (completion_ == bulk_completion_kind::error)
{
r.set_error(std::make_exception_ptr(std::runtime_error{"bulk"}));
}
else
{
r.set_stopped();
}
}

bulk_completion_kind completion_;
};

auto error_bulk_scheduler_backend() -> std::shared_ptr<scr::parallel_scheduler_backend>
{
return std::make_shared<terminal_bulk_scheduler_backend_impl>(bulk_completion_kind::error);
}

auto stopped_bulk_scheduler_backend() -> std::shared_ptr<scr::parallel_scheduler_backend>
{
return std::make_shared<terminal_bulk_scheduler_backend_impl>(bulk_completion_kind::stopped);
}

struct backend_factory_guard
{
explicit backend_factory_guard(scr::__parallel_scheduler_backend_factory_t factory)
: old_factory_(scr::set_parallel_scheduler_backend(factory))
{}

~backend_factory_guard()
{
(void) scr::set_parallel_scheduler_backend(old_factory_);
}

scr::__parallel_scheduler_backend_factory_t old_factory_;
};

struct destructor_tracked_value
{
explicit destructor_tracked_value(std::shared_ptr<std::atomic<int>> live) noexcept
: live_(std::move(live))
{
live_->fetch_add(1, std::memory_order_relaxed);
}

destructor_tracked_value(destructor_tracked_value const & other) noexcept
: live_(other.live_)
{
live_->fetch_add(1, std::memory_order_relaxed);
}

destructor_tracked_value(destructor_tracked_value&& other) noexcept
: live_(other.live_)
{
live_->fetch_add(1, std::memory_order_relaxed);
}

auto operator=(destructor_tracked_value const &) -> destructor_tracked_value& = delete;
auto operator=(destructor_tracked_value&&) -> destructor_tracked_value& = delete;

~destructor_tracked_value()
{
live_->fetch_sub(1, std::memory_order_relaxed);
}

std::shared_ptr<std::atomic<int>> live_;
};

struct tracked_value_sender
{
using sender_concept = ex::sender_tag;
using completion_signatures = ex::completion_signatures<ex::set_value_t(destructor_tracked_value)>;

struct env
{
STDEXEC::parallel_scheduler sched_;

auto query(ex::get_completion_scheduler_t<ex::set_value_t>, auto const &...) const noexcept
-> STDEXEC::parallel_scheduler
{
return sched_;
}

auto query(ex::get_completion_domain_t<ex::set_value_t>, auto const &...) const noexcept
{
return ex::get_domain(sched_);
}
};

template <class Receiver>
struct operation
{
using operation_state_concept = ex::operation_state_tag;

Receiver rcvr_;
destructor_tracked_value value_;

void start() & noexcept
{
ex::set_value(std::move(rcvr_), std::move(value_));
}
};

tracked_value_sender(STDEXEC::parallel_scheduler sched, std::shared_ptr<std::atomic<int>> live)
: sched_(std::move(sched))
, value_(std::move(live))
{}

auto get_env() const noexcept -> env
{
return {sched_};
}

template <ex::receiver Receiver>
auto connect(Receiver rcvr) && noexcept -> operation<Receiver>
{
return {std::move(rcvr), std::move(value_)};
}

STDEXEC::parallel_scheduler sched_;
destructor_tracked_value value_;
};

TEST_CASE("can change the implementation of parallel scheduler at runtime",
"[scheduler][parallel_scheduler]")
{
Expand Down Expand Up @@ -451,6 +613,58 @@ TEST_CASE("can change the implementation of parallel scheduler at runtime, with
(void) scr::set_parallel_scheduler_backend(old_factory);
}

TEST_CASE("bulk on parallel_scheduler destroys stored predecessor values",
"[scheduler][parallel_scheduler]")
{
auto live = std::make_shared<std::atomic<int>>(0);

{
STDEXEC::parallel_scheduler sched = STDEXEC::get_parallel_scheduler();
auto snd = tracked_value_sender{sched, live}
| ex::bulk(ex::par, 16, [](std::size_t, destructor_tracked_value&) noexcept {});

auto result = ex::sync_wait(std::move(snd));
REQUIRE(result.has_value());
}

CHECK(live->load(std::memory_order_relaxed) == 0);
}

TEST_CASE("bulk on parallel_scheduler destroys stored predecessor values after error",
"[scheduler][parallel_scheduler]")
{
backend_factory_guard guard{error_bulk_scheduler_backend};
auto live = std::make_shared<std::atomic<int>>(0);

{
STDEXEC::parallel_scheduler sched = STDEXEC::get_parallel_scheduler();
auto snd = tracked_value_sender{sched, live}
| ex::bulk(ex::par, 16, [](std::size_t, destructor_tracked_value&) noexcept {});

CHECK_THROWS_AS(ex::sync_wait(std::move(snd)), std::runtime_error);
}

CHECK(live->load(std::memory_order_relaxed) == 0);
}

TEST_CASE("bulk on parallel_scheduler destroys stored predecessor values after stopped",
"[scheduler][parallel_scheduler]")
{
backend_factory_guard guard{stopped_bulk_scheduler_backend};
auto live = std::make_shared<std::atomic<int>>(0);

{
STDEXEC::parallel_scheduler sched = STDEXEC::get_parallel_scheduler();
auto snd = tracked_value_sender{sched, live}
| ex::bulk(ex::par, 16, [](std::size_t, destructor_tracked_value&) noexcept {});

auto result = ex::sync_wait(std::move(snd));
CHECK_FALSE(result.has_value());
}

CHECK(live->load(std::memory_order_relaxed) == 0);
}

TEST_CASE("empty environment always returns nullopt for any query",
"[scheduler][parallel_scheduler]")
{
Expand Down