From ce3fb73c7a5eaa9ee1a65324c866a8c5ba0a802d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Mar 2026 20:47:00 +0100 Subject: [PATCH] [asm] Add --mlir-timing support to waveasm-translate Register MLIR pass manager and timing manager CLI options so waveasm-translate accepts --mlir-timing for per-pass timing reports. Also add llvm::Timer instrumentation for the translation and assembly emission phases which are outside the PassManager. Plumb print_pass_times through WaveASMCompiler so the Python side can request timing via the same flag used for Wave graph passes. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/waveasm_e2e.py | 10 +++ .../waveasm-translate/waveasm-translate.cpp | 65 +++++++++++++------ waveasm/waveasm_e2e.py | 10 +++ 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/wave_lang/kernel/wave/waveasm_e2e.py b/wave_lang/kernel/wave/waveasm_e2e.py index ff4f008d95..fffba77a1f 100644 --- a/wave_lang/kernel/wave/waveasm_e2e.py +++ b/wave_lang/kernel/wave/waveasm_e2e.py @@ -30,6 +30,7 @@ import os import subprocess +import sys import tempfile from dataclasses import dataclass from pathlib import Path @@ -159,10 +160,12 @@ def __init__( target: str = "gfx942", codeobj: str = "5", keep_temp_files: bool = False, + print_pass_times: bool = False, ): self.target = target self.codeobj = codeobj self.keep_temp_files = keep_temp_files + self.print_pass_times = print_pass_times self.waveasm_translate = get_waveasm_translate_path() self.clang = get_clang_path() self._temp_dir = None @@ -229,6 +232,9 @@ def compile_mlir_to_asm( ] ) + if self.print_pass_times: + cmd.append("--mlir-timing") + cmd.append(str(mlir_file)) try: @@ -242,6 +248,10 @@ def compile_mlir_to_asm( if result.returncode != 0: return False, result.stderr, result.stderr + # Print timing report from stderr when requested. + if self.print_pass_times and result.stderr: + print(result.stderr, file=sys.stderr) + # The output is printed to stdout. asm_text = result.stdout diff --git a/waveasm/tools/waveasm-translate/waveasm-translate.cpp b/waveasm/tools/waveasm-translate/waveasm-translate.cpp index 4af6f98773..765c539ae1 100644 --- a/waveasm/tools/waveasm-translate/waveasm-translate.cpp +++ b/waveasm/tools/waveasm-translate/waveasm-translate.cpp @@ -34,10 +34,13 @@ #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" #include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/Timing.h" #include "mlir/Transforms/Passes.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" #define DEBUG_TYPE "waveasm-translate" #include "llvm/Support/InitLLVM.h" @@ -110,6 +113,8 @@ int main(int argc, char **argv) { // Register passes so PassPipelineCLParser can expose them as CLI flags. waveasm::registerWaveASMPasses(); mlir::registerTransformsPasses(); + mlir::registerPassManagerCLOptions(); + mlir::registerDefaultTimingManagerCLOptions(); // Construct AFTER pass registration — PassNameParser::initialize() snapshots // the registry, so passes must already be registered at this point. @@ -160,6 +165,11 @@ int main(int argc, char **argv) { // Run pre-translation MLIR passes. { PassManager prePm(&context); + if (failed(applyPassManagerCLOptions(prePm))) { + llvm::errs() << "Failed to apply pass manager CLI options\n"; + return 1; + } + applyDefaultTimingPassManagerCLOptions(prePm); // Scalarize vector.extract from broadcast+dense-const patterns so the // translator only sees ordinary scalar IR. prePm.addPass(waveasm::createWAVEASMExtractScalarization()); @@ -171,28 +181,39 @@ int main(int argc, char **argv) { } } - // Use TranslationOptions if workgroup size is specified - if (workgroupSizeX > 0 || workgroupSizeY > 0 || workgroupSizeZ > 0) { - waveasm::TranslationOptions options; - options.targetId = targetId.getValue(); - options.workgroupSizeX = workgroupSizeX; - options.workgroupSizeY = workgroupSizeY; - options.workgroupSizeZ = workgroupSizeZ; - options.subgroupSize = subgroupSize; - if (failed(waveasm::translateModule(*module, options))) { - llvm::errs() << "Translation failed\n"; - return 1; - } - } else { - if (failed(waveasm::translateModule(*module, targetId))) { - llvm::errs() << "Translation failed\n"; - return 1; + // Translate MLIR to WaveASM IR. + { + llvm::Timer translationTimer("TranslateFromMLIR", + "MLIR to WaveASM translation"); + translationTimer.startTimer(); + if (workgroupSizeX > 0 || workgroupSizeY > 0 || workgroupSizeZ > 0) { + waveasm::TranslationOptions options; + options.targetId = targetId.getValue(); + options.workgroupSizeX = workgroupSizeX; + options.workgroupSizeY = workgroupSizeY; + options.workgroupSizeZ = workgroupSizeZ; + options.subgroupSize = subgroupSize; + if (failed(waveasm::translateModule(*module, options))) { + llvm::errs() << "Translation failed\n"; + return 1; + } + } else { + if (failed(waveasm::translateModule(*module, targetId))) { + llvm::errs() << "Translation failed\n"; + return 1; + } } + translationTimer.stopTimer(); } } // Build pass pipeline from CLI flags. PassManager pm(&context); + if (failed(applyPassManagerCLOptions(pm))) { + llvm::errs() << "Failed to apply pass manager CLI options\n"; + return 1; + } + applyDefaultTimingPassManagerCLOptions(pm); if (passPipeline.hasAnyOccurrences()) { auto errorHandler = [](const Twine &msg) { llvm::errs() << msg << "\n"; @@ -238,18 +259,20 @@ int main(int argc, char **argv) { return 1; } - // Emit assembly if requested + // Emit assembly if requested. if (emitAssembly) { - // Create an empty physical mapping (for already-physical registers) + llvm::Timer asmTimer("EmitAssembly", "WaveASM assembly emission"); + asmTimer.startTimer(); + // Create an empty physical mapping (for already-physical registers). waveasm::PhysicalMapping mapping; - // Find all programs and emit assembly for each + // Find all programs and emit assembly for each. bool success = true; module->walk([&](waveasm::ProgramOp program) { - if (failed(waveasm::writeAssembly(program, mapping, outputStream))) { + if (failed(waveasm::writeAssembly(program, mapping, outputStream))) success = false; - } }); + asmTimer.stopTimer(); return success ? 0 : 1; } diff --git a/waveasm/waveasm_e2e.py b/waveasm/waveasm_e2e.py index 1a33e798f8..765456c8e4 100644 --- a/waveasm/waveasm_e2e.py +++ b/waveasm/waveasm_e2e.py @@ -30,6 +30,7 @@ import os import subprocess +import sys import tempfile from dataclasses import dataclass from pathlib import Path @@ -157,10 +158,12 @@ def __init__( target: str = "gfx942", codeobj: str = "5", keep_temp_files: bool = False, + print_pass_times: bool = False, ): self.target = target self.codeobj = codeobj self.keep_temp_files = keep_temp_files + self.print_pass_times = print_pass_times self.waveasm_translate = get_waveasm_translate_path() self.clang = get_clang_path() self._temp_dir: Optional[Path] = None @@ -227,6 +230,9 @@ def compile_mlir_to_asm( ] ) + if self.print_pass_times: + cmd.append("--mlir-timing") + cmd.append(str(mlir_file)) try: @@ -240,6 +246,10 @@ def compile_mlir_to_asm( if result.returncode != 0: return False, result.stderr, result.stderr + # Print timing report from stderr when requested. + if self.print_pass_times and result.stderr: + print(result.stderr, file=sys.stderr) + # The output is printed to stdout asm_text = result.stdout