From 8d306bf6915bec7c1fe3ab1e2f762f4f7734fdf6 Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Thu, 2 Apr 2026 23:24:06 -0400 Subject: [PATCH] Update documentation for llama --- .../src/data_generation/llama/destructor.cpp | 15 ++++ .../llama/generate_brewery.cpp | 36 ++++++++- .../data_generation/llama/generate_user.cpp | 45 +++++++++++ .../src/data_generation/llama/helpers.cpp | 33 ++++++++ pipeline/src/data_generation/llama/infer.cpp | 79 +++++++++++++++++++ pipeline/src/data_generation/llama/load.cpp | 13 +++ .../llama/set_sampling_options.cpp | 25 ++++++ 7 files changed, 245 insertions(+), 1 deletion(-) diff --git a/pipeline/src/data_generation/llama/destructor.cpp b/pipeline/src/data_generation/llama/destructor.cpp index 957e071..b4516a6 100644 --- a/pipeline/src/data_generation/llama/destructor.cpp +++ b/pipeline/src/data_generation/llama/destructor.cpp @@ -1,16 +1,31 @@ +/** + * Destructor Module + * Ensures proper cleanup of llama.cpp resources (context and model) when the + * generator is destroyed, preventing memory leaks and resource exhaustion. + */ + #include "data_generation/llama_generator.h" #include "llama.h" LlamaGenerator::~LlamaGenerator() { + /** + * Free the inference context (contains KV cache and computation state) + */ if (context_ != nullptr) { llama_free(context_); context_ = nullptr; } + /** + * Free the loaded model (contains weights and vocabulary) + */ if (model_ != nullptr) { llama_model_free(model_); model_ = nullptr; } + /** + * Clean up the backend (GPU/CPU acceleration resources) + */ llama_backend_free(); } diff --git a/pipeline/src/data_generation/llama/generate_brewery.cpp b/pipeline/src/data_generation/llama/generate_brewery.cpp index b06bc29..31ea071 100644 --- a/pipeline/src/data_generation/llama/generate_brewery.cpp +++ b/pipeline/src/data_generation/llama/generate_brewery.cpp @@ -1,3 +1,10 @@ +/** + * Brewery Data Generation Module + * Uses the LLM to generate realistic brewery names and descriptions for a given + * location. Implements retry logic with validation and error correction to + * ensure valid JSON output conforming to the expected schema. + */ + #include #include @@ -9,9 +16,16 @@ BreweryResult LlamaGenerator::GenerateBrewery( const std::string& city_name, const std::string& country_name, const std::string& region_context) { + /** + * Preprocess and truncate region context to manageable size + */ const std::string safe_region_context = PrepareRegionContextPublic(region_context); + /** + * System prompt: establishes role and output format constraints + * Instructs LLM to roleplay as brewery owner and output only JSON + */ const std::string system_prompt = "You are the brewmaster and owner of a local craft brewery. " "Write a name and a short, soulful description for your brewery that " @@ -22,6 +36,10 @@ BreweryResult LlamaGenerator::GenerateBrewery( "\"description\". " "Do not include markdown formatting or backticks."; + /** + * User prompt: provides geographic context to guide generation towards + * culturally appropriate and locally-inspired brewery attributes + */ std::string prompt = "Write a brewery name and place-specific long description for a craft " "brewery in " + @@ -32,26 +50,41 @@ BreweryResult LlamaGenerator::GenerateBrewery( ? std::string(".") : std::string(". Regional context: ") + safe_region_context); + /** + * RETRY LOOP with validation and error correction + * Attempts to generate valid brewery data up to 3 times, with feedback-based + * refinement + */ const int max_attempts = 3; std::string raw; std::string last_error; + + // Limit output length to keep it concise and focused + constexpr int max_tokens = 1052; for (int attempt = 0; attempt < max_attempts; ++attempt) { - raw = Infer(system_prompt, prompt, 384); + // Generate brewery data from LLM + raw = Infer(system_prompt, prompt, max_tokens); spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, raw); + // Validate output: parse JSON and check required fields + std::string name; std::string description; const std::string validation_error = ValidateBreweryJsonPublic(raw, name, description); if (validation_error.empty()) { + // Success: return parsed brewery data return {std::move(name), std::move(description)}; } + // Validation failed: log error and prepare corrective feedback + last_error = validation_error; spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", attempt + 1, validation_error); + // Update prompt with error details to guide LLM toward correct output prompt = "Your previous response was invalid. Error: " + validation_error + "\nReturn ONLY valid JSON with this exact schema: " @@ -66,6 +99,7 @@ BreweryResult LlamaGenerator::GenerateBrewery( : std::string("\nRegional context: ") + safe_region_context); } + // All retry attempts exhausted: log failure and throw exception spdlog::error( "LlamaGenerator: malformed brewery response after {} attempts: " "{}", diff --git a/pipeline/src/data_generation/llama/generate_user.cpp b/pipeline/src/data_generation/llama/generate_user.cpp index 22fb57a..1985ec7 100644 --- a/pipeline/src/data_generation/llama/generate_user.cpp +++ b/pipeline/src/data_generation/llama/generate_user.cpp @@ -1,3 +1,11 @@ +/** + * User Profile Generation Module + * Uses the LLM to generate realistic user profiles (username and bio) for craft + * beer enthusiasts. Implements retry logic to handle parsing failures and + * ensures output adheres to strict format constraints (two lines, specific + * character limits). + */ + #include #include @@ -8,6 +16,10 @@ #include "data_generation/llama_generator_helpers.h" UserResult LlamaGenerator::GenerateUser(const std::string& locale) { + /** + * System prompt: specifies exact output format to minimize parsing errors + * Constraints: 2-line output, username format, bio length bounds + */ const std::string system_prompt = "You generate plausible social media profiles for craft beer " "enthusiasts. " @@ -17,39 +29,72 @@ UserResult LlamaGenerator::GenerateUser(const std::string& locale) { "The profile should feel consistent with the locale. " "No preamble, no labels."; + /** + * User prompt: locale parameter guides cultural appropriateness of generated + * profiles + */ std::string prompt = "Generate a craft beer enthusiast profile. Locale: " + locale; + /** + * RETRY LOOP with format validation + * Attempts up to 3 times to generate valid user profile with correct format + */ const int max_attempts = 3; std::string raw; for (int attempt = 0; attempt < max_attempts; ++attempt) { + /** + * Generate user profile (max 128 tokens - should fit 2 lines easily) + */ raw = Infer(system_prompt, prompt, 128); spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", attempt + 1, raw); try { + /** + * Parse two-line response: first line = username, second line = bio + */ auto [username, bio] = ParseTwoLineResponsePublic( raw, "LlamaGenerator: malformed user response"); + /** + * Remove any whitespace from username (usernames shouldn't have + * spaces) + */ username.erase( std::remove_if(username.begin(), username.end(), [](unsigned char ch) { return std::isspace(ch); }), username.end()); + /** + * Validate both fields are non-empty after processing + */ if (username.empty() || bio.empty()) { throw std::runtime_error("LlamaGenerator: malformed user response"); } + /** + * Truncate bio if exceeds reasonable length for bio field + */ if (bio.size() > 200) bio = bio.substr(0, 200); + /** + * Success: return parsed user profile + */ return {username, bio}; } catch (const std::exception& e) { + /** + * Parsing failed: log and continue to next attempt + */ spdlog::warn( "LlamaGenerator: malformed user response (attempt {}): {}", attempt + 1, e.what()); } } + /** + * All retry attempts exhausted: log failure and throw exception + */ spdlog::error( "LlamaGenerator: malformed user response after {} attempts: {}", max_attempts, raw); diff --git a/pipeline/src/data_generation/llama/helpers.cpp b/pipeline/src/data_generation/llama/helpers.cpp index 18a7e8f..2a6bd2d 100644 --- a/pipeline/src/data_generation/llama/helpers.cpp +++ b/pipeline/src/data_generation/llama/helpers.cpp @@ -1,3 +1,11 @@ +/** + * Helper Functions Module + * Provides utility functions for text processing, parsing, and chat template + * formatting. Functions handle whitespace normalization, response parsing, and + * conversion of prompts to proper chat format using the model's built-in + * template. + */ + #include #include #include @@ -12,6 +20,9 @@ namespace { +/** + * String trimming: removes leading and trailing whitespace + */ std::string Trim(std::string value) { auto not_space = [](unsigned char ch) { return !std::isspace(ch); }; @@ -23,6 +34,10 @@ std::string Trim(std::string value) { return value; } +/** + * Normalize whitespace: collapses multiple spaces/tabs/newlines into single + * spaces + */ std::string CondenseWhitespace(std::string text) { std::string out; out.reserve(text.size()); @@ -44,6 +59,10 @@ std::string CondenseWhitespace(std::string text) { return Trim(std::move(out)); } +/** + * Truncate region context to fit within max length while preserving word + * boundaries + */ std::string PrepareRegionContext(std::string_view region_context, std::size_t max_chars) { std::string normalized = CondenseWhitespace(std::string(region_context)); @@ -61,6 +80,9 @@ std::string PrepareRegionContext(std::string_view region_context, return normalized; } +/** + * Remove common bullet points, numbers, and field labels added by LLM in output + */ std::string StripCommonPrefix(std::string line) { line = Trim(std::move(line)); @@ -102,6 +124,10 @@ std::string StripCommonPrefix(std::string line) { return Trim(std::move(line)); } +/** + * Parse two-line response from LLM: normalize line endings, strip formatting, + * filter spurious output, and combine remaining lines if needed + */ std::pair ParseTwoLineResponse( const std::string& raw, const std::string& error_message) { std::string normalized = raw; @@ -140,6 +166,9 @@ std::pair ParseTwoLineResponse( return {first, second}; } +/** + * Apply model's chat template to user-only prompt, formatting it for the model + */ std::string ToChatPrompt(const llama_model* model, const std::string& user_prompt) { const char* tmpl = llama_model_chat_template(model, nullptr); @@ -173,6 +202,10 @@ std::string ToChatPrompt(const llama_model* model, return std::string(buffer.data(), static_cast(required)); } +/** + * Apply model's chat template to system+user prompt pair, formatting for the + * model + */ std::string ToChatPrompt(const llama_model* model, const std::string& system_prompt, const std::string& user_prompt) { diff --git a/pipeline/src/data_generation/llama/infer.cpp b/pipeline/src/data_generation/llama/infer.cpp index ae1b786..1a1c7d0 100644 --- a/pipeline/src/data_generation/llama/infer.cpp +++ b/pipeline/src/data_generation/llama/infer.cpp @@ -1,3 +1,10 @@ +/** + * Text Generation / Inference Module + * Core module that performs LLM inference: converts text prompts into tokens, + * runs the neural network forward pass, samples the next token, and converts + * output tokens back to text. Supports both simple and system+user prompts. + */ + #include #include @@ -22,21 +29,37 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt, std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, int max_tokens) { + /** + * Validate that model and context are loaded + */ if (model_ == nullptr || context_ == nullptr) throw std::runtime_error("LlamaGenerator: model not loaded"); + /** + * Get vocabulary for tokenization and token-to-text conversion + */ const llama_vocab* vocab = llama_model_get_vocab(model_); if (vocab == nullptr) throw std::runtime_error("LlamaGenerator: vocab unavailable"); + /** + * Clear KV cache to ensure clean inference state (no residual context) + */ llama_memory_clear(llama_get_memory(context_), true); + /** + * TOKENIZATION PHASE + * Convert text prompt into token IDs (integers) that the model understands + */ std::vector prompt_tokens(formatted_prompt.size() + 8); int32_t token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), static_cast(prompt_tokens.size()), true, true); + /** + * If buffer too small, negative return indicates required size + */ if (token_count < 0) { prompt_tokens.resize(static_cast(-token_count)); token_count = llama_tokenize( @@ -48,16 +71,31 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, if (token_count < 0) throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + /** + * CONTEXT SIZE VALIDATION + * Validate and compute effective token budgets based on context window + * constraints + */ const int32_t n_ctx = static_cast(llama_n_ctx(context_)); const int32_t n_batch = static_cast(llama_n_batch(context_)); if (n_ctx <= 1 || n_batch <= 0) throw std::runtime_error("LlamaGenerator: invalid context or batch size"); + /** + * Clamp generation limit to available context window, reserve space for + * output + */ const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); + /** + * Prompt can use remaining context after reserving space for generation + */ int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); prompt_budget = std::max(1, prompt_budget); + /** + * Truncate prompt if necessary to fit within constraints + */ prompt_tokens.resize(static_cast(token_count)); if (token_count > prompt_budget) { spdlog::warn( @@ -68,11 +106,21 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, token_count = prompt_budget; } + /** + * PROMPT PROCESSING PHASE + * Create a batch containing all prompt tokens and feed through the model + * This computes internal representations and fills the KV cache + */ const llama_batch prompt_batch = llama_batch_get_one( prompt_tokens.data(), static_cast(prompt_tokens.size())); if (llama_decode(context_, prompt_batch) != 0) throw std::runtime_error("LlamaGenerator: prompt decode failed"); + /** + * SAMPLER CONFIGURATION PHASE + * Set up the probabilistic token selection pipeline (sampler chain) + * Samplers are applied in sequence: temperature -> top-p -> distribution + */ llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params(); using SamplerPtr = @@ -82,21 +130,48 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, if (!sampler) throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + /** + * Temperature: scales logits before softmax (controls randomness) + */ llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(sampling_temperature_)); + /** + * Top-P: nucleus sampling - filters to most likely tokens summing to top_p + * probability + */ llama_sampler_chain_add(sampler.get(), llama_sampler_init_top_p(sampling_top_p_, 1)); + /** + * Distribution sampler: selects actual token using configured seed for + * reproducibility + */ llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(sampling_seed_)); + /** + * TOKEN GENERATION LOOP + * Iteratively generate tokens one at a time until max_tokens or + * end-of-sequence + */ std::vector generated_tokens; generated_tokens.reserve(static_cast(effective_max_tokens)); for (int i = 0; i < effective_max_tokens; ++i) { + /** + * Sample next token using configured sampler chain and model logits + * Index -1 means use the last output position from previous batch + */ const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); + /** + * Stop if model predicts end-of-generation token (EOS/EOT) + */ if (llama_vocab_is_eog(vocab, next)) break; generated_tokens.push_back(next); + /** + * Feed the sampled token back into model for next iteration + * (autoregressive) + */ llama_token token = next; const llama_batch one_token_batch = llama_batch_get_one(&token, 1); if (llama_decode(context_, one_token_batch) != 0) @@ -104,6 +179,10 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, "LlamaGenerator: decode failed during generation"); } + /** + * DETOKENIZATION PHASE + * Convert generated token IDs back to text using vocabulary + */ std::string output; for (const llama_token token : generated_tokens) AppendTokenPiecePublic(vocab, token, output); diff --git a/pipeline/src/data_generation/llama/load.cpp b/pipeline/src/data_generation/llama/load.cpp index 1853b5f..a20827d 100644 --- a/pipeline/src/data_generation/llama/load.cpp +++ b/pipeline/src/data_generation/llama/load.cpp @@ -1,3 +1,10 @@ +/** + * Model Loading Module + * This module handles loading a pre-trained LLM model from disk and + * initializing the llama.cpp context for inference. It performs one-time setup + * required before any inference operations can be performed. + */ + #include #include @@ -7,6 +14,9 @@ #include "llama.h" void LlamaGenerator::Load(const std::string& model_path) { + /** + * Validate input and clean up any previously loaded model/context + */ if (model_path.empty()) throw std::runtime_error("LlamaGenerator: model path must not be empty"); @@ -19,6 +29,9 @@ void LlamaGenerator::Load(const std::string& model_path) { model_ = nullptr; } + /** + * Initialize the llama backend (one-time setup for GPU/CPU acceleration) + */ llama_backend_init(); llama_model_params model_params = llama_model_default_params(); diff --git a/pipeline/src/data_generation/llama/set_sampling_options.cpp b/pipeline/src/data_generation/llama/set_sampling_options.cpp index 8953eda..7b9238c 100644 --- a/pipeline/src/data_generation/llama/set_sampling_options.cpp +++ b/pipeline/src/data_generation/llama/set_sampling_options.cpp @@ -1,3 +1,10 @@ +/** + * Sampling Configuration Module + * Configures the hyperparameters that control probabilistic token selection + * during text generation. These settings affect the randomness, diversity, and + * quality of generated output. + */ + #include #include "data_generation/llama_generator.h" @@ -5,19 +12,37 @@ void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, int seed) { + /** + * Validate temperature: controls randomness in output distribution + * 0.0 = deterministic (always pick highest probability token) + * Higher values = more random/diverse output + */ if (temperature < 0.0f) { throw std::runtime_error( "LlamaGenerator: sampling temperature must be >= 0"); } + + /** + * Validate top-p (nucleus sampling): only sample from top cumulative + * probability e.g., top-p=0.9 means sample from tokens that make up 90% of + * probability mass + */ if (!(top_p > 0.0f && top_p <= 1.0f)) { throw std::runtime_error( "LlamaGenerator: sampling top-p must be in (0, 1]"); } + + /** + * Validate seed: for reproducible results (-1 uses random seed) + */ if (seed < -1) { throw std::runtime_error( "LlamaGenerator: seed must be >= 0, or -1 for random"); } + /** + * Store sampling parameters for use during token generation + */ sampling_temperature_ = temperature; sampling_top_p_ = top_p; sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED)