-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathai.cpp
More file actions
164 lines (138 loc) · 7.4 KB
/
ai.cpp
File metadata and controls
164 lines (138 loc) · 7.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#include <mlx/mlx.h>
#include <iostream>
#include <vector>
#include <memory>
#include <cmath>
#include <optional>
#include <fstream>
#include <sstream>
#include <unordered_map>
#include <functional>
#include <iomanip>
#include <limits>
#include <chrono>
#include "lib/kv_cache.cpp"
#include "lib/rms_norm.cpp"
#include "lib/linear.cpp"
#include "lib/rope.cpp"
#include "lib/embedding.cpp"
#include "lib/cross_entropy.cpp"
#include "lib/model_args.cpp"
#include "lib/feed_forward.cpp"
#include "lib/model.cpp"
#include "lib/optimizer.cpp"
#include "lib/transformer.cpp"
#include "lib/tokenizer.cpp"
#include "lib/generator.cpp"
using namespace mlx::core;
// Example usage in main
int main() {
try {
std::cout << "\n=== MLX C++ Language Model Training ===\n" << std::endl;
// Initialize model and tokenizer
std::cout << "🔧 Initializing model and tokenizer..." << std::endl;
TransformerModel model;
Tokenizer tokenizer("./tokenizer.model");
Generator generator(model, tokenizer);
std::cout << "✓ Initialization complete\n" << std::endl;
// First generation test
std::string prompt = "Már nem volt fiatal, de még";
std::cout << "📝 Initial generation test with prompt:\n\"" << prompt << "\"\n" << std::endl;
std::cout << "Generated text:" << std::endl;
std::cout << "----------------" << std::endl;
generator.generate(prompt);
std::cout << "----------------\n" << std::endl;
// Training preparation
std::cout << "🔄 Preparing training data..." << std::endl;
std::string training_text = "Már nem volt fiatal, de még elég jól bírta magát; "
"ismerték és félték a nádas lakói, de még azon túl is, közelben-távolban, "
"minden négylábú lény. Látása nem romlott, s ha ezerméteres magasságból "
"kiszemelte zsákmányát, úgy csapott le rá, mint egy kalapács, mely egyetlen "
"ütéssel veri be a szöget. És így, viruló korában, ereje teljében, két lassú "
"szárnycsapás között egyszer csak megállt a szíve. De nem mertek előbújni sem "
"a nyulak, sem az ürgék, sem a környező falvak baromfiai, mert ő ott lebegett "
"ezer méter magasban, kiterjesztett szárnyával, fenyegető mozdulatlanságban "
"túlélve a halált még két vagy három perccel, míg el nem állt a szél.";
std::vector<int> tokens = tokenizer.encode(training_text);
// Create input/target arrays
std::vector<int> token_vec(tokens.begin(), tokens.end() - 1);
array x = array(token_vec.data(), {1, (int)token_vec.size()}, int32);
std::vector<int> target_vec(tokens.begin() + 1, tokens.end());
array y = array(target_vec.data(), {1, (int)target_vec.size()}, int32);
array z = array({(int)tokens.size()}, {1}, int32);
std::cout << "✓ Training data prepared" << std::endl;
std::cout << " • Sequence length: " << tokens.size() << " tokens" << std::endl;
std::cout << " • Input shape: [1, " << token_vec.size() << "]" << std::endl;
std::cout << " • Target shape: [1, " << target_vec.size() << "]\n" << std::endl;
// Initialize optimizer
std::cout << "🔧 Initializing AdamW optimizer with:" << std::endl;
std::cout << " • Learning rate: 1e-5" << std::endl;
std::cout << " • Beta1: 0.9" << std::endl;
std::cout << " • Beta2: 0.97" << std::endl;
std::cout << " • Epsilon: 1e-5" << std::endl;
std::cout << " • Weight decay: 0.0\n" << std::endl;
std::unique_ptr<AdamW> optimizer = std::make_unique<AdamW>(
1e-5, 0.9, 0.97, 1e-5, 0.0
);
// Training loop
std::cout << "🚀 Starting training loop (20 iterations)" << std::endl;
std::cout << "----------------------------------------" << std::endl;
float best_loss = std::numeric_limits<float>::max();
// Pre-allocate vectors and reuse them
std::vector<array> state_arrays;
const size_t total_array_size = 1 + model.state_arrays().size() + model.parameters().size(); // loss + states + params
state_arrays.reserve(total_array_size);
// Cache model states to avoid recomputation
auto model_states = model.state_arrays();
for (int step = 0; step < 20; ++step) {
auto step_start = std::chrono::high_resolution_clock::now();
// Forward + backward pass
auto fwd_bwd_start = std::chrono::high_resolution_clock::now();
auto [loss, grads] = model.value_and_grad(x, y, z);
eval({loss}); // Evaluate loss immediately
auto fwd_bwd_end = std::chrono::high_resolution_clock::now();
// Optimizer update
auto optim_start = std::chrono::high_resolution_clock::now();
optimizer->update(model, grads);
eval(model.state_arrays()); // Evaluate model state immediately
auto optim_end = std::chrono::high_resolution_clock::now();
// Remaining evaluations
auto remaining_start = std::chrono::high_resolution_clock::now();
for (const auto& [_, param] : grads) {
eval(param);
}
auto remaining_end = std::chrono::high_resolution_clock::now();
float loss_value = loss.item<float>();
best_loss = std::min(best_loss, loss_value);
// Calculate timings in milliseconds
auto fwd_bwd_time = std::chrono::duration_cast<std::chrono::milliseconds>(fwd_bwd_end - fwd_bwd_start).count();
auto optim_time = std::chrono::duration_cast<std::chrono::milliseconds>(optim_end - optim_start).count();
auto remaining_time = std::chrono::duration_cast<std::chrono::milliseconds>(remaining_end - remaining_start).count();
auto total_time = std::chrono::duration_cast<std::chrono::milliseconds>(remaining_end - step_start).count();
std::cout << "Step " << std::setw(2) << step + 1 << "/20"
<< " | Loss: " << std::fixed << std::setprecision(6) << loss_value;
if (loss_value == best_loss) {
std::cout << " ⭐";
}
std::cout << "\n • Forward+Backward: " << fwd_bwd_time << "ms"
<< " | Optimizer: " << optim_time << "ms"
<< " | Remaining eval: " << remaining_time << "ms"
<< " | Total: " << total_time << "ms"
<< std::endl;
}
std::cout << "----------------------------------------" << std::endl;
std::cout << "✓ Training complete" << std::endl;
std::cout << " • Final loss: " << std::fixed << std::setprecision(6) << best_loss << "\n" << std::endl;
// Final generation test
std::cout << "📝 Final generation test with same prompt:\n\"" << prompt << "\"\n" << std::endl;
std::cout << "Generated text:" << std::endl;
std::cout << "----------------" << std::endl;
generator.generate(prompt);
std::cout << "----------------\n" << std::endl;
std::cout << "✨ All operations completed successfully!\n" << std::endl;
return 0;
} catch (const std::exception& e) {
std::cerr << "\n❌ Error: " << e.what() << std::endl;
return 1;
}
}