Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions waveasm/include/waveasm/Dialect/WaveASMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ class SALUCmpOp<string mnemonic, list<Trait> traits = []>
let assemblyFormat = "$src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result)";
}

// SALU Physical-Register Write (unary): write src to a physical SGPR dst.
// The destination is an input operand (a PrecoloredSReg reference), not an
// SSA result. Marked with SpecialRegOp to prevent DCE/CSE -- the write is
// a side effect visible only through the physical register file.
class SALUPhysUnaryOp<string mnemonic, list<Trait> traits = []>
: WAVEASMOp<mnemonic, !listconcat([WaveASM_SpecialRegOp], traits)> {
let arguments = (ins WaveASM_AnySGPR:$dst, WaveASM_SRegOrImm:$src);
let results = (outs);
let assemblyFormat = "$dst `,` $src attr-dict `:` type($dst) `,` type($src)";
}

// MFMA: Matrix multiply-accumulate with tied accumulator
// The result is tied to operand index 2 (the accumulator) when acc is a VGPR.
// MFMA instructions support inline 0 for the accumulator (no tying needed).
Expand Down Expand Up @@ -690,6 +701,16 @@ def WaveASM_S_ABS_I32 : SALUUnaryOp<"s_abs_i32">;
def WaveASM_S_SEXT_I32_I8 : SALUUnaryOp<"s_sext_i32_i8">;
def WaveASM_S_SEXT_I32_I16 : SALUUnaryOp<"s_sext_i32_i16">;

//===----------------------------------------------------------------------===//
// SALU Physical-Register Write Instructions
//===----------------------------------------------------------------------===//
// Non-Pure variants for writing to physical (precolored) SGPRs as a side
// effect. Used for SRD setup and cache-swizzle descriptor construction
// where the destination register has no SSA consumer.

def WaveASM_S_MOV_B32_PHYS : SALUPhysUnaryOp<"s_mov_b32_phys">;
def WaveASM_S_MOV_B64_PHYS : SALUPhysUnaryOp<"s_mov_b64_phys">;

//===----------------------------------------------------------------------===//
// SALU Binary Instructions
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions waveasm/lib/Transforms/AssemblyEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ std::optional<std::string> KernelGenerator::generateOp(Operation *op) {
return result;
})

.Case<S_MOV_B32_PHYS>([&](S_MOV_B32_PHYS movOp) {
return formatter.format("s_mov_b32", {resolveValue(movOp.getDst()),
resolveValue(movOp.getSrc())});
})
.Case<S_MOV_B64_PHYS>([&](S_MOV_B64_PHYS movOp) {
return formatter.format("s_mov_b64", {resolveValue(movOp.getDst()),
resolveValue(movOp.getSrc())});
})
.Case<S_BRANCH>([&](S_BRANCH branchOp) {
return std::string(" s_branch ") +
branchOp.getTarget().getRootReference().str();
Expand Down
119 changes: 74 additions & 45 deletions waveasm/lib/Transforms/TranslateFromMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ using namespace mlir;

namespace waveasm {

// Named constants for SRD descriptor fields and limits.
// kDefaultMaxBufferSize lives in TranslationContext (header).
static constexpr int64_t kSRDStrideSwizzle = 0x20000;
static constexpr int64_t kMaxNumRecords32 = 0xFFFFFFFF;
constexpr int64_t kSRDStrideDescriptor = 0x20000;
constexpr int64_t kMaxSRDNumRecords = 0xFFFFFFFF;

//===----------------------------------------------------------------------===//
// TranslationContext Implementation
Expand Down Expand Up @@ -240,9 +238,9 @@ void TranslationContext::emitSRDPrologue() {
RawOp::create(builder, loc, mainLabel + ":");

// Step 3: Copy from preload locations to SRD positions and fill
// size/stride. Must use RawOp: S_MOV_B64/S_MOV_B32 are Pure (SALUUnaryOp)
// and write to physical registers with no SSA consumer, so CSE/DCE
// eliminates them.
// size/stride. Uses SALUPhys ops (non-Pure, SpecialRegOp trait) so that
// writes to physical registers survive DCE/CSE.
auto *mlirCtx = builder.getContext();
for (size_t i = 0; i < pendingSRDs.size(); ++i) {
const auto &pending = pendingSRDs[i];
int64_t srdBase = pending.srdBaseIndex;
Expand All @@ -252,20 +250,31 @@ void TranslationContext::emitSRDPrologue() {
auto srdReg = PrecoloredSRegOp::create(builder, loc, srdType, srdBase, 4);

// Copy base address with s_mov_b64
std::string movB64Str = "s_mov_b64 s[" + std::to_string(srdBase) + ":" +
std::to_string(srdBase + 1) + "], s[" +
std::to_string(preloadBase) + ":" +
std::to_string(preloadBase + 1) + "]";
RawOp::create(builder, loc, movB64Str);

int64_t clampedSize = std::min(pending.bufferSize, kMaxNumRecords32);
std::string movSizeStr = "s_mov_b32 s" + std::to_string(srdBase + 2) +
", 0x" + llvm::utohexstr(clampedSize);
RawOp::create(builder, loc, movSizeStr);

std::string movStrideStr = "s_mov_b32 s" + std::to_string(srdBase + 3) +
", 0x" + llvm::utohexstr(kSRDStrideSwizzle);
RawOp::create(builder, loc, movStrideStr);
auto dstBaseType = PSRegType::get(mlirCtx, srdBase, 2);
auto dstBaseReg =
PrecoloredSRegOp::create(builder, loc, dstBaseType, srdBase, 2);
auto srcBaseType = PSRegType::get(mlirCtx, preloadBase, 2);
auto srcBaseReg =
PrecoloredSRegOp::create(builder, loc, srcBaseType, preloadBase, 2);
S_MOV_B64_PHYS::create(builder, loc, dstBaseReg, srcBaseReg);

// Fill size and stride (clamp to 32-bit max for >4GB buffers;
// per-workgroup SRD adjustment handles the actual addressing)
int64_t clampedSize = std::min(pending.bufferSize, kMaxSRDNumRecords);
auto dstSizeType = PSRegType::get(mlirCtx, srdBase + 2, 1);
auto dstSizeReg =
PrecoloredSRegOp::create(builder, loc, dstSizeType, srdBase + 2, 1);
auto sizeImm = ConstantOp::create(
builder, loc, createImmType(clampedSize), clampedSize);
S_MOV_B32_PHYS::create(builder, loc, dstSizeReg, sizeImm);

auto dstStrideType = PSRegType::get(mlirCtx, srdBase + 3, 1);
auto dstStrideReg =
PrecoloredSRegOp::create(builder, loc, dstStrideType, srdBase + 3, 1);
auto strideImm =
ConstantOp::create(builder, loc, createImmType(kSRDStrideDescriptor),
kSRDStrideDescriptor);
S_MOV_B32_PHYS::create(builder, loc, dstStrideReg, strideImm);

mapper.mapValue(pending.memref, srdReg);
}
Expand Down Expand Up @@ -301,22 +310,33 @@ void TranslationContext::emitSRDPrologue() {
/*expcnt=*/IntegerAttr{});

// Step 3: Fill SRD[2:3] with size and stride.
// Must use RawOp: Pure S_MOV_B32 to physical registers gets DCE'd.
// Uses S_MOV_B32_PHYS (non-Pure, SpecialRegOp) to survive DCE/CSE.
auto *mlirCtx = builder.getContext();
for (size_t i = 0; i < pendingSRDs.size(); ++i) {
const auto &pending = pendingSRDs[i];
int64_t srdBase = pending.srdBaseIndex;

auto srdType = createSRegType(4, 4);
auto srdReg = PrecoloredSRegOp::create(builder, loc, srdType, srdBase, 4);

int64_t clampedSize = std::min(pending.bufferSize, kMaxNumRecords32);
std::string movSizeStr = "s_mov_b32 s" + std::to_string(srdBase + 2) +
", 0x" + llvm::utohexstr(clampedSize);
RawOp::create(builder, loc, movSizeStr);

std::string movStrideStr = "s_mov_b32 s" + std::to_string(srdBase + 3) +
", 0x" + llvm::utohexstr(kSRDStrideSwizzle);
RawOp::create(builder, loc, movStrideStr);
// Fill size (clamp to 32-bit max for >4GB buffers;
// per-workgroup SRD adjustment handles the actual addressing)
int64_t clampedSize = std::min(pending.bufferSize, kMaxSRDNumRecords);
auto dstSizeType = PSRegType::get(mlirCtx, srdBase + 2, 1);
auto dstSizeReg =
PrecoloredSRegOp::create(builder, loc, dstSizeType, srdBase + 2, 1);
auto sizeImm = ConstantOp::create(
builder, loc, createImmType(clampedSize), clampedSize);
S_MOV_B32_PHYS::create(builder, loc, dstSizeReg, sizeImm);

// Fill stride descriptor
auto dstStrideType = PSRegType::get(mlirCtx, srdBase + 3, 1);
auto dstStrideReg =
PrecoloredSRegOp::create(builder, loc, dstStrideType, srdBase + 3, 1);
auto strideImm =
ConstantOp::create(builder, loc, createImmType(kSRDStrideDescriptor),
kSRDStrideDescriptor);
S_MOV_B32_PHYS::create(builder, loc, dstStrideReg, strideImm);

mapper.mapValue(pending.memref, srdReg);
}
Expand Down Expand Up @@ -447,13 +467,13 @@ Value emitSRDBaseAdjustment(const TranslationContext::PendingSRDBaseAdjust &adj,
assert(N + 4 < 108 && "SRD allocation exceeds SGPR limit");

// Copy source SRD base to new SRD.
// Must use RawOp: S_MOV_B64 is Pure (SALUUnaryOp) and writes to a
// physical register with no SSA consumer, so CSE/DCE eliminates it.
std::string copyBase = "s_mov_b64 s[" + std::to_string(N) + ":" +
std::to_string(N + 1) + "], s[" +
std::to_string(adj.srcSrdBase) + ":" +
std::to_string(adj.srcSrdBase + 1) + "]";
RawOp::create(builder, loc, copyBase);
// S_MOV_B64_PHYS (non-Pure, SpecialRegOp) survives DCE/CSE.
auto dstBaseType = PSRegType::get(mlirCtx, N, 2);
auto dstBaseReg = PrecoloredSRegOp::create(builder, loc, dstBaseType, N, 2);
auto srcBaseType = PSRegType::get(mlirCtx, adj.srcSrdBase, 2);
auto srcBaseReg =
PrecoloredSRegOp::create(builder, loc, srcBaseType, adj.srcSrdBase, 2);
S_MOV_B64_PHYS::create(builder, loc, dstBaseReg, srcBaseReg);

// Get element offset -> SGPR via v_readfirstlane_b32 (or s_mov_b32 if
// already scalar). Pinned to s[N+3].
Expand Down Expand Up @@ -488,15 +508,24 @@ Value emitSRDBaseAdjustment(const TranslationContext::PendingSRDBaseAdjust &adj,
S_ADD_U32::create(builder, loc, base0Type, sccType, base0, byteOffLo);
S_ADDC_U32::create(builder, loc, base1Type, sccType, base1, byteOffHi);

// Set num_records and stride using buffer size from the source SRD.
// Set num_records and stride.
// S_MOV_B32_PHYS (non-Pure, SpecialRegOp) survives DCE/CSE.
int64_t bufferSize = ctx.getBufferSizeForSRD(adj.srcSrdBase);
int64_t clampedSize = std::min(bufferSize, kMaxNumRecords32);
std::string movSize = "s_mov_b32 s" + std::to_string(N + 2) + ", 0x" +
llvm::utohexstr(clampedSize);
RawOp::create(builder, loc, movSize);
std::string movStride = "s_mov_b32 s" + std::to_string(N + 3) + ", 0x" +
llvm::utohexstr(kSRDStrideSwizzle);
RawOp::create(builder, loc, movStride);
int64_t clampedSize = std::min(bufferSize, kMaxSRDNumRecords);
auto dstSizeType = PSRegType::get(mlirCtx, N + 2, 1);
auto dstSizeReg =
PrecoloredSRegOp::create(builder, loc, dstSizeType, N + 2, 1);
auto sizeImm = ConstantOp::create(
builder, loc, ctx.createImmType(clampedSize), clampedSize);
S_MOV_B32_PHYS::create(builder, loc, dstSizeReg, sizeImm);

auto dstStrideType = PSRegType::get(mlirCtx, N + 3, 1);
auto dstStrideReg =
PrecoloredSRegOp::create(builder, loc, dstStrideType, N + 3, 1);
auto strideImm =
ConstantOp::create(builder, loc, ctx.createImmType(kSRDStrideDescriptor),
kSRDStrideDescriptor);
S_MOV_B32_PHYS::create(builder, loc, dstStrideReg, strideImm);

auto srdType = ctx.createSRegType(4, 4);
auto srd = PrecoloredSRegOp::create(builder, loc, srdType, N, 4);
Expand Down
61 changes: 42 additions & 19 deletions waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,25 +423,48 @@ LogicalResult handleFatRawBufferCast(Operation *op, TranslationContext &ctx) {
return success();
}

std::string mov0 = "s_mov_b32 s" + std::to_string(newSrdBase) + ", s" +
std::to_string(srcSrdBase);
RawOp::create(builder, loc, mov0);

std::string and1 = "s_and_b32 s" + std::to_string(newSrdBase + 1) + ", s" +
std::to_string(srcSrdBase + 1) + ", 0xffff";
RawOp::create(builder, loc, and1);

std::string or1 = "s_or_b32 s" + std::to_string(newSrdBase + 1) + ", s" +
std::to_string(newSrdBase + 1) + ", 0x40400000";
RawOp::create(builder, loc, or1);

std::string mov2 =
"s_mov_b32 s" + std::to_string(newSrdBase + 2) + ", 0x7ffffffd";
RawOp::create(builder, loc, mov2);

std::string mov3 =
"s_mov_b32 s" + std::to_string(newSrdBase + 3) + ", 0x27000";
RawOp::create(builder, loc, mov3);
auto *mlirCtx = builder.getContext();

// SRD[0] = srcSrd[0]
auto dst0Type = PSRegType::get(mlirCtx, newSrdBase, 1);
auto dst0 = PrecoloredSRegOp::create(builder, loc, dst0Type, newSrdBase, 1);
auto src0Type = PSRegType::get(mlirCtx, srcSrdBase, 1);
auto src0 = PrecoloredSRegOp::create(builder, loc, src0Type, srcSrdBase, 1);
S_MOV_B32_PHYS::create(builder, loc, dst0, src0);

// SRD[1] = (srcSrd[1] & 0xffff) | 0x40400000
// Compute in SSA (Pure ops with data dependency) then write once.
auto dst1Type = PSRegType::get(mlirCtx, newSrdBase + 1, 1);
auto dst1 =
PrecoloredSRegOp::create(builder, loc, dst1Type, newSrdBase + 1, 1);
auto src1Type = PSRegType::get(mlirCtx, srcSrdBase + 1, 1);
auto src1 =
PrecoloredSRegOp::create(builder, loc, src1Type, srcSrdBase + 1, 1);
auto maskImm =
ConstantOp::create(builder, loc, ctx.createImmType(0xffff), 0xffff);
auto sregTy = ctx.createSRegType(1);
auto andResult = S_AND_B32::create(builder, loc, sregTy, src1, maskImm);

auto orImm = ConstantOp::create(builder, loc, ctx.createImmType(0x40400000),
0x40400000);
auto orResult = S_OR_B32::create(builder, loc, sregTy, andResult, orImm);
S_MOV_B32_PHYS::create(builder, loc, dst1, orResult);

// SRD[2] = 0x7ffffffd (num_records)
auto dst2Type = PSRegType::get(mlirCtx, newSrdBase + 2, 1);
auto dst2 =
PrecoloredSRegOp::create(builder, loc, dst2Type, newSrdBase + 2, 1);
auto sizeImm = ConstantOp::create(builder, loc, ctx.createImmType(0x7ffffffd),
0x7ffffffd);
S_MOV_B32_PHYS::create(builder, loc, dst2, sizeImm);

// SRD[3] = 0x27000 (stride / descriptor flags)
auto dst3Type = PSRegType::get(mlirCtx, newSrdBase + 3, 1);
auto dst3 =
PrecoloredSRegOp::create(builder, loc, dst3Type, newSrdBase + 3, 1);
auto strideImm =
ConstantOp::create(builder, loc, ctx.createImmType(0x27000), 0x27000);
S_MOV_B32_PHYS::create(builder, loc, dst3, strideImm);

auto srdType = ctx.createSRegType(4, 4);
auto newSrd = PrecoloredSRegOp::create(builder, loc, srdType, newSrdBase, 4);
Expand Down
87 changes: 87 additions & 0 deletions waveasm/test/Transforms/salu-phys-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// RUN: waveasm-translate --waveasm-scoped-cse %s 2>&1 | FileCheck %s --check-prefix=CSE
// RUN: waveasm-translate --disable-pass-verifier --waveasm-linear-scan --emit-assembly %s | FileCheck %s --check-prefix=ASM
//
// Test SALUPhys ops: non-Pure physical-register-write variants of SALU
// instructions. They must survive CSE (SpecialRegOp trait) and emit
// the correct assembly mnemonic (without the _phys suffix).

//===----------------------------------------------------------------------===//
// Test 1: s_mov_b32_phys / s_mov_b64_phys survive CSE
//===----------------------------------------------------------------------===//

// CSE-LABEL: waveasm.program @phys_mov_survives_cse
// ASM-LABEL: phys_mov_survives_cse:
waveasm.program @phys_mov_survives_cse target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
%dst0 = waveasm.precolored.sreg 8 : !waveasm.psreg<8>
%dst1 = waveasm.precolored.sreg 9 : !waveasm.psreg<9>
%dst_pair = waveasm.precolored.sreg 10, 2 : !waveasm.psreg<10, 2>
%src_pair = waveasm.precolored.sreg 2, 2 : !waveasm.psreg<2, 2>
%imm_size = waveasm.constant 4096 : !waveasm.imm<4096>
%imm_stride = waveasm.constant 131072 : !waveasm.imm<131072>

// Two s_mov_b32_phys to different destinations -- both must survive DCE
// (zero results + SpecialRegOp trait prevents trivial dead elimination).
// CSE: waveasm.s_mov_b32_phys
// CSE: waveasm.s_mov_b32_phys
waveasm.s_mov_b32_phys %dst0, %imm_size : !waveasm.psreg<8>, !waveasm.imm<4096>
waveasm.s_mov_b32_phys %dst1, %imm_stride : !waveasm.psreg<9>, !waveasm.imm<131072>

// s_mov_b64_phys must also survive.
// CSE: waveasm.s_mov_b64_phys
waveasm.s_mov_b64_phys %dst_pair, %src_pair : !waveasm.psreg<10, 2>, !waveasm.psreg<2, 2>

// ASM: s_mov_b32 s8, 4096
// ASM: s_mov_b32 s9, 131072
// ASM: s_mov_b64 s[10:11], s[2:3]

// CSE: waveasm.s_endpgm
// ASM: s_endpgm
waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test 2: AND-OR pattern uses Pure SSA ops + final MOV_PHYS write.
// The SSA data dependency AND->OR->MOV enforces ordering structurally.
//===----------------------------------------------------------------------===//

// CSE-LABEL: waveasm.program @ssa_and_or_with_phys_write
// ASM-LABEL: ssa_and_or_with_phys_write:
waveasm.program @ssa_and_or_with_phys_write target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
%dst = waveasm.precolored.sreg 5 : !waveasm.psreg<5>
%src = waveasm.precolored.sreg 3 : !waveasm.psreg<3>
%mask = waveasm.constant 65535 : !waveasm.imm<65535>
%flags = waveasm.constant 1077936128 : !waveasm.imm<1077936128>

// Pure AND and OR compute in SSA; final MOV_PHYS writes to physical reg.
// CSE: waveasm.s_and_b32
// CSE: waveasm.s_or_b32
// CSE: waveasm.s_mov_b32_phys
%and = waveasm.s_and_b32 %src, %mask : !waveasm.psreg<3>, !waveasm.imm<65535> -> !waveasm.sreg
%or = waveasm.s_or_b32 %and, %flags : !waveasm.sreg, !waveasm.imm<1077936128> -> !waveasm.sreg
waveasm.s_mov_b32_phys %dst, %or : !waveasm.psreg<5>, !waveasm.sreg

// ASM: s_and_b32
// ASM: s_or_b32
// ASM: s_mov_b32 s5,

// CSE: waveasm.s_endpgm
// ASM: s_endpgm
waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test 3: Pure s_mov_b32 IS still CSE'd (contrast with _phys variant)
//===----------------------------------------------------------------------===//

// CSE-LABEL: waveasm.program @pure_mov_is_csed
waveasm.program @pure_mov_is_csed target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
%imm = waveasm.constant 42 : !waveasm.imm<42>

// Two identical Pure s_mov_b32 -- second should be CSE'd away.
// CSE: waveasm.s_mov_b32
%r0 = waveasm.s_mov_b32 %imm : !waveasm.imm<42> -> !waveasm.sreg
// CSE-NOT: waveasm.s_mov_b32 %{{.*}} : !waveasm.imm<42>
%r1 = waveasm.s_mov_b32 %imm : !waveasm.imm<42> -> !waveasm.sreg

waveasm.s_endpgm
}
Loading