[CUB] Refactor DevicePartition::If to always take an environment#9393
[CUB] Refactor DevicePartition::If to always take an environment#9393miscco wants to merge 3 commits into
DevicePartition::If to always take an environment#9393Conversation
We want to use env based API to ensure that we take advantage of user provided tunings
We want to be able to pass tunings to the APIs that take user provided memory Make sure we can pass any environment or stream type to them
OverviewThis PR refactors CUB and libcudacxx's device algorithms to consistently use environment-based execution dispatch instead of explicit CUDA stream parameters. The primary goal is to enable callers to pass tunings and custom execution policies when user-provided temporary storage is supplied. Key ChangesCore Dispatch Infrastructure (
|
| Layer / File(s) | Summary |
|---|---|
Dispatch infrastructure: const-ref environments and user-memory overloads cub/cub/detail/env_dispatch.cuh |
dispatch_with_env and dispatch_with_env_and_tuning changed to accept environment by const reference. New overloads added for user-provided temporary storage: they query stream and tuning from the environment, select policy via DefaultPolicySelector, and invoke the algorithm callable with explicit storage pointers. |
DeviceFind: FindIf, LowerBound, UpperBound with environment execution cub/cub/device/device_find.cuh |
DeviceFind::FindIf, LowerBound, and UpperBound main overloads updated to accept const EnvT& env (defaulting to cuda::std::execution::env<>), route through detail::dispatch_with_env_and_tuning with policy selection, and move null-storage handling into env-dispatch lambdas. Convenience overloads pass environment by const reference. |
DevicePartition: If overloads with environment execution cub/cub/device/device_partition.cuh |
Two-way and three-way DevicePartition::If device storage overloads updated to accept const EnvT& env instead of cudaStream_t. Default policy selectors built from iterator/value types, invoked via detail::dispatch_with_env_and_tuning, with policy forwarded into partition dispatch. Environment-based overloads pass environment by const reference. |
PSTL backends: propagate policy to CUB device algorithms libcudacxx/include/cuda/std/__pstl/cuda/find_if.h, partition.h, partition_copy.h, stable_partition.h |
CUDA PSTL dispatch implementations updated to pass execution __policy instead of __stream.get() to CUB DeviceFind and DevicePartition calls during both temporary storage sizing and kernel invocation. find_if updated to select CUB offset type and use typed nullptr for output pointer. |
Tests: DeviceFind and DevicePartition with environment and user-provided memory cub/test/catch2_test_device_find_env.cu, catch2_test_device_partition_if.cu |
Test coverage expanded to validate algorithms with user-provided temporary storage and multiple environment forms (CUDA stream, cuda::stream, cuda::stream_ref, cuda::std::execution::env, cuda::execution::gpu policies). Existing tests refactored into unprovided vs. user-provided memory sections; new C2H_TESTs verify nullptr queries and allocated-storage runs with CUDA error and synchronization checks. |
Suggested reviewers
- shwina
- gevtushenko
Comment @coderabbitai help to get the list of available commands and usage tips.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
cub/cub/detail/env_dispatch.cuh (1)
117-129: suggestion: This new user-memory path now feeds policy selection through the env/tuning dispatcher, so it can change the kernel variant selected forDeviceFindandDevicePartition. Please run the SASS-diff and benchmark flow for the affectedDevice*benchmarks before merge. As per coding guidelines,**/*.{cpp,cu,cuh}: Do not commit SASS code changes without running benchmarks to check for performance regressions.Source: Coding guidelines
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 18345e1e-685e-4c8c-9fd7-da9a89f1a7dd
📒 Files selected for processing (9)
cub/cub/detail/env_dispatch.cuhcub/cub/device/device_find.cuhcub/cub/device/device_partition.cuhcub/test/catch2_test_device_find_env.cucub/test/catch2_test_device_partition_if.culibcudacxx/include/cuda/std/__pstl/cuda/find_if.hlibcudacxx/include/cuda/std/__pstl/cuda/partition.hlibcudacxx/include/cuda/std/__pstl/cuda/partition_copy.hlibcudacxx/include/cuda/std/__pstl/cuda/stable_partition.h
| auto test_find_if = [&](const auto& env) { | ||
| size_t num_bytes = 0; | ||
| error = cub::DeviceFind::FindIf(nullptr, num_bytes, d_in.begin(), d_out.begin(), predicate, num_items, env); | ||
| REQUIRE(error == cudaSuccess); | ||
| REQUIRE(cudaSuccess == cudaPeekAtLastError()); | ||
| REQUIRE(cudaSuccess == cudaDeviceSynchronize()); | ||
| REQUIRE(num_bytes == expected_bytes_allocated); | ||
|
|
||
| error = cub::DeviceFind::FindIf(temp_storage, num_bytes, d_in.begin(), d_out.begin(), predicate, num_items, env); | ||
| REQUIRE(error == cudaSuccess); | ||
| REQUIRE(cudaSuccess == cudaPeekAtLastError()); | ||
| REQUIRE(cudaSuccess == cudaDeviceSynchronize()); | ||
| REQUIRE(d_out[0] == 5); | ||
| }; |
There was a problem hiding this comment.
important: These tests assume temp-storage size is identical across all environment forms (num_bytes == expected_bytes_allocated) and reuse a buffer sized from a baseline query. The env dispatch path can legally select different policy/tuning per env, so required bytes may differ; this makes the test brittle and can fail valid implementations. Query and allocate temp storage per env invocation instead of enforcing cross-env equality.
As per coding guidelines, cub/**/* reviews should prioritize stream behavior and test-coverage risks, and this assertion over-constrains valid env behavior.
Also applies to: 309-340, 436-467
Source: Coding guidelines
🥳 CI Workflow Results🟩 Finished in 1h 29m: Pass: 100%/343 | Total: 6d 17h | Max: 1h 28m | Hits: 69%/714109See results here. |
We want to be able to pass tunings to the APIs that take user provided memory
Make sure we can pass any environment or stream type to them
DeviceFindIf#9318