/** * Text Generation / Inference Module * Core module that performs LLM inference: converts text prompts into tokens, * runs the neural network forward pass, samples the next token, and converts * output tokens back to text. Supports both simple and system+user prompts. */ #include #include #include #include #include #include #include "data_generation/llama_generator.h" #include "data_generation/llama_generator_helpers.h" #include "llama.h" std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens); } std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, int max_tokens) { return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt), max_tokens); } std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, int max_tokens) { /** * Validate that model and context are loaded */ if (model_ == nullptr || context_ == nullptr) throw std::runtime_error("LlamaGenerator: model not loaded"); /** * Get vocabulary for tokenization and token-to-text conversion */ const llama_vocab* vocab = llama_model_get_vocab(model_); if (vocab == nullptr) throw std::runtime_error("LlamaGenerator: vocab unavailable"); /** * Clear KV cache to ensure clean inference state (no residual context) */ llama_memory_clear(llama_get_memory(context_), true); /** * TOKENIZATION PHASE * Convert text prompt into token IDs (integers) that the model understands */ std::vector prompt_tokens(formatted_prompt.size() + 8); int32_t token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), static_cast(prompt_tokens.size()), true, true); /** * If buffer too small, negative return indicates required size */ if (token_count < 0) { prompt_tokens.resize(static_cast(-token_count)); token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), static_cast(prompt_tokens.size()), true, true); } if (token_count < 0) throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); /** * CONTEXT SIZE VALIDATION * Validate and compute effective token budgets based on context window * constraints */ const int32_t n_ctx = static_cast(llama_n_ctx(context_)); const int32_t n_batch = static_cast(llama_n_batch(context_)); if (n_ctx <= 1 || n_batch <= 0) throw std::runtime_error("LlamaGenerator: invalid context or batch size"); /** * Clamp generation limit to available context window, reserve space for * output */ const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); /** * Prompt can use remaining context after reserving space for generation */ int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); prompt_budget = std::max(1, prompt_budget); /** * Truncate prompt if necessary to fit within constraints */ prompt_tokens.resize(static_cast(token_count)); if (token_count > prompt_budget) { spdlog::warn( "LlamaGenerator: prompt too long ({} tokens), truncating to {} " "tokens to fit n_batch/n_ctx limits", token_count, prompt_budget); prompt_tokens.resize(static_cast(prompt_budget)); token_count = prompt_budget; } /** * PROMPT PROCESSING PHASE * Create a batch containing all prompt tokens and feed through the model * This computes internal representations and fills the KV cache */ const llama_batch prompt_batch = llama_batch_get_one( prompt_tokens.data(), static_cast(prompt_tokens.size())); if (llama_decode(context_, prompt_batch) != 0) throw std::runtime_error("LlamaGenerator: prompt decode failed"); /** * SAMPLER CONFIGURATION PHASE * Set up the probabilistic token selection pipeline (sampler chain) * Samplers are applied in sequence: temperature -> top-p -> distribution */ llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params(); using SamplerPtr = std::unique_ptr; SamplerPtr sampler(llama_sampler_chain_init(sampler_params), &llama_sampler_free); if (!sampler) throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); /** * Temperature: scales logits before softmax (controls randomness) */ llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(sampling_temperature_)); /** * Top-P: nucleus sampling - filters to most likely tokens summing to top_p * probability */ llama_sampler_chain_add(sampler.get(), llama_sampler_init_top_p(sampling_top_p_, 1)); /** * Distribution sampler: selects actual token using configured seed for * reproducibility */ llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(sampling_seed_)); /** * TOKEN GENERATION LOOP * Iteratively generate tokens one at a time until max_tokens or * end-of-sequence */ std::vector generated_tokens; generated_tokens.reserve(static_cast(effective_max_tokens)); for (int i = 0; i < effective_max_tokens; ++i) { /** * Sample next token using configured sampler chain and model logits * Index -1 means use the last output position from previous batch */ const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); /** * Stop if model predicts end-of-generation token (EOS/EOT) */ if (llama_vocab_is_eog(vocab, next)) break; generated_tokens.push_back(next); /** * Feed the sampled token back into model for next iteration * (autoregressive) */ llama_token token = next; const llama_batch one_token_batch = llama_batch_get_one(&token, 1); if (llama_decode(context_, one_token_batch) != 0) throw std::runtime_error( "LlamaGenerator: decode failed during generation"); } /** * DETOKENIZATION PHASE * Convert generated token IDs back to text using vocabulary */ std::string output; for (const llama_token token : generated_tokens) AppendTokenPiecePublic(vocab, token, output); /** * Advance seed for next generation to improve output diversity */ sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1; return output; }