Files
the-biergarten-app/tooling/pipeline/src/data_generation/llama/infer.cc
Aaron Po 54a46458a3 Integrate logging channel system
update logging to use logger channel

updates
2026-05-20 00:38:00 -04:00

240 lines
7.8 KiB
C++

/**
* Text Generation / Inference Module
* Core module that performs LLM inference: converts text prompts into tokens,
* runs the neural network forward pass, samples the next token, and converts
* output tokens back to text for system+user chat prompts.
*/
#include <algorithm>
#include <memory>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
static constexpr size_t kPromptTokenSlack = 8;
// Minimum tokens to keep when using top-p sampling. Ensures at least one
// candidate token remains available even with very restrictive top-p values.
static constexpr size_t kTopPMinKeep = 1;
namespace {
using SamplerHandle =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
struct SamplerConfig {
float temperature;
uint32_t top_k;
float top_p;
uint32_t seed;
};
SamplerHandle MakeSamplerChain(const llama_vocab* vocab,
const SamplerConfig& config,
std::string_view grammar) {
const llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
SamplerHandle chain(llama_sampler_chain_init(sampler_params),
llama_sampler_free);
if (!chain) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
auto add_sampler = [&](llama_sampler* sampler, const char* error_message) {
if (sampler == nullptr) {
throw std::runtime_error(error_message);
}
llama_sampler_chain_add(chain.get(), sampler);
};
if (!grammar.empty()) {
const std::string grammar_text(grammar);
add_sampler(llama_sampler_init_grammar(vocab, grammar_text.c_str(), "root"),
"LlamaGenerator: failed to initialize grammar sampler");
}
add_sampler(llama_sampler_init_temp(config.temperature),
"LlamaGenerator: failed to initialize temperature sampler");
add_sampler(llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)),
"LlamaGenerator: failed to initialize top-k sampler");
add_sampler(llama_sampler_init_top_p(config.top_p, kTopPMinKeep),
"LlamaGenerator: failed to initialize top-p sampler");
add_sampler(llama_sampler_init_dist(config.seed),
"LlamaGenerator: failed to initialize distribution sampler");
return chain;
}
} // namespace
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt,
const int max_tokens,
std::string_view grammar) {
return InferFormatted(prompt_formatter_->Format(system_prompt, prompt),
max_tokens, grammar);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
const int max_tokens,
std::string_view grammar) {
/**
* Validate that model and context are loaded
*/
if (!model_ || !context_) {
throw std::runtime_error("LlamaGenerator: model not loaded");
}
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_.get());
if (vocab == nullptr) {
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
const SamplerConfig sampler_config{
.temperature = sampling_temperature_,
.top_k = sampling_top_k_,
.top_p = sampling_top_p_,
.seed = static_cast<uint32_t>(rng_()),
};
auto sampler = MakeSamplerChain(vocab, sampler_config, grammar);
/**
* Clear KV cache to ensure clean inference state (no residual context)
*/
llama_memory_clear(llama_get_memory(context_.get()), true);
/**
* TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands
*/
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
kPromptTokenSlack);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
/**
* If buffer too small, negative return indicates required size
*/
if (token_count < 0) {
prompt_tokens.resize(static_cast<size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0) {
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
/**
* CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window
* constraints
*/
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_.get()));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_.get()));
if (n_ctx <= 1 || n_batch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
/**
* Clamp generation limit to available context window, reserve space for
* output
*/
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
/**
* Prompt can use remaining context after reserving space for generation
*/
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
/**
* Truncate prompt if necessary to fit within constraints
*/
prompt_tokens.resize(static_cast<size_t>(token_count));
if (token_count > prompt_budget) {
if (logger_) {
logger_->Log(
LogLevel::Warn, PipelinePhase::BreweryAndBeerGeneration,
std::string("LlamaGenerator: prompt too long (") +
std::to_string(token_count) + ") tokens, truncating to " +
std::to_string(prompt_budget) +
" tokens to fit n_batch/n_ctx limits");
}
prompt_tokens.resize(static_cast<size_t>(prompt_budget));
token_count = prompt_budget;
}
/**
* PROMPT PROCESSING PHASE
* Create a batch containing all prompt tokens and feed through the model
* This computes internal representations and fills the KV cache
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_.get(), prompt_batch) != 0) {
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
/**
* TOKEN GENERATION LOOP
* Iteratively generate tokens one at a time until max_tokens or
* end-of-sequence
*/
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<size_t>(effective_max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
/**
* Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler.get(), context_.get(), -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) {
break;
}
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token decode_token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
if (llama_decode(context_.get(), one_token_batch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
/**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens) {
AppendTokenPiece(vocab, token, output);
}
return output;
}