/** * Text Generation / Inference Module * Core module that performs LLM inference: converts text prompts into tokens, * runs the neural network forward pass, samples the next token, and converts * output tokens back to text for system+user chat prompts. */ #include #include #include #include #include #include #include #include #include "data_generation/llama_generator.h" #include "data_generation/llama_generator_helpers.h" #include "llama.h" static constexpr size_t kPromptTokenSlack = 8; // Minimum tokens to keep when using top-p sampling. Ensures at least one // candidate token remains available even with very restrictive top-p values. static constexpr size_t kTopPMinKeep = 1; namespace { using SamplerHandle = std::unique_ptr; struct SamplerConfig { float temperature; uint32_t top_k; float top_p; uint32_t seed; }; SamplerHandle MakeSamplerChain(const llama_vocab* vocab, const SamplerConfig& config, std::string_view grammar) { const llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params(); SamplerHandle chain(llama_sampler_chain_init(sampler_params), llama_sampler_free); if (!chain) { throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); } auto add_sampler = [&](llama_sampler* sampler, const char* error_message) { if (sampler == nullptr) { throw std::runtime_error(error_message); } llama_sampler_chain_add(chain.get(), sampler); }; if (!grammar.empty()) { const std::string grammar_text(grammar); add_sampler(llama_sampler_init_grammar(vocab, grammar_text.c_str(), "root"), "LlamaGenerator: failed to initialize grammar sampler"); } add_sampler(llama_sampler_init_temp(config.temperature), "LlamaGenerator: failed to initialize temperature sampler"); add_sampler(llama_sampler_init_top_k(static_cast(config.top_k)), "LlamaGenerator: failed to initialize top-k sampler"); add_sampler(llama_sampler_init_top_p(config.top_p, kTopPMinKeep), "LlamaGenerator: failed to initialize top-p sampler"); add_sampler(llama_sampler_init_dist(config.seed), "LlamaGenerator: failed to initialize distribution sampler"); return chain; } } // namespace std::string LlamaGenerator::Infer(const std::string& system_prompt, const std::string& prompt, const int max_tokens, std::string_view grammar) { return InferFormatted(prompt_formatter_->Format(system_prompt, prompt), max_tokens, grammar); } std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, const int max_tokens, std::string_view grammar) { /** * Validate that model and context are loaded */ if (!model_ || !context_) { throw std::runtime_error("LlamaGenerator: model not loaded"); } /** * Get vocabulary for tokenization and token-to-text conversion */ const llama_vocab* vocab = llama_model_get_vocab(model_.get()); if (vocab == nullptr) { throw std::runtime_error("LlamaGenerator: vocab unavailable"); } const SamplerConfig sampler_config{ .temperature = sampling_temperature_, .top_k = sampling_top_k_, .top_p = sampling_top_p_, .seed = static_cast(rng_()), }; const auto sampler = MakeSamplerChain(vocab, sampler_config, grammar); /** * Clear KV cache to ensure clean inference state (no residual context) */ llama_memory_clear(llama_get_memory(context_.get()), true); /** * TOKENIZATION PHASE * Convert text prompt into token IDs (integers) that the model understands */ std::vector prompt_tokens(formatted_prompt.size() + kPromptTokenSlack); int32_t token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), static_cast(prompt_tokens.size()), true, true); /** * If buffer too small, negative return indicates required size */ if (token_count < 0) { prompt_tokens.resize(static_cast(-token_count)); token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), static_cast(prompt_tokens.size()), true, true); } if (token_count < 0) { throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); } /** * CONTEXT SIZE VALIDATION * Validate and compute effective token budgets based on context window * constraints */ const auto n_ctx = static_cast(llama_n_ctx(context_.get())); const auto n_batch = static_cast(llama_n_batch(context_.get())); if (n_ctx <= 1 || n_batch <= 0) { throw std::runtime_error("LlamaGenerator: invalid context or batch size"); } /** * Clamp generation limit to available context window, reserve space for * output */ const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); /** * Prompt can use remaining context after reserving space for generation */ int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); prompt_budget = std::max(1, prompt_budget); /** * Truncate prompt if necessary to fit within constraints */ prompt_tokens.resize(static_cast(token_count)); if (token_count > prompt_budget) { if (logger_) { logger_->Log({.level = LogLevel::Warn, .phase = PipelinePhase::BreweryAndBeerGeneration, .message = std::format( "LlamaGenerator: prompt too long ({} tokens), " "truncating to {} tokens to fit n_batch/n_ctx limits", token_count, prompt_budget)}); } prompt_tokens.resize(static_cast(prompt_budget)); token_count = prompt_budget; } /** * PROMPT PROCESSING PHASE * Create a batch containing all prompt tokens and feed through the model * This computes internal representations and fills the KV cache */ const llama_batch prompt_batch = llama_batch_get_one( prompt_tokens.data(), static_cast(prompt_tokens.size())); if (llama_decode(context_.get(), prompt_batch) != 0) { throw std::runtime_error("LlamaGenerator: prompt decode failed"); } /** * TOKEN GENERATION LOOP * Iteratively generate tokens one at a time until max_tokens or * end-of-sequence */ std::vector generated_tokens; generated_tokens.reserve(static_cast(effective_max_tokens)); for (int i = 0; i < effective_max_tokens; ++i) { /** * Sample next token using configured sampler chain and model logits * Index -1 means use the last output position from previous batch */ const llama_token next = llama_sampler_sample(sampler.get(), context_.get(), -1); /** * Stop if model predicts end-of-generation token (EOS/EOT) */ if (llama_vocab_is_eog(vocab, next)) { break; } generated_tokens.push_back(next); /** * Feed the sampled token back into model for next iteration * (autoregressive) */ llama_token decode_token = next; const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1); if (llama_decode(context_.get(), one_token_batch) != 0) { throw std::runtime_error( "LlamaGenerator: decode failed during generation"); } } /** * DETOKENIZATION PHASE * Convert generated token IDs back to text using vocabulary */ std::string output; for (const llama_token token : generated_tokens) { AppendTokenPiece(vocab, token, output); } return output; }