Refactor Llama generator, helpers, and build assets

make Gemma 4 the default model, enable thinking mode
style updates
This commit is contained in:
Aaron Po
2026-04-10 00:03:45 -04:00
parent 7ca651a886
commit 56ec728ba7
61 changed files with 1430 additions and 1905 deletions

View File

@@ -1,51 +0,0 @@
/**
* @file data_generation/llama/constructor.cpp
* @brief LlamaGenerator constructor implementation.
*/
#include <random>
#include <stdexcept>
#include <string>
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path)
: rng_() {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (options.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
Load(model_path);
}

View File

@@ -1,26 +0,0 @@
/**
* @file data_generation/llama/destructor.cpp
* @brief Releases llama model/context resources and backend state during
* LlamaGenerator teardown to avoid leaks across runs.
*/
#include "data_generation/llama_generator.h"
#include "llama.h"
LlamaGenerator::~LlamaGenerator() {
/**
* Free the inference context (contains KV cache and computation state)
*/
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
/**
* Free the loaded model (contains weights and vocabulary)
*/
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
}

View File

@@ -6,65 +6,109 @@
#include <spdlog/spdlog.h>
#include <array>
#include <format>
#include <optional>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
return std::string(trimmed);
}
BreweryResult LlamaGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const Location& location, const std::string& region_context) {
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
* Default path: prompts/brewery_system_prompt_expanded.txt
*/
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/brewery_system_prompt_expanded.txt");
LoadBrewerySystemPrompt("prompts/system.md");
/**
* User prompt: provides geographic context to guide generation towards
* culturally appropriate and locally-inspired brewery attributes
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt =
std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
"Location: " + city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name);
std::format("Location: {}{}", location.city, country_suffix);
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
const int max_attempts = 3;
constexpr int max_attempts = 3;
std::string raw;
std::string last_error;
// Limit output length to keep it concise and focused
constexpr int max_tokens = 1052;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = Infer(system_prompt, prompt, max_tokens);
raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
@@ -72,29 +116,29 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string name;
std::string description;
const std::string validation_error =
ValidateBreweryJsonPublic(raw, name, description);
if (validation_error.empty()) {
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return {std::move(name), std::move(description)};
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
}
// Validation failed: log error and prepare corrective feedback
last_error = validation_error;
last_error = *validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
attempt + 1, *validation_error);
// Update prompt with error details to guide LLM toward correct output.
// For retries, use a compact prompt format to avoid exceeding token
// limits.
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\n" +
retry_location;
prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
{})",
*validation_error, retry_location);
}
// All retry attempts exhausted: log failure and throw exception

View File

@@ -6,7 +6,6 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
@@ -14,87 +13,6 @@
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
/**
* System prompt: specifies exact output format to minimize parsing errors
* Constraints: 2-line output, username format, bio length bounds
*/
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
/**
* User prompt: locale parameter guides cultural appropriateness of generated
* profiles
*/
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
/**
* RETRY LOOP with format validation
* Attempts up to 3 times to generate valid user profile with correct format
*/
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
/**
* Generate user profile (max 128 tokens - should fit 2 lines easily)
*/
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
/**
* Parse two-line response: first line = username, second line = bio
*/
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
/**
* Remove any whitespace from username (usernames shouldn't have
* spaces)
*/
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
/**
* Validate both fields are non-empty after processing
*/
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
/**
* Truncate bio if exceeds reasonable length for bio field
*/
if (bio.size() > 200) bio = bio.substr(0, 200);
/**
* Success: return parsed user profile
*/
return {username, bio};
} catch (const std::exception& e) {
/**
* Parsing failed: log and continue to next attempt
*/
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
/**
* All retry attempts exhausted: log failure and throw exception
*/
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."};
}

View File

@@ -4,13 +4,17 @@
* parsing, token decoding, and JSON validation helpers for Llama modules.
*/
#include <spdlog/spdlog.h>
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/llama_generator.h"
@@ -19,40 +23,42 @@
/**
* String trimming: removes leading and trailing whitespace
*/
static std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
}
/**
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces
*/
static std::string CondenseWhitespace(std::string text) {
static std::string CondenseWhitespace(std::string_view text) {
std::string out;
out.reserve(text.size());
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
bool pending_space = false;
for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
}
continue;
}
in_whitespace = false;
out.push_back(static_cast<char>(ch));
if (pending_space) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
}
return Trim(std::move(out));
return out;
}
/**
@@ -60,14 +66,14 @@ static std::string CondenseWhitespace(std::string text) {
* boundaries
*/
static std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' ');
const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
@@ -76,108 +82,20 @@ static std::string PrepareRegionContext(std::string_view region_context,
return normalized;
}
/**
* Remove common bullet points, numbers, and field labels added by LLM in output
*/
static std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
/**
* Parse two-line response from LLM: normalize line endings, strip formatting,
* filter spurious output, and combine remaining lines if needed
*/
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
return system_prompt + "\n\n" + user_prompt;
spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
}
const std::array<llama_chat_message, 2> messages = {
@@ -186,65 +104,62 @@ std::string ToChatPrompt(const llama_model* model,
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages.data(), 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
auto apply_template_with_resize =
[&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
// FALLBACK: If the template fails (e.g., Gemma rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
if (required < 0) {
std::string combined_prompt = system_prompt + "\n\n" + user_prompt;
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
required = llama_chat_apply_template(tmpl, fallback_msg.data(), 1, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// THE FIX: Ultimate fallback. If the GGUF's internal template is
// completely unparseable (which happens with complex Jinja macros),
// degrade gracefully to raw text instead of throwing a runtime_error.
if (required < 0) {
return combined_prompt;
if (result < 0) {
return result;
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(
tmpl, fallback_msg.data(), 1, true, buffer.data(),
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
return combined_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
return result;
};
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
// Standard buffer resize if the original "system" + "user" array succeeded
// but needed more space
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages.data(), 2, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// Final safety net on resize
if (required < 0) {
return system_prompt + "\n\n" + user_prompt;
}
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1);
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
int32_t bytes = llama_token_to_piece(vocab, token, buffer.data(),
buffer.size(), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
@@ -263,12 +178,14 @@ static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
static bool ExtractFirstJsonObject(const std::string& text,
std::string& json_out) {
static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
bool found = false;
std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
@@ -303,18 +220,32 @@ static bool ExtractFirstJsonObject(const std::string& text,
}
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
candidate = text.substr(start, i - start + 1);
found = true;
}
}
}
return false;
if (!found) {
return false;
}
json_out = std::move(candidate);
return true;
}
static std::string ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
return {};
}
static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
@@ -333,9 +264,11 @@ static std::string ValidateBreweryJson(const std::string& raw,
return false;
}
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
@@ -371,7 +304,7 @@ static std::string ValidateBreweryJson(const std::string& raw,
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractFirstJsonObject(raw, extracted)) {
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
@@ -385,14 +318,14 @@ static std::string ValidateBreweryJson(const std::string& raw,
return validation_error;
}
return {};
return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
return std::nullopt;
}
// Forward declarations for helper functions exposed to other translation units
@@ -401,16 +334,6 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt, "");
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
@@ -422,8 +345,8 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
AppendTokenPiece(vocab, token, output);
}
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -2,7 +2,7 @@
* Text Generation / Inference Module
* Core module that performs LLM inference: converts text prompts into tokens,
* runs the neural network forward pass, samples the next token, and converts
* output tokens back to text. Supports both simple and system+user prompts.
* output tokens back to text for system+user chat prompts.
*/
#include <spdlog/spdlog.h>
@@ -17,174 +17,156 @@
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
static constexpr std::size_t kPromptTokenSlack = 8;
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
const std::string& prompt,
const int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
int max_tokens) {
/**
* Validate that model and context are loaded
*/
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const int max_tokens) {
/**
* Validate that model and context are loaded
*/
if (model_ == nullptr || context_ == nullptr) {
throw std::runtime_error("LlamaGenerator: model not loaded");
}
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
/**
* Clear KV cache to ensure clean inference state (no residual context)
*/
llama_memory_clear(llama_get_memory(context_), true);
/**
* Clear KV cache to ensure clean inference state (no residual context)
*/
llama_memory_clear(llama_get_memory(context_), true);
/**
* TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands
*/
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
/**
* TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands
*/
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
kPromptTokenSlack);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
/**
* If buffer too small, negative return indicates required size
*/
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
/**
* If buffer too small, negative return indicates required size
*/
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
if (token_count < 0) {
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
/**
* CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window
* constraints
*/
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
/**
* CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window
* constraints
*/
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
/**
* Clamp generation limit to available context window, reserve space for
* output
*/
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
/**
* Prompt can use remaining context after reserving space for generation
*/
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
/**
* Clamp generation limit to available context window, reserve space for
* output
*/
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
/**
* Prompt can use remaining context after reserving space for generation
*/
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
/**
* Truncate prompt if necessary to fit within constraints
*/
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
/**
* Truncate prompt if necessary to fit within constraints
*/
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
/**
* PROMPT PROCESSING PHASE
* Create a batch containing all prompt tokens and feed through the model
* This computes internal representations and fills the KV cache
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
/**
* PROMPT PROCESSING PHASE
* Create a batch containing all prompt tokens and feed through the model
* This computes internal representations and fills the KV cache
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) {
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
/**
* SAMPLER CONFIGURATION PHASE
* Set up the probabilistic token selection pipeline (sampler chain)
* Samplers are applied in sequence: temperature -> top-p -> distribution
*/
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
/**
* TOKEN GENERATION LOOP
* Iteratively generate tokens one at a time until max_tokens or
* end-of-sequence
*/
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
/**
* Temperature: scales logits before softmax (controls randomness)
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
/**
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
* probability
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
/**
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
if (sampler_ == nullptr || sampler_->chain == nullptr) {
throw std::runtime_error("LlamaGenerator: sampler not initialized");
}
/**
* TOKEN GENERATION LOOP
* Iteratively generate tokens one at a time until max_tokens or
* end-of-sequence
*/
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
/**
* Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler_->chain, context_, -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) {
break;
}
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token decode_token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
if (llama_decode(context_, one_token_batch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
for (int i = 0; i < effective_max_tokens; ++i) {
/**
* Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler.get(), context_, -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
/**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens) {
AppendTokenPiecePublic(vocab, token, output);
}
/**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
return output;
return output;
}

View File

@@ -0,0 +1,125 @@
/**
* @file data_generation/llama/llama_generator.cpp
* @brief LlamaGenerator constructor and destructor implementation.
*/
#include "data_generation/llama_generator.h"
#include <memory>
#include <random>
#include <stdexcept>
#include <string>
#include "data_model/application_options.h"
#include "llama.h"
static constexpr uint32_t kMaxContextSize = 32768U;
struct SamplerConfig {
float temperature;
float top_p;
uint32_t top_k;
};
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
std::mt19937& rng) {
const llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(config.temperature));
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(config.top_p, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng()));
return sampler;
}
LlamaGenerator::SamplerState::~SamplerState() {
if (chain != nullptr) {
llama_sampler_free(chain);
chain = nullptr;
}
}
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path)
: rng_(std::random_device{}()) {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (options.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
}
if (options.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k;
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
this->Load(model_path);
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
sampling_top_k_};
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
sampler_.reset(new SamplerState());
sampler_->chain = sampler_chain.release();
}
LlamaGenerator::~LlamaGenerator() {
sampler_.reset();
/**
* Free the inference context (contains KV cache and computation state)
*/
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
/**
* Free the loaded model (contains weights and vocabulary)
*/
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
}

View File

@@ -23,7 +23,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
model_ = nullptr;
}
llama_model_params model_params = llama_model_default_params();
const llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(

View File

@@ -1,13 +1,14 @@
/**
* @file data_generation/llama/load_brewery_prompt.cpp
* @brief Resolves brewery system prompt content from cache or filesystem
* search paths and provides a robust inline fallback prompt when absent.
* @brief Resolves brewery system prompt content from cache or a configured
* filesystem path and provides a robust inline fallback prompt when absent.
*/
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h"
@@ -17,81 +18,43 @@ namespace fs = std::filesystem;
* @brief Loads brewery system prompt from disk or cache.
*
* @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk or fallback content.
* @return Prompt text loaded from disk.
*/
std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) {
const std::string& prompt_file_path) {
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
// Try multiple path locations
std::vector<std::string> paths_to_try = {
prompt_file_path, // As provided
"../" + prompt_file_path, // One level up
"../../" + prompt_file_path, // Two levels up
};
for (const auto& path : paths_to_try) {
std::ifstream prompt_file(path);
if (prompt_file.is_open()) {
std::string prompt((std::istreambuf_iterator<char>(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (!prompt.empty()) {
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} "
"chars)",
path, prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}
}
// Try the provided path only
const fs::path prompt_path(prompt_file_path);
std::ifstream prompt_file(prompt_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
spdlog::warn(
"LlamaGenerator: Could not open brewery system prompt file at any of "
"the "
"expected locations. Using fallback inline prompt.");
return GetFallbackBreweryPrompt();
}
const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
/**
* @brief Provides an inline fallback brewery system prompt.
*
* @return Default fallback prompt text.
*/
std::string LlamaGenerator::GetFallbackBreweryPrompt() {
return "You are an experienced brewmaster and owner of a local craft "
"brewery. "
"Create a distinctive, authentic name and detailed description that "
"genuinely reflects your specific location, brewing philosophy, "
"local "
"culture, and community connection. The brewery must feel real and "
"grounded—not generic or interchangeable.\n\n"
"AVOID REPETITIVE PHRASES - Never use:\n"
"Love letter to, tribute to, rolling hills, picturesque, every sip "
"tells a story, Come for X stay for Y, rich history, passion, woven "
"into, ancient roots, timeless, where tradition meets innovation\n\n"
"OPENING APPROACHES - Choose ONE:\n"
"1. Start with specific beer style and its regional origins\n"
"2. Begin with specific brewing challenge (water, altitude, "
"climate)\n"
"3. Open with founding story or personal motivation\n"
"4. Lead with specific local ingredient or resource\n"
"5. Start with unexpected angle or contradiction\n"
"6. Open with local event, tradition, or cultural moment\n"
"7. Begin with tangible architectural or geographic detail\n\n"
"BE SPECIFIC - Include:\n"
"- At least ONE concrete proper noun (landmark, river, "
"neighborhood)\n"
"- Specific beer styles relevant to the REGION'S culture\n"
"- Concrete brewing challenges or advantages\n"
"- Sensory details SPECIFIC to place—not generic adjectives\n\n"
"LENGTH: 150-250 words. TONE: Can be soulful, irreverent, "
"matter-of-fact, unpretentious, or minimalist.\n\n"
"Output ONLY a raw JSON object with keys name and description. "
"No markdown, backticks, preamble, or trailing text.";
}
if (prompt.empty()) {
spdlog::error(
"LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_path.string());
}
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_path.string(), prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}