Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions core/src/SYCL/Kokkos_SYCL_ParallelFor_MDRange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class Kokkos::Impl::ParallelFor<FunctorType, Kokkos::MDRangePolicy<Traits...>,
array_type m_upper;
array_type m_extent; // tile_size * num_tiles

template <typename FunctorWrapper>
template <bool UseStride, typename FunctorWrapper>
sycl::event sycl_direct_launch(const FunctorWrapper& functor_wrapper,
const sycl::event& memcpy_event) const {
const sycl::event& memcpy_event,
const sycl::nd_range<3>& range) const {
// Convenience references
const Kokkos::SYCL& space = m_policy.space();
sycl::queue& q = space.sycl_queue();
Expand All @@ -50,9 +51,6 @@ class Kokkos::Impl::ParallelFor<FunctorType, Kokkos::MDRangePolicy<Traits...>,

desul::ensure_sycl_lock_arrays_on_device(q);

const auto range =
Kokkos::Impl::compute_device_launch_params(m_policy, m_max_grid_size);

auto cgh_lambda = [&, range](sycl::handler& cgh) {
const sycl::range<3> global_range = range.get_global_range();
const sycl::range<3> local_range = range.get_local_range();
Expand Down Expand Up @@ -83,8 +81,8 @@ class Kokkos::Impl::ParallelFor<FunctorType, Kokkos::MDRangePolicy<Traits...>,
const index_type n_global_z = item.get_group_range(0);

Kokkos::Impl::DeviceIterate<Policy::rank, array_index_type, index_type,
Policy::inner_direction, true, FunctorType,
typename Policy::work_tag>(
Policy::inner_direction, UseStride,
FunctorType, typename Policy::work_tag>(
lower_bound, upper_bound, extent, functor_wrapper.get_functor(),
{n_global_x, n_global_y, n_global_z},
{n_local_x, n_local_y, n_local_z}, {global_x, global_y, global_z},
Expand Down Expand Up @@ -150,8 +148,73 @@ class Kokkos::Impl::ParallelFor<FunctorType, Kokkos::MDRangePolicy<Traits...>,

auto functor_wrapper =
Impl::make_sycl_function_wrapper(m_functor, indirectKernelMem);
sycl::event event =
sycl_direct_launch(functor_wrapper, functor_wrapper.get_copy_event());

// Compute launch range once and determine if a grid-stride loop is needed
const auto range =
Kokkos::Impl::compute_device_launch_params(m_policy, m_max_grid_size);
const sycl::range<3> local_range = range.get_local_range();

bool need_grid_stride = true;

if constexpr (Policy::rank == 1) {
if ((m_max_grid_size[0] * local_range[0]) >=
static_cast<std::size_t>(m_extent[0])) {
need_grid_stride = false;
}
} else if constexpr (Policy::rank == 2) {
if ((m_max_grid_size[0] * local_range[0]) >=
static_cast<std::size_t>(m_extent[0]) &&
(m_max_grid_size[1] * local_range[1]) >=
static_cast<std::size_t>(m_extent[1])) {
need_grid_stride = false;
}
} else if constexpr (Policy::rank == 3) {
if ((m_max_grid_size[0] * local_range[0]) >=
static_cast<std::size_t>(m_extent[0]) &&
(m_max_grid_size[1] * local_range[1]) >=
static_cast<std::size_t>(m_extent[1]) &&
(m_max_grid_size[2] * local_range[2]) >=
static_cast<std::size_t>(m_extent[2])) {
need_grid_stride = false;
}
} else if constexpr (Policy::rank == 4) {
if ((m_max_grid_size[0] * local_range[0]) >=
static_cast<std::size_t>(m_extent[0] * m_extent[1]) &&
(m_max_grid_size[1] * local_range[1]) >=
static_cast<std::size_t>(m_extent[2]) &&
(m_max_grid_size[2] * local_range[2]) >=
static_cast<std::size_t>(m_extent[3])) {
need_grid_stride = false;
}
} else if constexpr (Policy::rank == 5) {
if ((m_max_grid_size[0] * local_range[0]) >=
static_cast<std::size_t>(m_extent[0] * m_extent[1]) &&
(m_max_grid_size[1] * local_range[1]) >=
static_cast<std::size_t>(m_extent[2] * m_extent[3]) &&
(m_max_grid_size[2] * local_range[2]) >=
static_cast<std::size_t>(m_extent[4])) {
need_grid_stride = false;
}
} else if constexpr (Policy::rank == 6) {
if ((m_max_grid_size[0] * local_range[0]) >=
static_cast<std::size_t>(m_extent[0] * m_extent[1]) &&
(m_max_grid_size[1] * local_range[1]) >=
static_cast<std::size_t>(m_extent[2] * m_extent[3]) &&
(m_max_grid_size[2] * local_range[2]) >=
static_cast<std::size_t>(m_extent[4] * m_extent[5])) {
need_grid_stride = false;
}
}

sycl::event event;
if (need_grid_stride) {
event = sycl_direct_launch<true>(functor_wrapper,
functor_wrapper.get_copy_event(), range);
} else {
event = sycl_direct_launch<false>(
functor_wrapper, functor_wrapper.get_copy_event(), range);
}

functor_wrapper.register_event(event);
}

Expand Down
Loading