Refactor Llama generator, helpers, and build assets

make Gemma 4 the default model, enable thinking mode
style updates
This commit is contained in:
Aaron Po
2026-04-10 00:03:45 -04:00
parent 7ca651a886
commit 56ec728ba7
61 changed files with 1430 additions and 1905 deletions

View File

@@ -4,13 +4,17 @@
* parsing, token decoding, and JSON validation helpers for Llama modules.
*/
#include <spdlog/spdlog.h>
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/llama_generator.h"
@@ -19,40 +23,42 @@
/**
* String trimming: removes leading and trailing whitespace
*/
static std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
}
/**
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces
*/
static std::string CondenseWhitespace(std::string text) {
static std::string CondenseWhitespace(std::string_view text) {
std::string out;
out.reserve(text.size());
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
bool pending_space = false;
for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
}
continue;
}
in_whitespace = false;
out.push_back(static_cast<char>(ch));
if (pending_space) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
}
return Trim(std::move(out));
return out;
}
/**
@@ -60,14 +66,14 @@ static std::string CondenseWhitespace(std::string text) {
* boundaries
*/
static std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' ');
const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
@@ -76,108 +82,20 @@ static std::string PrepareRegionContext(std::string_view region_context,
return normalized;
}
/**
* Remove common bullet points, numbers, and field labels added by LLM in output
*/
static std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
/**
* Parse two-line response from LLM: normalize line endings, strip formatting,
* filter spurious output, and combine remaining lines if needed
*/
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
return system_prompt + "\n\n" + user_prompt;
spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
}
const std::array<llama_chat_message, 2> messages = {
@@ -186,65 +104,62 @@ std::string ToChatPrompt(const llama_model* model,
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages.data(), 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
auto apply_template_with_resize =
[&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
// FALLBACK: If the template fails (e.g., Gemma rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
if (required < 0) {
std::string combined_prompt = system_prompt + "\n\n" + user_prompt;
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
required = llama_chat_apply_template(tmpl, fallback_msg.data(), 1, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// THE FIX: Ultimate fallback. If the GGUF's internal template is
// completely unparseable (which happens with complex Jinja macros),
// degrade gracefully to raw text instead of throwing a runtime_error.
if (required < 0) {
return combined_prompt;
if (result < 0) {
return result;
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(
tmpl, fallback_msg.data(), 1, true, buffer.data(),
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
return combined_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
return result;
};
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
// Standard buffer resize if the original "system" + "user" array succeeded
// but needed more space
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages.data(), 2, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// Final safety net on resize
if (required < 0) {
return system_prompt + "\n\n" + user_prompt;
}
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1);
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
int32_t bytes = llama_token_to_piece(vocab, token, buffer.data(),
buffer.size(), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
@@ -263,12 +178,14 @@ static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
static bool ExtractFirstJsonObject(const std::string& text,
std::string& json_out) {
static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
bool found = false;
std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
@@ -303,18 +220,32 @@ static bool ExtractFirstJsonObject(const std::string& text,
}
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
candidate = text.substr(start, i - start + 1);
found = true;
}
}
}
return false;
if (!found) {
return false;
}
json_out = std::move(candidate);
return true;
}
static std::string ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
return {};
}
static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
@@ -333,9 +264,11 @@ static std::string ValidateBreweryJson(const std::string& raw,
return false;
}
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
@@ -371,7 +304,7 @@ static std::string ValidateBreweryJson(const std::string& raw,
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractFirstJsonObject(raw, extracted)) {
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
@@ -385,14 +318,14 @@ static std::string ValidateBreweryJson(const std::string& raw,
return validation_error;
}
return {};
return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
return std::nullopt;
}
// Forward declarations for helper functions exposed to other translation units
@@ -401,16 +334,6 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt, "");
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
@@ -422,8 +345,8 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
AppendTokenPiece(vocab, token, output);
}
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}