diff --git a/KLR/Core/Basic.lean b/KLR/Core/Basic.lean index d6c0b0f1..24c1b62c 100644 --- a/KLR/Core/Basic.lean +++ b/KLR/Core/Basic.lean @@ -58,6 +58,7 @@ def Stmt.engine : Stmt -> Engine @[serde tag = 104] structure Block where label : String + noReorder : Bool body : List Stmt deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp diff --git a/KLR/Extract/Extract/SerdeCpp.lean b/KLR/Extract/Extract/SerdeCpp.lean index 0032b1c1..ab1c3961 100644 --- a/KLR/Extract/Extract/SerdeCpp.lean +++ b/KLR/Extract/Extract/SerdeCpp.lean @@ -144,11 +144,11 @@ private def genSer (ty : LeanType) : MetaM Unit := do match v with | .prod n fs => do IO.println s!"case {name}::Tag::{Cpp.varName n}: \{" - IO.println s!" auto *typed_value = static_cast(value.get());" if fs.isEmpty then IO.println s!" return true; // {Cpp.varName n} variant has no fields to serialize" else -- Generate proper sequential serialization with error checking + IO.println s!" auto *typed_value = static_cast(value.get());" let rec serializeFields (fields : List Field) : IO Unit := do match fields with | [] => pure () diff --git a/KLR/Trace.lean b/KLR/Trace.lean index ec3c9298..c23bbba3 100644 --- a/KLR/Trace.lean +++ b/KLR/Trace.lean @@ -122,12 +122,10 @@ def runLncKernels (k : NKI.Kernel) (genDebug : Bool := false) let mut result := [{ res with result := () }] let mut bodies := [res.result.body] - let mut edges := k.edges ++ res.edges.map fun (a,b) => ⟨ a, [b] ⟩ for i in [1:num] do let res <- runNkiKernel k genDebug (i,num) result := { res with result := () } :: result bodies := res.result.body :: bodies - edges := edges ++ res.edges.map fun (a,b) => ⟨ a, [b] ⟩ sharedBuffers := sharedBuffers ++ res.sharedBuffers outputLists := outputLists ++ [res.result.outputs] compareLabels firstKernelLabels res.labels @@ -145,7 +143,7 @@ def runLncKernels (k : NKI.Kernel) (genDebug : Bool := false) outputs := k0.outputs bodies := bodies.reverse sharedConstants := [] - edges := edges + edges := k.edges sharedBuffers := dedupedSharedBuffers } return (result.reverse, kernel, outputsByPosition) diff --git a/KLR/Trace/Types.lean b/KLR/Trace/Types.lean index 9719b270..33a0dbc2 100644 --- a/KLR/Trace/Types.lean +++ b/KLR/Trace/Types.lean @@ -209,14 +209,6 @@ abbrev SharedConstant := String × TensorLib.Tensor abbrev SharedConstants := Array SharedConstant abbrev Env := Std.HashMap Name Term -structure LastInst where - act : Option String := none - dma : Option String := none - dve : Option String := none - pe : Option String := none - pool : Option String := none - sp : Option String := none - structure State where globals : Env := ∅ locals : Env := ∅ @@ -235,9 +227,7 @@ structure State where -- debug info debug : DebugInfo := {} -- no reorder - edges : List (String × String) := [] noReorderDepth : Nat := 0 - lastInst : LastInst := {} instance : Inhabited State where default := {} @@ -336,28 +326,6 @@ def enterFun (m : Trace a) : Trace a := do -- append fully traced statement -private def swapLast (engine : Engine) (name : String) (last : LastInst) : LastInst × List String := - let (names, last) := match engine with - | .act => ([last.act], { last with act := some name}) - | .dma => ([last.dma], { last with dma := some name}) - | .dve => ([last.dve], { last with dve := some name}) - | .pe => ([last.pe], { last with pe := some name}) - | .pool => ([last.pool], { last with pool := some name}) - | .sp => ([last.sp], { last with sp := some name}) - | .unassigned => - ([last.act, last.dma, last.dve, last.pe, last.pool, last.sp], - { act := some name, dma := some name, - dve := some name, pe := some name, - pool := some name, sp := some name }) - let names := names.filterMap fun n => n.map toString - (last, names) - -private def updateLast (engine : Engine) (name : String) (state : State) : State := - if state.noReorderDepth == 0 then state else - let (last, names) := swapLast engine name state.lastInst - let edges := names.map (·, name) - { state with lastInst := last, edges := edges ++ state.edges } - def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do let pos <- getPos let (stmt, name) <- match stmt pos with @@ -367,7 +335,6 @@ def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do | .oper op (some name) pos => pure (Core.Stmt.oper op (some name) pos, name) modifyThe State fun s => - let s := updateLast stmt.engine name s { s with stmts := s.stmts.push stmt } dbgAdd name @@ -413,14 +380,14 @@ def addImm (src dst : String) (imm : Int) : Trace Unit := do }) (<- genLabel `brnz)) -def endBlock (next : Option String := none) : Trace Unit := do +def endBlock (next : Option String := none) (noReorder : Bool := false) : Trace Unit := do if let some target := next then jmp target modify fun st => let body := match st.label with | none => st.body - | some lbl => st.body.push ⟨ lbl, st.stmts.toList ⟩ + | some lbl => st.body.push ⟨ lbl, noReorder, st.stmts.toList ⟩ { st with body := body @@ -429,22 +396,21 @@ def endBlock (next : Option String := none) : Trace Unit := do } def beginBlock (label : Option String := none) : Trace String := do + if (<- get).noReorderDepth > 0 then + throw "Dynamic control-flow cannot be nested with a no_reorder block" let l := label.getD ((<- genLabel `label)) endBlock l return l -def beginWithBlock : Trace Unit := +def beginWithBlock : Trace Unit := do + if (<- get).noReorderDepth == 0 then + let _ <- beginBlock modify fun s => { s with noReorderDepth := s.noReorderDepth + 1 } -def endWithBlock : Trace Unit := - modify fun s => - let (i, d) := match s.noReorderDepth with - | 0 | 1 => ({ : LastInst}, 0) - | .succ n => (s.lastInst, n) - { s with - noReorderDepth := d - lastInst := i - } +def endWithBlock : Trace Unit := do + if (<- get).noReorderDepth == 1 then + endBlock (<- genLabel `label) true + modify fun s => { s with noReorderDepth := s.noReorderDepth - 1 } private def identity (n : Nat) : TensorLib.Tensor := Id.run do let dtype := TensorLib.Dtype.int8 @@ -487,7 +453,7 @@ def addId : Trace Unit := do let lbl := (<- genLabel `init) let idTensor := identity 128 modify fun s => { s with - body := #[Block.mk lbl [initStmt]] ++ s.body, + body := #[Block.mk lbl false [initStmt]] ++ s.body, sharedConstants := s.sharedConstants.push (hbmInitName.toString, idTensor) } extend_global idName (.access (.simple tensorName)) @@ -528,7 +494,6 @@ structure TraceResult (a : Type) where sharedBuffers : List (TensorName × Pos) debug : Array DebugItem labels : Array String - edges : List (String × String) result : a -- Run a `Trace` monad computation, and handle any generated warnings or errors. @@ -538,7 +503,7 @@ def tracer (genDebug : Bool) (g : List (Name × Term)) (m : Trace a) : PassM (Tr runPassWith initialState do let x <- m let st <- get - return ⟨st.sharedConstants, st.sharedBuffers.toList, st.debug.leaf, st.labels, st.edges, x⟩ + return ⟨st.sharedConstants, st.sharedBuffers.toList, st.debug.leaf, st.labels, x⟩ -- Truthiness of Terms following Python namespace Term diff --git a/interop/klr/NKI.asdl b/interop/klr/NKI.asdl index ef9da411..5ecf8209 100644 --- a/interop/klr/NKI.asdl +++ b/interop/klr/NKI.asdl @@ -448,7 +448,7 @@ Stmt = | oper(Operator op, String? name, Pos pos) -Block = (String label, Stmt* body) +Block = (String label, Bool noReorder, Stmt* body) Kernel = (String name, TensorName* inputs, TensorName* outputs, Block* body) diff --git a/interop/klr/klir_ast.hpp b/interop/klr/klir_ast.hpp index baf4be72..6cbe1cd7 100644 --- a/interop/klr/klir_ast.hpp +++ b/interop/klr/klir_ast.hpp @@ -1572,6 +1572,7 @@ struct StmtOperWrapper final : Stmt { struct Block final { String label; + Bool noReorder; List> body; }; diff --git a/interop/klr/klir_pretty_print.cpp b/interop/klr/klir_pretty_print.cpp index 42ce6d40..a6fee5f2 100644 --- a/interop/klr/klir_pretty_print.cpp +++ b/interop/klr/klir_pretty_print.cpp @@ -4310,6 +4310,9 @@ std::string to_string(Block &BlockInstance) { result += "label="; result += BlockInstance.label; result += ", "; + result += "noReorder="; + result += std::to_string(BlockInstance.noReorder); + result += ", "; result += "body="; { size_t i1 = 0; diff --git a/interop/klr/klir_serde.cpp b/interop/klr/klir_serde.cpp index 52eb610c..7c6bf26f 100644 --- a/interop/klr/klir_serde.cpp +++ b/interop/klr/klir_serde.cpp @@ -648,8 +648,6 @@ bool Immediate_ser(FILE *out, const Ptr &value) { return Nat_ser(out, typed_value->reg); } case Immediate::Tag::pointer: { - auto *typed_value = - static_cast(value.get()); return true; // pointer variant has no fields to serialize } case Immediate::Tag::int32: { @@ -1206,8 +1204,6 @@ bool ActivationImm_ser(FILE *out, const Ptr &value) { return Nat_ser(out, typed_value->reg); } case ActivationImm::Tag::pointer: { - auto *typed_value = - static_cast(value.get()); return true; // pointer variant has no fields to serialize } case ActivationImm::Tag::float32: { @@ -1557,11 +1553,9 @@ bool DmaBounds_ser(FILE *out, const Ptr &value) { // Serialize the fields based on the specific variant switch (value->tag) { case DmaBounds::Tag::skip: { - auto *typed_value = static_cast(value.get()); return true; // skip variant has no fields to serialize } case DmaBounds::Tag::error: { - auto *typed_value = static_cast(value.get()); return true; // error variant has no fields to serialize } case DmaBounds::Tag::reg: { @@ -1626,8 +1620,6 @@ bool IndexMissBehavior_ser(FILE *out, const Ptr &value) { return Immediate_ser(out, typed_value->value); } case IndexMissBehavior::Tag::skip: { - auto *typed_value = - static_cast(value.get()); return true; // skip variant has no fields to serialize } default: @@ -2516,8 +2508,6 @@ bool ReplicaGroup_ser(FILE *out, const Ptr &value) { // Serialize the fields based on the specific variant switch (value->tag) { case ReplicaGroup::Tag::unspecified: { - auto *typed_value = - static_cast(value.get()); return true; // unspecified variant has no fields to serialize } case ReplicaGroup::Tag::named: { @@ -3666,10 +3656,12 @@ bool Stmt_ser(FILE *out, const Ptr &value) { } bool Block_ser(FILE *out, const Ptr &value) { - if (!serialize_tag(out, 104, 0, 2)) + if (!serialize_tag(out, 104, 0, 3)) return false; if (!String_ser(out, value->label)) return false; + if (!Bool_ser(out, value->noReorder)) + return false; if (!List_Stmt_ser(out, value->body)) return false; return true; @@ -7730,14 +7722,15 @@ Ptr Block_des(FILE *in) { msg << "Could not find tag, expecting Block:104,0"; throw std::runtime_error(msg.str()); } - if (t != 104 || c != 0 || l != 2) { + if (t != 104 || c != 0 || l != 3) { std::ostringstream msg; - msg << "Expecting Block:(104,0,2)"; + msg << "Expecting Block:(104,0,3)"; msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; throw std::runtime_error(msg.str()); } Ptr x = ptr(); x->label = String_des(in); + x->noReorder = Bool_des(in); x->body = List_Stmt_des(in); return x; }