Skip to content

Conversation

liqiangxl
Copy link
Collaborator

No description provided.

Copy link

github-actions bot commented Sep 4, 2025

Description

  • Enable native max.bf16 PTX instruction on supported architectures

  • Add BFloat16 warp reduction specialization using float conversion

  • Add test cases for BFloat16 max operation and auto-scheduler

  • Support BFloat16 in reduction benchmarks with math precision toggle


Changes walkthrough 📝

Relevant files
Enhancement
codegen.cpp
Fix formatting and BFloat16 cast in codegen                           

csrc/codegen.cpp

  • Reformatted code for consistency (no functional change)
  • Added conditional cast to float for BFloat16 output in reduction
  • +10/-11 
    test_reduction.py
    Add BFloat16 reduction benchmark with math mode toggle     

    benchmarks/python/test_reduction.py

  • Added benchmark for BFloat16 reduction with configurable math mode
  • Supports toggling native BF16 math vs float emulation
  • +45/-0   
    helpers.cu
    Implement native max.bf16 intrinsic                                           

    runtime/helpers.cu

  • Implemented native max.bf16 PTX instruction for sm_80+
  • Fallback to float conversion on older architectures
  • +16/-0   
    warp.cu
    Add BFloat16 warp reduction via float conversion                 

    runtime/warp.cu

  • Added BFloat16 specialization for warp reduction
  • Converts to float for shuffle operations and back
  • +26/-0   
    Tests
    test_math_opt.cpp
    Add BFloat16 max operation tests                                                 

    tests/cpp/test_math_opt.cpp

  • Added test for BFloat16 max using max.bf16 PTX instruction
  • Added auto-scheduled version of BFloat16 max test
  • +48/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Potential Precision Loss

    The conversion between bfloat16 and float in the fmax function for __bfloat may introduce precision loss, especially when dealing with values that cannot be accurately represented in bfloat16 format. This could affect the correctness of the results in certain edge cases.

    auto a_float = __bfloat2float(a);
    auto b_float = __bfloat2float(b);
    return __float2bfloat(fmax(a_float, b_float));
    Inefficient Type Conversion

    The specialization of warpReduceTIDX for __bfloat involves converting to and from float for each operation, which may lead to unnecessary computational overhead. This could be optimized by using native bfloat16 operations if available on the target architecture.

    float out_fp32 = __bfloat2float(out);
    float inp_val_fp32 = __bfloat2float(inp_val);
    float init_val_fp32 = __bfloat2float(init_val);
    auto reduction_op_fp32 = [](float& a, float b) { a = fmaxf(a, b); };
    warpReduceTIDX<SINGLE_WARP, Aligned, float>(
        out_fp32,
        inp_val_fp32,
        reduction_op_fp32,
        shared_mem,
        read_write_pred,
        init_val_fp32,
        block_dim);
    out = __float2bfloat(out_fp32);
    Conditional Shared Memory Casting

    The conditional casting of shared memory based on output data type may lead to inconsistent behavior or potential bugs if the DataType::BFloat16 handling is not properly tested across different scenarios.

    if (output->dtype() != DataType::BFloat16) {
      func_args.arg(genStaticCast(genPtrType(output->dtype()), "shared_mem"));
    } else {
      func_args.arg(genStaticCast(genPtrType(DataType::Float), "shared_mem"));
    }

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant