Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions NAM/slimmable_wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <algorithm>
#include <cmath>
#include <optional>
#include <stdexcept>

namespace nam
Expand Down Expand Up @@ -120,6 +121,12 @@ std::vector<float> extract_slimmed_weights(const std::vector<wavenet::LayerArray
for (int arr = 0; arr < num_arrays; arr++)
{
const auto& p = original_params[arr];
if (p.head_kernel_size != 1)
{
throw std::runtime_error(
"SlimmableWavenet: head rechannel kernel_size must be 1 (slimming with head kernel_size > 1 is not "
"implemented)");
}
validate_groups(p);

const int full_ch = p.channels;
Expand Down Expand Up @@ -258,8 +265,9 @@ std::vector<wavenet::LayerArrayParams> modify_params_for_channels(
int new_head_size = (i < num_arrays - 1) ? new_channels_per_array[i + 1] : p.head_size;

modified.push_back(wavenet::LayerArrayParams(
new_input_size, p.condition_size, new_head_size, new_ch, new_bottleneck, std::vector<int>(p.kernel_sizes),
std::vector<int>(p.dilations), std::vector<activations::ActivationConfig>(p.activation_configs),
new_input_size, p.condition_size, new_head_size, p.head_kernel_size, new_ch, new_bottleneck,
std::vector<int>(p.kernel_sizes), std::vector<int>(p.dilations),
std::vector<activations::ActivationConfig>(p.activation_configs),
std::vector<wavenet::GatingMode>(p.gating_modes), p.head_bias, p.groups_input, p.groups_input_mixin,
p.layer1x1_params, p.head1x1_params, std::vector<activations::ActivationConfig>(p.secondary_activation_configs),
p.conv_pre_film_params, p.conv_post_film_params, p.input_mixin_pre_film_params, p.input_mixin_post_film_params,
Expand Down Expand Up @@ -326,6 +334,9 @@ SlimmableWavenet::SlimmableWavenet(std::vector<wavenet::LayerArrayParams> origin
if (!any_slimmable)
throw std::runtime_error("SlimmableWavenet: at least one layer array must have allowed_channels");

if (with_head)
throw std::runtime_error("SlimmableWavenet: post-stack head is not supported");

// Build with full channel counts as default (ratio=1.0)
std::vector<int> full_channels(_original_params.size());
for (size_t i = 0; i < _original_params.size(); i++)
Expand Down Expand Up @@ -360,8 +371,8 @@ void SlimmableWavenet::_rebuild_model(const std::vector<int>& target_channels)
condition_dsp = get_dsp(_condition_dsp_json);

double sampleRate = _current_sample_rate > 0 ? _current_sample_rate : GetExpectedSampleRate();
_active_model = std::make_unique<wavenet::WaveNet>(
_in_channels, *params_ptr, _head_scale, _with_head, std::move(weights), std::move(condition_dsp), sampleRate);
_active_model = std::make_unique<wavenet::WaveNet>(_in_channels, *params_ptr, _head_scale, _with_head, std::nullopt,
std::move(weights), std::move(condition_dsp), sampleRate);
_current_channels = target_channels;

if (_current_buffer_size > 0)
Expand Down
233 changes: 214 additions & 19 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,75 @@
#include "slimmable_wavenet.h"
#include "wavenet.h"

// PostStackHead (WaveNet post-stack head) =====================================

nam::wavenet::PostStackHead::PostStackHead(const WaveNetHeadParams& params)
: _in_channels(params.in_channels)
, _out_channels(params.out_channels)
{
if (params.kernel_sizes.empty())
throw std::runtime_error("PostStackHead: kernel_sizes must be non-empty");
const size_t n = params.kernel_sizes.size();
int cin = params.in_channels;
for (size_t i = 0; i < n; i++)
{
const int cout = (i + 1 == n) ? params.out_channels : params.channels;
const int k = params.kernel_sizes[i];
if (k < 1)
throw std::runtime_error("PostStackHead: kernel_sizes entries must be >= 1");
nam::activations::Activation::Ptr act = nam::activations::Activation::get_activation(params.activation_config);
if (act == nullptr)
throw std::runtime_error("PostStackHead: unsupported activation for post-stack head");
_activations.push_back(std::move(act));
nam::Conv1D conv;
conv.set_size_(cin, cout, k, true, 1, 1);
_convs.push_back(std::move(conv));
cin = cout;
}
}

void nam::wavenet::PostStackHead::set_weights_(std::vector<float>::iterator& weights)
{
for (size_t i = 0; i < _convs.size(); i++)
_convs[i].set_weights_(weights);
}

void nam::wavenet::PostStackHead::SetMaxBufferSize(const int maxBufferSize)
{
for (size_t i = 0; i < _convs.size(); i++)
_convs[i].SetMaxBufferSize(maxBufferSize);
}

long nam::wavenet::PostStackHead::receptive_field() const
{
long rf = 1;
for (size_t i = 0; i < _convs.size(); i++)
{
const long k = _convs[i].get_kernel_size();
rf += k - 1;
}
return rf;
}

void nam::wavenet::PostStackHead::process(Eigen::MatrixXf& work, const int num_frames)
{
for (size_t i = 0; i < _convs.size(); i++)
{
const long in_ch = _convs[i].get_in_channels();
if (i == 0)
{
_activations[i]->apply(work.data(), (long)(in_ch * num_frames));
_convs[i].Process(work, num_frames);
}
else
{
auto& prev = _convs[i - 1].GetOutput();
_activations[i]->apply(prev.data(), (long)(in_ch * num_frames));
_convs[i].Process(prev, num_frames);
}
}
}

// Layer ======================================================================

void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
Expand Down Expand Up @@ -306,7 +375,7 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
nam::wavenet::_LayerArray::_LayerArray(const LayerArrayParams& params)
: _rechannel(params.input_size, params.channels, false)
, _head_rechannel(params.head1x1_params.active ? params.head1x1_params.out_channels : params.bottleneck,
params.head_size, params.head_bias)
params.head_size, params.head_kernel_size, params.head_bias ? 1 : 0, 1, 1)
, _head_output_size(params.head1x1_params.active ? params.head1x1_params.out_channels : params.bottleneck)
{
const size_t num_layers = params.dilations.size();
Expand Down Expand Up @@ -345,6 +414,7 @@ long nam::wavenet::_LayerArray::get_receptive_field() const
long result = 0;
for (size_t i = 0; i < this->_layers.size(); i++)
result += this->_layers[i].get_dilation() * (this->_layers[i].get_kernel_size() - 1);
result += (long)this->_head_rechannel.get_kernel_size() - 1;
return result;
}

Expand Down Expand Up @@ -431,8 +501,8 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs
this->_layers[last_layer].GetOutputNextLayer().leftCols(num_frames);
#endif

// Process head rechannel
_head_rechannel.process_(this->_head_inputs, num_frames);
// Process head rechannel (causal Conv1D)
_head_rechannel.Process(this->_head_inputs, num_frames);
}


Expand Down Expand Up @@ -460,16 +530,27 @@ long nam::wavenet::_LayerArray::_get_channels() const
return this->_layers.size() > 0 ? this->_layers[0].get_channels() : 0;
}

namespace
{
int wave_net_output_channels(const std::vector<nam::wavenet::LayerArrayParams>& layer_array_params,
const bool with_head, const std::optional<nam::wavenet::WaveNetHeadParams>& head_params)
{
if (layer_array_params.empty())
throw std::runtime_error("WaveNet requires at least one layer array");
if (with_head && head_params.has_value())
return head_params->out_channels;
return layer_array_params.back().head_size;
}
} // namespace

// WaveNet ====================================================================

nam::wavenet::WaveNet::WaveNet(const int in_channels,
const std::vector<nam::wavenet::LayerArrayParams>& layer_array_params,
const float head_scale, const bool with_head, std::vector<float> weights,
const float head_scale, const bool with_head,
std::optional<WaveNetHeadParams> head_params, std::vector<float> weights,
std::unique_ptr<DSP> condition_dsp, const double expected_sample_rate)
: DSP(in_channels,
layer_array_params.empty() ? throw std::runtime_error("WaveNet requires at least one layer array")
: layer_array_params.back().head_size,
expected_sample_rate)
: DSP(in_channels, wave_net_output_channels(layer_array_params, with_head, head_params), expected_sample_rate)
, _condition_dsp(std::move(condition_dsp))
, _head_scale(head_scale)
{
Expand All @@ -484,10 +565,22 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
throw std::runtime_error(ss.str().c_str());
}
}
if (layer_array_params.empty())
throw std::runtime_error("WaveNet requires at least one layer array");
if (with_head)
throw std::runtime_error("Head not implemented!");
{
if (!head_params.has_value())
throw std::runtime_error("WaveNet: with_head is true but head configuration is missing");
if (head_params->in_channels != layer_array_params.back().head_size)
{
std::stringstream ss;
ss << "WaveNet head in_channels (" << head_params->in_channels << ") must match last layer array head_size ("
<< layer_array_params.back().head_size << ")";
throw std::runtime_error(ss.str());
}
this->_post_stack_head = std::make_unique<PostStackHead>(*head_params);
}
else if (head_params.has_value())
throw std::runtime_error("WaveNet: head configuration provided but with_head is false");

for (size_t i = 0; i < layer_array_params.size(); i++)
{
// Quick assert that the condition_dsp will output compatibly with this layer array
Expand Down Expand Up @@ -518,6 +611,8 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
mPrewarmSamples = this->_condition_dsp != nullptr ? this->_condition_dsp->PrewarmSamples() : 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
mPrewarmSamples += this->_layer_arrays[i].get_receptive_field();
if (this->_post_stack_head != nullptr)
mPrewarmSamples += this->_post_stack_head->receptive_field() - 1;
}

void nam::wavenet::WaveNet::set_weights_(std::vector<float>& weights)
Expand All @@ -527,6 +622,8 @@ void nam::wavenet::WaveNet::set_weights_(std::vector<float>& weights)
// so we don't need to set its weights here.
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
this->_layer_arrays[i].set_weights_(it);
if (this->_post_stack_head != nullptr)
this->_post_stack_head->set_weights_(it);
this->_head_scale = *(it++); // TODO `LayerArray.absorb_head_scale()`
if (it != weights.end())
{
Expand Down Expand Up @@ -579,6 +676,12 @@ void nam::wavenet::WaveNet::SetMaxBufferSize(const int maxBufferSize)

for (size_t i = 0; i < this->_layer_arrays.size(); i++)
this->_layer_arrays[i].SetMaxBufferSize(maxBufferSize);

if (this->_post_stack_head != nullptr)
{
this->_post_stack_head->SetMaxBufferSize(maxBufferSize);
this->_scaled_head_scratch.resize(this->_post_stack_head->in_channels(), maxBufferSize);
}
}

void nam::wavenet::WaveNet::_process_condition(const int num_frames)
Expand Down Expand Up @@ -656,9 +759,39 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, con
}
}

// (Head not implemented)

auto& final_head_outputs = this->_layer_arrays.back().GetHeadOutputs();

if (this->_post_stack_head != nullptr)
{
assert(final_head_outputs.rows() == this->_post_stack_head->in_channels());
const int head_in = this->_post_stack_head->in_channels();
for (int ch = 0; ch < head_in; ch++)
{
for (int s = 0; s < num_frames; s++)
this->_scaled_head_scratch(ch, s) = this->_head_scale * final_head_outputs(ch, s);
}
this->_post_stack_head->process(this->_scaled_head_scratch, num_frames);
const Eigen::MatrixXf& head_out = this->_post_stack_head->get_last_output();
assert(head_out.rows() == out_channels);

if (out_channels == 1)
{
const float* __restrict__ src = head_out.data();
NAM_SAMPLE* __restrict__ dst = output[0];
for (int s = 0; s < num_frames; s++)
dst[s] = (NAM_SAMPLE)src[s];
}
else
{
for (int ch = 0; ch < out_channels; ch++)
{
for (int s = 0; s < num_frames; s++)
output[ch][s] = (NAM_SAMPLE)head_out(ch, s);
}
}
return;
}

assert(final_head_outputs.rows() == out_channels);

// Optimized output copy with head_scale multiplication
Expand Down Expand Up @@ -729,7 +862,41 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json

const int input_size = layer_config["input_size"];
const int condition_size = layer_config["condition_size"];
const int head_size = layer_config["head_size"];

int head_size = 0;
int head_kernel_size = 1;
bool head_bias = false;

// Prefer nested "head" (matches trainer export). Legacy .nam uses head_size + head_bias (implicit kernel 1).
if (layer_config.find("head") != layer_config.end() && !layer_config["head"].is_null())
{
const auto& head_json = layer_config["head"];
if (!head_json.is_object())
{
throw std::runtime_error("Layer array " + std::to_string(i) + ": 'head' must be a JSON object");
}
head_size = head_json.at("out_channels").get<int>();
head_kernel_size = head_json.at("kernel_size").get<int>();
head_bias = head_json.at("bias").get<bool>();
}
else if (layer_config.find("head_size") != layer_config.end())
{
head_size = layer_config["head_size"].get<int>();
head_kernel_size = 1;
head_bias = layer_config.at("head_bias").get<bool>();
}
else
{
throw std::runtime_error("Layer array " + std::to_string(i)
+ ": expected 'head' object with out_channels, kernel_size, and bias, "
"or legacy 'head_size' and 'head_bias'");
}

if (head_kernel_size < 1)
{
throw std::runtime_error("Layer array " + std::to_string(i) + ": head.kernel_size must be >= 1");
}

const auto dilations = layer_config["dilations"];
const size_t num_layers = dilations.size();

Expand Down Expand Up @@ -921,8 +1088,6 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
secondary_activation_configs.resize(num_layers, activations::ActivationConfig{});
}

const bool head_bias = layer_config["head_bias"];

// Parse head1x1 parameters
bool head1x1_active = false;
int head1x1_out_channels = channels;
Expand Down Expand Up @@ -967,7 +1132,7 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
}

wc.layer_array_params.push_back(nam::wavenet::LayerArrayParams(
input_size, condition_size, head_size, channels, bottleneck, std::move(kernel_sizes), dilations,
input_size, condition_size, head_size, head_kernel_size, channels, bottleneck, std::move(kernel_sizes), dilations,
std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params,
head1x1_params, std::move(secondary_activation_configs), conv_pre_film_params, conv_post_film_params,
input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
Expand All @@ -981,14 +1146,44 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
if (wc.layer_array_params.empty())
throw std::runtime_error("WaveNet config requires at least one layer array");

if (wc.with_head)
{
const nlohmann::json& hj = config["head"];
WaveNetHeadParams hp;
const int implied_in = wc.layer_array_params.back().head_size;
// New trainer export omits in_channels (single source: last layer head_size). Legacy .nam may include it.
if (hj.find("in_channels") != hj.end() && !hj["in_channels"].is_null())
{
const int legacy_in = hj["in_channels"].get<int>();
if (legacy_in != implied_in)
{
std::stringstream ss;
ss << "WaveNet config: head.in_channels (" << legacy_in << ") must equal last layer's head_size (" << implied_in
<< ")";
throw std::runtime_error(ss.str());
}
}
hp.in_channels = implied_in;
hp.channels = hj.at("channels").get<int>();
hp.out_channels = hj.at("out_channels").get<int>();
hp.kernel_sizes = hj.at("kernel_sizes").get<std::vector<int>>();
hp.activation_config = nam::activations::ActivationConfig::from_json(hj.at("activation"));
if (hp.kernel_sizes.empty())
throw std::runtime_error("WaveNet config: head.kernel_sizes must be non-empty");
wc.head_params = std::move(hp);
}
else
wc.head_params = std::nullopt;

return wc;
}

// WaveNetConfig::create()
std::unique_ptr<nam::DSP> nam::wavenet::WaveNetConfig::create(std::vector<float> weights, double sampleRate)
{
return std::make_unique<nam::wavenet::WaveNet>(
in_channels, layer_array_params, head_scale, with_head, std::move(weights), std::move(condition_dsp), sampleRate);
return std::make_unique<nam::wavenet::WaveNet>(in_channels, layer_array_params, head_scale, with_head,
std::move(head_params), std::move(weights), std::move(condition_dsp),
sampleRate);
}

namespace
Expand Down
Loading
Loading