Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions KLR/Core/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def Stmt.engine : Stmt -> Engine
@[serde tag = 104]
structure Block where
label : String
noReorder : Bool
body : List Stmt
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

Expand Down
2 changes: 1 addition & 1 deletion KLR/Extract/Extract/SerdeCpp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ private def genSer (ty : LeanType) : MetaM Unit := do
match v with
| .prod n fs => do
IO.println s!"case {name}::Tag::{Cpp.varName n}: \{"
IO.println s!" auto *typed_value = static_cast<const {Cpp.subclassName n} *>(value.get());"
if fs.isEmpty then
IO.println s!" return true; // {Cpp.varName n} variant has no fields to serialize"
else
-- Generate proper sequential serialization with error checking
IO.println s!" auto *typed_value = static_cast<const {Cpp.subclassName n} *>(value.get());"
let rec serializeFields (fields : List Field) : IO Unit := do
match fields with
| [] => pure ()
Expand Down
4 changes: 1 addition & 3 deletions KLR/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,10 @@ def runLncKernels (k : NKI.Kernel) (genDebug : Bool := false)

let mut result := [{ res with result := () }]
let mut bodies := [res.result.body]
let mut edges := k.edges ++ res.edges.map fun (a,b) => ⟨ a, [b] ⟩
for i in [1:num] do
let res <- runNkiKernel k genDebug (i,num)
result := { res with result := () } :: result
bodies := res.result.body :: bodies
edges := edges ++ res.edges.map fun (a,b) => ⟨ a, [b] ⟩
sharedBuffers := sharedBuffers ++ res.sharedBuffers
outputLists := outputLists ++ [res.result.outputs]
compareLabels firstKernelLabels res.labels
Expand All @@ -145,7 +143,7 @@ def runLncKernels (k : NKI.Kernel) (genDebug : Bool := false)
outputs := k0.outputs
bodies := bodies.reverse
sharedConstants := []
edges := edges
edges := k.edges
sharedBuffers := dedupedSharedBuffers
}
return (result.reverse, kernel, outputsByPosition)
61 changes: 13 additions & 48 deletions KLR/Trace/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,6 @@ abbrev SharedConstant := String × TensorLib.Tensor
abbrev SharedConstants := Array SharedConstant
abbrev Env := Std.HashMap Name Term

structure LastInst where
act : Option String := none
dma : Option String := none
dve : Option String := none
pe : Option String := none
pool : Option String := none
sp : Option String := none

structure State where
globals : Env := ∅
locals : Env := ∅
Expand All @@ -235,9 +227,7 @@ structure State where
-- debug info
debug : DebugInfo := {}
-- no reorder
edges : List (String × String) := []
noReorderDepth : Nat := 0
lastInst : LastInst := {}

instance : Inhabited State where
default := {}
Expand Down Expand Up @@ -336,28 +326,6 @@ def enterFun (m : Trace a) : Trace a := do

-- append fully traced statement

private def swapLast (engine : Engine) (name : String) (last : LastInst) : LastInst × List String :=
let (names, last) := match engine with
| .act => ([last.act], { last with act := some name})
| .dma => ([last.dma], { last with dma := some name})
| .dve => ([last.dve], { last with dve := some name})
| .pe => ([last.pe], { last with pe := some name})
| .pool => ([last.pool], { last with pool := some name})
| .sp => ([last.sp], { last with sp := some name})
| .unassigned =>
([last.act, last.dma, last.dve, last.pe, last.pool, last.sp],
{ act := some name, dma := some name,
dve := some name, pe := some name,
pool := some name, sp := some name })
let names := names.filterMap fun n => n.map toString
(last, names)

private def updateLast (engine : Engine) (name : String) (state : State) : State :=
if state.noReorderDepth == 0 then state else
let (last, names) := swapLast engine name state.lastInst
let edges := names.map (·, name)
{ state with lastInst := last, edges := edges ++ state.edges }

def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do
let pos <- getPos
let (stmt, name) <- match stmt pos with
Expand All @@ -367,7 +335,6 @@ def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do
| .oper op (some name) pos =>
pure (Core.Stmt.oper op (some name) pos, name)
modifyThe State fun s =>
let s := updateLast stmt.engine name s
{ s with stmts := s.stmts.push stmt }
dbgAdd name

Expand Down Expand Up @@ -413,14 +380,14 @@ def addImm (src dst : String) (imm : Int) : Trace Unit := do
})
(<- genLabel `brnz))

def endBlock (next : Option String := none) : Trace Unit := do
def endBlock (next : Option String := none) (noReorder : Bool := false) : Trace Unit := do
if let some target := next then
jmp target

modify fun st =>
let body := match st.label with
| none => st.body
| some lbl => st.body.push ⟨ lbl, st.stmts.toList ⟩
| some lbl => st.body.push ⟨ lbl, noReorder, st.stmts.toList ⟩

{ st with
body := body
Expand All @@ -429,22 +396,21 @@ def endBlock (next : Option String := none) : Trace Unit := do
}

def beginBlock (label : Option String := none) : Trace String := do
if (<- get).noReorderDepth > 0 then
throw "Dynamic control-flow cannot be nested with a no_reorder block"
let l := label.getD ((<- genLabel `label))
endBlock l
return l

def beginWithBlock : Trace Unit :=
def beginWithBlock : Trace Unit := do
if (<- get).noReorderDepth == 0 then
let _ <- beginBlock
modify fun s => { s with noReorderDepth := s.noReorderDepth + 1 }

def endWithBlock : Trace Unit :=
modify fun s =>
let (i, d) := match s.noReorderDepth with
| 0 | 1 => ({ : LastInst}, 0)
| .succ n => (s.lastInst, n)
{ s with
noReorderDepth := d
lastInst := i
}
def endWithBlock : Trace Unit := do
if (<- get).noReorderDepth == 1 then
endBlock (<- genLabel `label) true
modify fun s => { s with noReorderDepth := s.noReorderDepth - 1 }

private def identity (n : Nat) : TensorLib.Tensor := Id.run do
let dtype := TensorLib.Dtype.int8
Expand Down Expand Up @@ -487,7 +453,7 @@ def addId : Trace Unit := do
let lbl := (<- genLabel `init)
let idTensor := identity 128
modify fun s => { s with
body := #[Block.mk lbl [initStmt]] ++ s.body,
body := #[Block.mk lbl false [initStmt]] ++ s.body,
sharedConstants := s.sharedConstants.push (hbmInitName.toString, idTensor)
}
extend_global idName (.access (.simple tensorName))
Expand Down Expand Up @@ -528,7 +494,6 @@ structure TraceResult (a : Type) where
sharedBuffers : List (TensorName × Pos)
debug : Array DebugItem
labels : Array String
edges : List (String × String)
result : a

-- Run a `Trace` monad computation, and handle any generated warnings or errors.
Expand All @@ -538,7 +503,7 @@ def tracer (genDebug : Bool) (g : List (Name × Term)) (m : Trace a) : PassM (Tr
runPassWith initialState do
let x <- m
let st <- get
return ⟨st.sharedConstants, st.sharedBuffers.toList, st.debug.leaf, st.labels, st.edges, x⟩
return ⟨st.sharedConstants, st.sharedBuffers.toList, st.debug.leaf, st.labels, x⟩

-- Truthiness of Terms following Python
namespace Term
Expand Down
2 changes: 1 addition & 1 deletion interop/klr/NKI.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ Stmt =
| oper(Operator op, String? name, Pos pos)


Block = (String label, Stmt* body)
Block = (String label, Bool noReorder, Stmt* body)

Kernel = (String name, TensorName* inputs, TensorName* outputs, Block* body)

Expand Down
1 change: 1 addition & 0 deletions interop/klr/klir_ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,7 @@ struct StmtOperWrapper final : Stmt {

struct Block final {
String label;
Bool noReorder;
List<Ptr<Stmt>> body;
};

Expand Down
3 changes: 3 additions & 0 deletions interop/klr/klir_pretty_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4310,6 +4310,9 @@ std::string to_string(Block &BlockInstance) {
result += "label=";
result += BlockInstance.label;
result += ", ";
result += "noReorder=";
result += std::to_string(BlockInstance.noReorder);
result += ", ";
result += "body=";
{
size_t i1 = 0;
Expand Down
19 changes: 6 additions & 13 deletions interop/klr/klir_serde.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,6 @@ bool Immediate_ser(FILE *out, const Ptr<Immediate> &value) {
return Nat_ser(out, typed_value->reg);
}
case Immediate::Tag::pointer: {
auto *typed_value =
static_cast<const ImmediatePointerWrapper *>(value.get());
return true; // pointer variant has no fields to serialize
}
case Immediate::Tag::int32: {
Expand Down Expand Up @@ -1206,8 +1204,6 @@ bool ActivationImm_ser(FILE *out, const Ptr<ActivationImm> &value) {
return Nat_ser(out, typed_value->reg);
}
case ActivationImm::Tag::pointer: {
auto *typed_value =
static_cast<const ActivationImmPointerWrapper *>(value.get());
return true; // pointer variant has no fields to serialize
}
case ActivationImm::Tag::float32: {
Expand Down Expand Up @@ -1557,11 +1553,9 @@ bool DmaBounds_ser(FILE *out, const Ptr<DmaBounds> &value) {
// Serialize the fields based on the specific variant
switch (value->tag) {
case DmaBounds::Tag::skip: {
auto *typed_value = static_cast<const DmaBoundsSkipWrapper *>(value.get());
return true; // skip variant has no fields to serialize
}
case DmaBounds::Tag::error: {
auto *typed_value = static_cast<const DmaBoundsErrorWrapper *>(value.get());
return true; // error variant has no fields to serialize
}
case DmaBounds::Tag::reg: {
Expand Down Expand Up @@ -1626,8 +1620,6 @@ bool IndexMissBehavior_ser(FILE *out, const Ptr<IndexMissBehavior> &value) {
return Immediate_ser(out, typed_value->value);
}
case IndexMissBehavior::Tag::skip: {
auto *typed_value =
static_cast<const IndexMissBehaviorSkipWrapper *>(value.get());
return true; // skip variant has no fields to serialize
}
default:
Expand Down Expand Up @@ -2516,8 +2508,6 @@ bool ReplicaGroup_ser(FILE *out, const Ptr<ReplicaGroup> &value) {
// Serialize the fields based on the specific variant
switch (value->tag) {
case ReplicaGroup::Tag::unspecified: {
auto *typed_value =
static_cast<const ReplicaGroupUnspecifiedWrapper *>(value.get());
return true; // unspecified variant has no fields to serialize
}
case ReplicaGroup::Tag::named: {
Expand Down Expand Up @@ -3666,10 +3656,12 @@ bool Stmt_ser(FILE *out, const Ptr<Stmt> &value) {
}

bool Block_ser(FILE *out, const Ptr<Block> &value) {
if (!serialize_tag(out, 104, 0, 2))
if (!serialize_tag(out, 104, 0, 3))
return false;
if (!String_ser(out, value->label))
return false;
if (!Bool_ser(out, value->noReorder))
return false;
if (!List_Stmt_ser(out, value->body))
return false;
return true;
Expand Down Expand Up @@ -7730,14 +7722,15 @@ Ptr<Block> Block_des(FILE *in) {
msg << "Could not find tag, expecting Block:104,0";
throw std::runtime_error(msg.str());
}
if (t != 104 || c != 0 || l != 2) {
if (t != 104 || c != 0 || l != 3) {
std::ostringstream msg;
msg << "Expecting Block:(104,0,2)";
msg << "Expecting Block:(104,0,3)";
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
throw std::runtime_error(msg.str());
}
Ptr<Block> x = ptr<Block>();
x->label = String_des(in);
x->noReorder = Bool_des(in);
x->body = List_Stmt_des(in);
return x;
}
Expand Down
Loading