mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
CORRECTNESS FIXES: - json_loader: Add RollbackTransaction() and call it on exception instead of CommitTransaction(). Prevents partial data corruption on parse/disk errors. - wikipedia_service: Fix invalid MediaWiki API parameter explaintext=true -> explaintext=1. Now returns plain text instead of HTML markup in contexts. - helpers: Fix ParseTwoLineResponse filter to only remove known thinking tags (<think>, <reasoning>, <reflect>) instead of any <...> pattern. Prevents silently removing legitimate output like <username>content</username>. RELIABILITY & DESIGN IMPROVEMENTS: - load/main: Make n_ctx (context window size) configurable via --n-ctx flag (default 2048, range 1-32768) to support larger models like Qwen3-14B. - generate_brewery: Prevent retry prompt growth by extracting location context into constant and using compact retry format (error + schema + location only). Avoids token truncation on final retry attempts. - database: Fix data representativeness by changing QueryCities from ORDER BY name (alphabetic bias) to ORDER BY RANDOM() for unbiased sampling. Convert all SQLITE_STATIC to SQLITE_TRANSIENT to prevent use-after-free risks. POLISH: - infer: Advance sampling seed between generation calls to improve diversity across brewery and user generation. - data_downloader: Remove unnecessary commit hash truncation; use full hash. - json_loader: Fix misleading log message from "RapidJSON" to "Boost.JSON".
197 lines
6.9 KiB
C++
197 lines
6.9 KiB
C++
/**
|
|
* Text Generation / Inference Module
|
|
* Core module that performs LLM inference: converts text prompts into tokens,
|
|
* runs the neural network forward pass, samples the next token, and converts
|
|
* output tokens back to text. Supports both simple and system+user prompts.
|
|
*/
|
|
|
|
#include <spdlog/spdlog.h>
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "data_generation/llama_generator.h"
|
|
#include "data_generation/llama_generator_helpers.h"
|
|
#include "llama.h"
|
|
|
|
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
|
|
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
|
|
}
|
|
|
|
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
|
const std::string& prompt, int max_tokens) {
|
|
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
|
|
max_tokens);
|
|
}
|
|
|
|
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|
int max_tokens) {
|
|
/**
|
|
* Validate that model and context are loaded
|
|
*/
|
|
if (model_ == nullptr || context_ == nullptr)
|
|
throw std::runtime_error("LlamaGenerator: model not loaded");
|
|
|
|
/**
|
|
* Get vocabulary for tokenization and token-to-text conversion
|
|
*/
|
|
const llama_vocab* vocab = llama_model_get_vocab(model_);
|
|
if (vocab == nullptr)
|
|
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
|
|
|
/**
|
|
* Clear KV cache to ensure clean inference state (no residual context)
|
|
*/
|
|
llama_memory_clear(llama_get_memory(context_), true);
|
|
|
|
/**
|
|
* TOKENIZATION PHASE
|
|
* Convert text prompt into token IDs (integers) that the model understands
|
|
*/
|
|
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
|
|
int32_t token_count = llama_tokenize(
|
|
vocab, formatted_prompt.c_str(),
|
|
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
|
static_cast<int32_t>(prompt_tokens.size()), true, true);
|
|
|
|
/**
|
|
* If buffer too small, negative return indicates required size
|
|
*/
|
|
if (token_count < 0) {
|
|
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
|
|
token_count = llama_tokenize(
|
|
vocab, formatted_prompt.c_str(),
|
|
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
|
static_cast<int32_t>(prompt_tokens.size()), true, true);
|
|
}
|
|
|
|
if (token_count < 0)
|
|
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
|
|
|
|
/**
|
|
* CONTEXT SIZE VALIDATION
|
|
* Validate and compute effective token budgets based on context window
|
|
* constraints
|
|
*/
|
|
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
|
|
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
|
|
if (n_ctx <= 1 || n_batch <= 0)
|
|
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
|
|
|
/**
|
|
* Clamp generation limit to available context window, reserve space for
|
|
* output
|
|
*/
|
|
const int32_t effective_max_tokens =
|
|
std::max(1, std::min(max_tokens, n_ctx - 1));
|
|
/**
|
|
* Prompt can use remaining context after reserving space for generation
|
|
*/
|
|
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
|
|
prompt_budget = std::max<int32_t>(1, prompt_budget);
|
|
|
|
/**
|
|
* Truncate prompt if necessary to fit within constraints
|
|
*/
|
|
prompt_tokens.resize(static_cast<std::size_t>(token_count));
|
|
if (token_count > prompt_budget) {
|
|
spdlog::warn(
|
|
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
|
|
"tokens to fit n_batch/n_ctx limits",
|
|
token_count, prompt_budget);
|
|
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
|
|
token_count = prompt_budget;
|
|
}
|
|
|
|
/**
|
|
* PROMPT PROCESSING PHASE
|
|
* Create a batch containing all prompt tokens and feed through the model
|
|
* This computes internal representations and fills the KV cache
|
|
*/
|
|
const llama_batch prompt_batch = llama_batch_get_one(
|
|
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
|
|
if (llama_decode(context_, prompt_batch) != 0)
|
|
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
|
|
|
/**
|
|
* SAMPLER CONFIGURATION PHASE
|
|
* Set up the probabilistic token selection pipeline (sampler chain)
|
|
* Samplers are applied in sequence: temperature -> top-p -> distribution
|
|
*/
|
|
llama_sampler_chain_params sampler_params =
|
|
llama_sampler_chain_default_params();
|
|
using SamplerPtr =
|
|
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
|
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
|
|
&llama_sampler_free);
|
|
if (!sampler)
|
|
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
|
|
|
/**
|
|
* Temperature: scales logits before softmax (controls randomness)
|
|
*/
|
|
llama_sampler_chain_add(sampler.get(),
|
|
llama_sampler_init_temp(sampling_temperature_));
|
|
/**
|
|
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
|
|
* probability
|
|
*/
|
|
llama_sampler_chain_add(sampler.get(),
|
|
llama_sampler_init_top_p(sampling_top_p_, 1));
|
|
/**
|
|
* Distribution sampler: selects actual token using configured seed for
|
|
* reproducibility
|
|
*/
|
|
llama_sampler_chain_add(sampler.get(),
|
|
llama_sampler_init_dist(sampling_seed_));
|
|
|
|
/**
|
|
* TOKEN GENERATION LOOP
|
|
* Iteratively generate tokens one at a time until max_tokens or
|
|
* end-of-sequence
|
|
*/
|
|
std::vector<llama_token> generated_tokens;
|
|
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
|
|
|
|
for (int i = 0; i < effective_max_tokens; ++i) {
|
|
/**
|
|
* Sample next token using configured sampler chain and model logits
|
|
* Index -1 means use the last output position from previous batch
|
|
*/
|
|
const llama_token next =
|
|
llama_sampler_sample(sampler.get(), context_, -1);
|
|
/**
|
|
* Stop if model predicts end-of-generation token (EOS/EOT)
|
|
*/
|
|
if (llama_vocab_is_eog(vocab, next)) break;
|
|
generated_tokens.push_back(next);
|
|
/**
|
|
* Feed the sampled token back into model for next iteration
|
|
* (autoregressive)
|
|
*/
|
|
llama_token token = next;
|
|
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
|
|
if (llama_decode(context_, one_token_batch) != 0)
|
|
throw std::runtime_error(
|
|
"LlamaGenerator: decode failed during generation");
|
|
}
|
|
|
|
/**
|
|
* DETOKENIZATION PHASE
|
|
* Convert generated token IDs back to text using vocabulary
|
|
*/
|
|
std::string output;
|
|
for (const llama_token token : generated_tokens)
|
|
AppendTokenPiecePublic(vocab, token, output);
|
|
|
|
/**
|
|
* Advance seed for next generation to improve output diversity
|
|
*/
|
|
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
|
|
|
|
return output;
|
|
}
|