Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5017,6 +5017,12 @@ def HLSLWaveActiveSum : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void (...)";
}

def HLSLWaveActiveProduct : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_product"];
let Attributes = [NoThrow, Const];
let Prototype = "void (...)";
}

def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_get_lane_index"];
let Attributes = [NoThrow, Const];
Expand Down
28 changes: 28 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,23 @@ static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
}
}

// Return wave active product that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveProductIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
switch (Arch) {
case llvm::Triple::spirv:
return Intrinsic::spv_wave_reduce_product;
case llvm::Triple::dxil: {
if (QT->isUnsignedIntegerType())
return Intrinsic::dx_wave_reduce_uproduct;
return Intrinsic::dx_wave_reduce_product;
}
default:
llvm_unreachable("Intrinsic WaveActiveProduct"
" not supported by target architecture");
}
}

// Return wave active sum that corresponds to the QT scalar type
static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
CGHLSLRuntime &RT, QualType QT) {
Expand Down Expand Up @@ -708,6 +725,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
}
case Builtin::BI__builtin_hlsl_wave_active_product: {
// Due to the use of variadic arguments, explicitly retreive argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID IID = getWaveActiveProductIntrinsic(
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
E->getArg(0)->getType());

return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
&CGM.getModule(), IID, {OpExpr->getType()}),
ArrayRef{OpExpr}, "hlsl.wave.active.product");
}
case Builtin::BI__builtin_hlsl_wave_active_max: {
// Due to the use of variadic arguments, explicitly retreive argument
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Expand Down
124 changes: 124 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2696,6 +2696,130 @@ __attribute__((convergent)) double3 WaveActiveSum(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
__attribute__((convergent)) double4 WaveActiveSum(double4);

//===----------------------------------------------------------------------===//
// WaveActiveProduct builtins
//===----------------------------------------------------------------------===//

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) half WaveActiveProduct(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) half2 WaveActiveProduct(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) half3 WaveActiveProduct(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) half4 WaveActiveProduct(half4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int16_t WaveActiveProduct(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int16_t2 WaveActiveProduct(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int16_t3 WaveActiveProduct(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int16_t4 WaveActiveProduct(int16_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint16_t WaveActiveProduct(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint16_t2 WaveActiveProduct(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint16_t3 WaveActiveProduct(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint16_t4 WaveActiveProduct(uint16_t4);
#endif

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int WaveActiveProduct(int);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int2 WaveActiveProduct(int2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int3 WaveActiveProduct(int3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int4 WaveActiveProduct(int4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint WaveActiveProduct(uint);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint2 WaveActiveProduct(uint2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint3 WaveActiveProduct(uint3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint4 WaveActiveProduct(uint4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int64_t WaveActiveProduct(int64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int64_t2 WaveActiveProduct(int64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int64_t3 WaveActiveProduct(int64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) int64_t4 WaveActiveProduct(int64_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint64_t WaveActiveProduct(uint64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint64_t2 WaveActiveProduct(uint64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint64_t3 WaveActiveProduct(uint64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) uint64_t4 WaveActiveProduct(uint64_t4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) float WaveActiveProduct(float);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) float2 WaveActiveProduct(float2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) float3 WaveActiveProduct(float3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) float4 WaveActiveProduct(float4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) double WaveActiveProduct(double);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) double2 WaveActiveProduct(double2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) double3 WaveActiveProduct(double3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_product)
__attribute__((convergent)) double4 WaveActiveProduct(double4);


//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3197,7 +3197,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
break;
}
case Builtin::BI__builtin_hlsl_wave_active_max:
case Builtin::BI__builtin_hlsl_wave_active_sum: {
case Builtin::BI__builtin_hlsl_wave_active_sum:
case Builtin::BI__builtin_hlsl_wave_active_product: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

Expand Down
45 changes: 45 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveProduct.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_int
int test_int(int expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.product.i32([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.product.i32([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveProduct(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.product.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.product.i32([[TY]]) #[[#attr:]]

// CHECK-LABEL: test_uint64_t
uint64_t test_uint64_t(uint64_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.product.i64([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.uproduct.i64([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveProduct(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.uproduct.i64([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.product.i64([[TY]]) #[[#attr:]]

// Test basic lowering to runtime function call with array and float value.

// CHECK-LABEL: test_floatv4
float4 test_floatv4(float4 expr) {
// CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.reduce.product.v4f32([[TY1]] %[[#]]
// CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.reduce.product.v4f32([[TY1]] %[[#]])
// CHECK: ret [[TY1]] %[[RET1]]
return WaveActiveProduct(expr);
}

// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.product.v4f32([[TY1]]) #[[#attr]]
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.product.v4f32([[TY1]]) #[[#attr]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
28 changes: 28 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveProduct-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg() {
return __builtin_hlsl_wave_active_product();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_active_product(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool test_expr_bool_type_check(bool p0) {
return __builtin_hlsl_wave_active_product(p0);
// expected-error@-1 {{invalid operand of type 'bool'}}
}

bool2 test_expr_bool_vec_type_check(bool2 p0) {
return __builtin_hlsl_wave_active_product(p0);
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
}

struct S { float f; };

S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_product(p0);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_product : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_uproduct : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_get_lane_count
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_product : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_get_lane_count
Expand All @@ -136,7 +137,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_sclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_spv_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

// Create resource handle given the binding information. Returns a
// Create resource handle given the binding information. Returns a
// type appropriate for the kind of resource given the set id, binding id,
// array size of the binding, as well as an index and an indicator
// whether that index may be non-uniform.
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,16 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>,
IntrinArgI8<SignedOpKind_Unsigned>
]>,
IntrinSelect<int_dx_wave_reduce_product,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Product>,
IntrinArgI8<SignedOpKind_Signed>
]>,
IntrinSelect<int_dx_wave_reduce_uproduct,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Product>,
IntrinArgI8<SignedOpKind_Unsigned>
]>,
IntrinSelect<int_dx_wave_reduce_max,
[
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ static bool checkWaveOps(Intrinsic::ID IID) {
// Wave Active Op Variants
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_product:
case Intrinsic::dx_wave_reduce_uproduct:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_umax:
return true;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_product:
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_uproduct:
case Intrinsic::dx_imad:
case Intrinsic::dx_umad:
return true;
Expand Down
31 changes: 31 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectWaveReduceSum(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReduceProduct(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectConst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Expand Down Expand Up @@ -2482,6 +2485,32 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
.addUse(I.getOperand(2).getReg());
}

bool SPIRVInstructionSelector::selectWaveReduceProduct(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();
Register InputRegister = I.getOperand(2).getReg();
SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);

if (!InputType)
report_fatal_error("Input Type could not be determined.");

SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
// Retreive the operation to use based on input type
bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
auto Opcode =
IsFloatTy ? SPIRV::OpGroupNonUniformFMul : SPIRV::OpGroupNonUniformIMul;
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg());
}

bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -3433,6 +3462,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/ false);
case Intrinsic::spv_wave_reduce_sum:
return selectWaveReduceSum(ResVReg, ResType, I);
case Intrinsic::spv_wave_reduce_product:
return selectWaveReduceProduct(ResVReg, ResType, I);
case Intrinsic::spv_wave_readlane:
return selectWaveOpInst(ResVReg, ResType, I,
SPIRV::OpGroupNonUniformShuffle);
Expand Down
Loading