Skip to content

Conversation

@rdspring1
Copy link
Collaborator

  • Add warp_specialize and inline_at schedule operations
  • Create test_register_sharing_circular_buffering_pointwise example.

* Add `warp_specialize` and `inline_at` schedule operations
@rdspring1 rdspring1 added the Direct Bindings Python extension with direct mapping to NvFuser CPP objects. label Dec 15, 2025
@rdspring1
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Dec 15, 2025

Review updated until commit 647ceb8

Description

  • Add inline_at schedule operation for selective tensor inlining

  • Add warp_specialize operation for warp specialization circular buffering

  • Create comprehensive test demonstrating warp specialization with TMA loads

  • Include device compatibility checks for Hopper+ architecture

Changes walkthrough

Relevant files
Enhancement
schedule.cpp
Add warp specialization schedule operations                           

python/python_direct/schedule.cpp

  • Added inline_at function for selective tensor operation inlining
  • Added warp_specialize function for warp specialization circular
    buffering
  • Created bindCircularBuffering binding function
  • Updated schedule operator bindings to include new functionality
  • +80/-0   
    Tests
    test_tutorial.py
    Add warp specialization test example                                         

    tests/python/direct/test_tutorial.py

  • Added test_warp_specialized_circular_buffering_pointwise test function
  • Demonstrates TMA load operations with warp specialization
  • Includes device compatibility checks for Hopper+ architecture
  • Validates kernel generation with warp specialization conditional
    statements
  • +83/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Missing Parameter Validation

    The new inline_at and warp_specialize functions lack input validation. Consider adding checks for null pointers, valid position indices, positive stage counts, and reasonable prefetch distances to prevent runtime errors and improve user experience.

    schedule.def(
        "inline_at",
        [](TensorView* reference_tv,
           int64_t pos,
           bool best_effort,
           const std::vector<TensorView*>& selected_tensors) {
          if (selected_tensors.empty()) {
            // Inline to the position corresponding to the reference position in
            // the reference tensor for all tensors in the current fusion.
            inlineAllAt(reference_tv, pos, best_effort);
          } else {
            // Inline to the position corresponding to the reference position in
            // the reference tensor for selected tensors in the current fusion.
            std::unordered_set<TensorView*> selected_tv_set(
                selected_tensors.begin(), selected_tensors.end());
            inlineSelectedAt(selected_tv_set, reference_tv, pos, best_effort);
          }
        },
        R"(
          Inline operations at a specific position for the selected tensors.
          If selected_tensors is empty, inlines all operations.
    
          Parameters
          ----------
          reference_tv : TensorView
              The reference TensorView whose position will be used for inlining.
          pos : int, optional
              The position to inline at. -1 means the last position.
          best_effort : bool, optional
              Whether to try to inline even if the exact position is not possible (default: False).
          selected_tensors : List[TensorView], optional
              List of TensorViews to inline. If empty, inlines all operations.
    
          Returns
          -------
          None
        )",
        py::arg("reference_tv"),
        py::arg("pos") = -1,
        py::arg("best_effort") = false,
        py::arg("selected_tensors") = std::vector<TensorView*>());
    Unclear API Documentation

    The best_effort parameter in inline_at lacks clear documentation about the difference in behavior between True and False states. Users need to understand what "try to inline even if exact position is not possible" means in practice.

    best_effort : bool, optional
        Whether to try to inline even if the exact position is not possible (default: False).
    Limited Test Coverage

    While the new test is comprehensive for the happy path, consider adding tests for edge cases like invalid parameters, different tensor dimensions, and error conditions to ensure robustness of the new functionality.

    def test_warp_specialized_circular_buffering_pointwise():
        def _definition_func(fd: FusionDefinition, inputs):
            tv0 = fd.from_pytorch(inputs[0])
            tv1 = fd.from_pytorch(inputs[1])
            tv2 = fd.ops.add(tv0, tv1)
            fd.add_output(tv2)
    
        def _schedule_func(fd: FusionDefinition):
            # Parameters
            number_of_stages = 4
            prefetch_distance = 1
            bulk_inner_dim = 128
    
            tv_inputs = list(filter(lambda v: v.is_tensor(), fd.fusion.inputs()))
            assert len(tv_inputs) == 2
            tv0, tv1 = tv_inputs
    
            # Use TMA to load TV0 into shared memory
            tv3 = tv0.cache_after(LoadStoreOpType.tma)
            tv3.set_memory_type(MemoryType.shared)
    
            tv4 = tv1.cache_after(LoadStoreOpType.tma)
            tv4.set_memory_type(MemoryType.shared)
    
            tv_outputs = list(filter(lambda v: v.is_tensor(), fd.fusion.outputs()))
            assert len(tv_outputs) == 1
            reference = tv_outputs[0]
    
            # [M, N] -> [M, N/bid, bid]
            reference.split(-1, bulk_inner_dim)
            fd.sched.transform_like(reference)
    
            tv3.axis(0).parallelize(ParallelType.grid_x)
            tv4.axis(0).parallelize(ParallelType.grid_x)
    
            # Set computeAt position
            fd.sched.inline_at(reference, pos=2)
    
            # Circular Buffer with TMA loads
            tv3.axis(2).parallelize(ParallelType.tma)
            tv4.axis(2).parallelize(ParallelType.tma)
            fd.sched.warp_specialize(
                tv3, number_of_stages, prefetch_distance, ParallelType.block_y
            )
            fd.sched.warp_specialize(
                tv4, number_of_stages, prefetch_distance, ParallelType.block_y
            )
    
            # Split reference to parallelize TMA tile
            reference.split(-1, bulk_inner_dim)
            reference.axis(0).parallelize(ParallelType.grid_x)
            reference.axis(-1).parallelize(ParallelType.block_x)
    
        # Inputs
        tensor_outer_dim = 128
        tensor_inner_dim = 1024
        t0 = torch.randn(
            tensor_outer_dim,
            tensor_inner_dim,
            dtype=torch.float,
            device=torch.device("cuda:0"),
        )
        t1 = torch.randn(
            tensor_outer_dim,
            tensor_inner_dim,
            dtype=torch.float,
            device=torch.device("cuda:0"),
        )
        inputs = [t0, t1]
    
        with FusionDefinition() as fd:
            _definition_func(fd, inputs)
            _schedule_func(fd)
    
        outputs = fd.manual_execute(inputs)
        warp_specialization_if_stmt = "if ((((nvfuser_index_t)threadIdx.y) >= 1LL)) {"
        assert warp_specialization_if_stmt in fd.fusion.print_kernel()

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 15, 2025

    Greptile Overview

    Greptile Summary

    This PR adds Python bindings for warp specialization circular buffering to the direct bindings API:

    • Added warp_specialize function binding in schedule.cpp that wraps the C++ WarpSpecialized circular buffer type for TMA loads with customizable stages, prefetch distance, parallel type, and optional register configuration
    • Added inline_at function binding that wraps inlineAllAt/inlineSelectedAt for controlling compute-at positions in scheduled fusions
    • Included a comprehensive test example test_warp_specialized_circular_buffering_pointwise demonstrating warp specialization with TMA operations on a simple pointwise addition kernel

    The implementation follows existing patterns in the codebase and correctly handles the optional num_registers parameter for register sharing between warps.

    Confidence Score: 5/5

    • This PR is safe to merge - it adds new API bindings with a working test example and follows established patterns.
    • High confidence because: (1) the new bindings correctly wrap existing C++ functionality with proper parameter handling, (2) the test properly guards for Hopper+ devices using the standard is_pre_hopper() check, (3) the test validates both kernel generation and numerical correctness, and (4) the implementation follows existing patterns in the codebase.
    • No files require special attention.

    Important Files Changed

    File Analysis

    Filename Score Overview
    python/python_direct/schedule.cpp 5/5 Added inline_at and warp_specialize Python bindings for schedule operations. The implementations correctly wrap the C++ inlineAllAt/inlineSelectedAt and WarpSpecialized circular buffering functions.
    tests/python/direct/test_tutorial.py 5/5 Added test_warp_specialized_circular_buffering_pointwise example demonstrating the new warp_specialize and inline_at schedule operations with TMA loads. Test properly guarded for Hopper+ devices.

    Sequence Diagram

    sequenceDiagram
        participant Test as Test Code
        participant FD as FusionDefinition
        participant Sched as Schedule Module
        participant TV as TensorView
        participant CB as CircularBuffer
    
        Test->>FD: Create fusion definition
        FD->>TV: from_pytorch(t0, t1)
        FD->>TV: ops.add(tv0, tv1)
        FD->>FD: add_output(tv2)
        
        Note over Test,CB: Schedule Phase
        Test->>TV: cache_after(LoadStoreOpType.tma)
        Test->>TV: set_memory_type(MemoryType.shared)
        Test->>TV: split(-1, bulk_inner_dim)
        Test->>Sched: transform_like(reference)
        Test->>TV: parallelize(ParallelType.grid_x)
        Test->>Sched: inline_at(reference, pos=2)
        Test->>TV: parallelize(ParallelType.tma)
        Test->>Sched: warp_specialize(tv, stages, prefetch, ParallelType.block_y)
        Sched->>CB: WarpSpecialized(parallel_type)
        CB->>TV: circularBuffer(stages, prefetch, type)
        
        Note over Test,CB: Execution Phase
        Test->>FD: manual_execute(inputs)
        FD-->>Test: outputs
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @rdspring1
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects.

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants