Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
DEFINE_bool(overfit_single_batch, true, "overfit just one batch of data");
// memory management
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if using parallel training mode");
// flash attention
DEFINE_bool(flash, false, "Whether to enable FlashAttention");
// parallel
DEFINE_int32(
nthread_per_process, 1,
Expand Down Expand Up @@ -191,9 +193,10 @@ void Train(const nn::parallel::Rank &rank) {
std::shared_ptr<nn::Module> model = nullptr;

if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.use_flash_attention = FLAGS_flash;
model = std::make_shared<GPT2>(model_config);
} else {
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
Expand Down
45 changes: 26 additions & 19 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,27 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
std::shared_ptr<infini_train::Tensor> y;

const bool is_flash_dtype = q->Dtype() == DataType::kFLOAT32 || q->Dtype() == DataType::kBFLOAT16;
const bool short_mha_shape = (T <= 128 && head_dim <= 64);
const bool can_use_flash_kernel = config_.use_flash_attention && !short_mha_shape && q->GetDevice().IsCUDA()
&& is_flash_dtype && k->Dtype() == q->Dtype() && v->Dtype() == q->Dtype();

if (can_use_flash_kernel) {
y = nn::function::ScaledDotProductAttention(q, k, v, nullptr, 0.0, true);
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
y = att->Matmul(v);
}
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});

Expand Down Expand Up @@ -198,8 +209,8 @@ GPT2FirstStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>>
auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled();
int tp_rank = 0;
if (tp_world_size > 1) {
auto tp_group = nn::parallel::ProcessGroupFactory::Instance(device.type())
->Get(nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank());
}
int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1];
Expand Down Expand Up @@ -307,11 +318,6 @@ GPT2::GPT2(const GPT2Config &config)
modules_[kTransformerLayerName] = std::make_shared<nn::ModuleDict>(std::move(transformer));

// FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation
// TODO: Implement real GPT-2 weight tying: make lm_head.weight share the exact same Parameter/Tensor (same
// shared_ptr/storage) as transformer.wte.weight (pointer aliasing, not value copy), and ensure the tie is applied
// after loading weights so it won't be overwritten. Also fix GPT2::FromLLMC() loading logic to respect weight tying
// (do not create/load a separate lm_head.weight tensor; load once into the tied weight) so parameter counting
// matches PyTorch/PEFT.
if (nn::parallel::global::GetPipelineParallelSize() == 1) {
// https://paperswithcode.com/method/weight-tying
*mutable_module(kTransformerLayerName)
Expand Down Expand Up @@ -356,7 +362,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
}
} // namespace

std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath, bool use_flash_attention) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -384,7 +390,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
.original_vocab_size = vocab_size,
.n_layer = n_layer,
.n_head = n_head,
.n_embd = n_embd});
.n_embd = n_embd,
.use_flash_attention = use_flash_attention});

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
Expand Down
3 changes: 2 additions & 1 deletion example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool use_flash_attention = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down Expand Up @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath, bool use_flash_attention = false);

int GetChunkSize() const;

Expand Down
5 changes: 4 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
DEFINE_bool(overfit_single_batch, true, "overfit just one batch of data");
// memory management
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if using parallel training mode");
// flash attention
DEFINE_bool(flash, false, "Whether to enable FlashAttention");
// parallel
DEFINE_int32(
nthread_per_process, 1,
Expand Down Expand Up @@ -168,9 +170,10 @@ void Train(const nn::parallel::Rank &rank) {
// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
model_config.use_flash_attention = FLAGS_flash;
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model = std::make_shared<LLaMA3>(model_config);
}
Expand Down
84 changes: 60 additions & 24 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,34 +207,69 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
// TODO(zbl): use kv cache during inference
// if (use_kv_) { ... }

// align n_head in GQA
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
k = RepeatKV(k, n_rep_);
v = RepeatKV(v, n_rep_);

// (B, T, H_local, D) -> (B, H_local, T, D)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y;

auto q_flash_candidate = q;
auto k_flash_candidate = k;
auto v_flash_candidate = v;
if (config_.use_flash_attention && q->GetDevice().IsCUDA()) {
const DataType target_dtype = v->Dtype();
if (q_flash_candidate->Dtype() != target_dtype) {
q_flash_candidate = std::make_shared<Tensor>(q_flash_candidate->To(target_dtype));
}
if (k_flash_candidate->Dtype() != target_dtype) {
k_flash_candidate = std::make_shared<Tensor>(k_flash_candidate->To(target_dtype));
}
}

const bool is_flash_dtype
= q_flash_candidate->Dtype() == DataType::kFLOAT32 || q_flash_candidate->Dtype() == DataType::kBFLOAT16;
const bool can_use_flash_kernel = config_.use_flash_attention && q_flash_candidate->GetDevice().IsCUDA()
&& is_flash_dtype && k_flash_candidate->Dtype() == q_flash_candidate->Dtype()
&& v_flash_candidate->Dtype() == q_flash_candidate->Dtype();

if (can_use_flash_kernel) {
// Flash path keeps native GQA shape (k/v on KV_local heads) to avoid RepeatKV expansion.
auto k_flash = k_flash_candidate->Transpose(1, 2);
auto v_flash = v_flash_candidate->Transpose(1, 2);

// Training mask in this model is standard causal Triu; prefer causal fast-path with null attn_mask.
const bool use_causal_fast_path = (mask != nullptr && start_pos == nullptr);
const auto &attn_mask = use_causal_fast_path ? nullptr : mask;
const bool is_causal = (mask == nullptr) || use_causal_fast_path;

y = nn::function::ScaledDotProductAttention(q_flash_candidate, k_flash, v_flash, attn_mask, 0.0, is_causal,
std::nullopt, n_rep_ > 1);
} else {
// align n_head in GQA
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
k = RepeatKV(k, n_rep_);
v = RepeatKV(v, n_rep_);

// (B, T, H_local, D) -> (B, H_local, T, D)
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
y = att->Matmul(v);
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
// output projection
Expand Down Expand Up @@ -457,7 +492,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool use_flash_attention) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -491,6 +526,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.n_head = n_head,
.n_kv_head = n_kv_head,
.n_embd = n_embd,
.use_flash_attention = use_flash_attention,
.ffn_dim_multiplier = ffn_dim_multiplier,
.multiple_of = multiple_of,
.rope_theta = rope_theta,
Expand Down
5 changes: 4 additions & 1 deletion example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ struct LLaMA3Config {
int64_t n_kv_head = 8; // Num of Key/Value heads(< n_head if using GQA)
int64_t n_embd = 2048; // Hidden size

// Attention config
bool use_flash_attention = false;

// FFN config
std::optional<float> ffn_dim_multiplier = 1.5f; // FFN dim multiplier
int64_t multiple_of = 256; // FFN dims must be multiple of this number
Expand Down Expand Up @@ -179,7 +182,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool use_flash_attention = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
38 changes: 38 additions & 0 deletions infini_train/include/autograd/attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <cstdint>
#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class ScaledDotProductAttention : public Function {
public:
static constexpr char kType[] = "ScaledDotProductAttentionFunction";

ScaledDotProductAttention(double dropout_p, bool is_causal, double scale, bool enable_gqa)
: Function(kType), dropout_p_(dropout_p), is_causal_(is_causal), scale_(scale), enable_gqa_(enable_gqa) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const double dropout_p_ = 0.0;
const bool is_causal_ = false;
const double scale_ = 1.0;
const bool enable_gqa_ = false;
bool has_attn_mask_ = false;
std::shared_ptr<Tensor> lse_ = nullptr;
uint64_t rng_seed_ = 0;
uint64_t rng_offset_ = 0;
};

} // namespace infini_train::autograd
25 changes: 25 additions & 0 deletions infini_train/include/nn/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

namespace infini_train {
Expand All @@ -10,6 +11,30 @@ class Tensor;

namespace infini_train::nn::function {

// Computes the scaled dot product attention.
//
// Ref: PyTorch scaled_dot_product_attention.
//
// Args:
// query: Query tensor; shape (N, ..., L, E).
// key: Key tensor; shape (N, ..., S, E).
// value: Value tensor; shape (N, ..., S, E).
// attn_mask: Optional additive-style mask tensor using framework convention;
// shape must be broadcastable to (N, ..., L, S).
// dropout_p: Dropout probability; defaults to 0.0.
// is_causal: If true, applies a causal mask to the attention window.
// scale: Scaling factor applied prior to softmax. Defaults to 1 / sqrt(E).
// enable_gqa: If true, enables Grouped Query Attention support.
//
// Returns:
// Attention output tensor; shape (N, ..., L, E).
std::shared_ptr<Tensor> ScaledDotProductAttention(const std::shared_ptr<Tensor> &query,
const std::shared_ptr<Tensor> &key,
const std::shared_ptr<Tensor> &value,
const std::shared_ptr<Tensor> &attn_mask = nullptr,
double dropout_p = 0.0, bool is_causal = false,
std::optional<double> scale = std::nullopt, bool enable_gqa = false);

// Returns the lower triangular part of a 2D tensor or a batch of matrices.
//
// The lower triangular part includes elements on and below the specified
Expand Down
Loading