This commit is contained in:
Aaron Po
2026-04-15 00:22:15 -04:00
parent 15853c62fd
commit ddf4bcb981
12 changed files with 198 additions and 198 deletions

View File

@@ -17,7 +17,7 @@
#include <string_view>
#include <vector>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
/**
@@ -25,12 +25,12 @@
*/
static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
const size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
const std::size_t last_index = value.find_last_not_of(whitespace);
const size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
}
@@ -43,7 +43,7 @@ static std::string CondenseWhitespace(std::string_view text) {
out.reserve(text.size());
bool pending_space = false;
for (const unsigned char chr : text) {
for (const char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
@@ -55,7 +55,7 @@ static std::string CondenseWhitespace(std::string_view text) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
out.push_back(chr);
}
return out;
@@ -65,8 +65,8 @@ static std::string CondenseWhitespace(std::string_view text) {
* Truncate region context to fit within max length while preserving word
* boundaries
*/
static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) {
std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) {
return normalized;
@@ -82,11 +82,10 @@ static std::string PrepareRegionContext(std::string_view region_context,
return normalized;
}
static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt = system_prompt;
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
@@ -127,7 +126,7 @@ static std::string ToChatPrompt(const llama_model* model,
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
return {buffer.data(), static_cast<size_t>(template_result)};
}
spdlog::warn(
@@ -151,74 +150,114 @@ static std::string ToChatPrompt(const llama_model* model,
return combined_prompt;
}
return {buffer.data(), static_cast<std::size_t>(template_result)};
return {buffer.data(), static_cast<size_t>(template_result)};
}
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
constexpr size_t initial_buffer_size = 256;
std::array<char, initial_buffer_size> buffer{};
// serialize the sampled token into UTF-8 bytes
auto buffer_too_small = [](int32_t result) -> bool { return result < 0; };
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
if (!buffer_too_small(bytes)) {
// Append the decoded bytes from the stack buffer.
output.append(buffer.data(), static_cast<size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
const int32_t required_size = -bytes;
std::vector<char> dynamic_buffer(static_cast<size_t>(required_size));
// Retry token decoding against the larger heap buffer.
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0,
true);
if (!buffer_too_small(bytes)) {
output.append(dynamic_buffer.data(), static_cast<size_t>(bytes));
}
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
// Shared parser used by the public extractor and JSON validation.
static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) {
std::size_t start = std::string::npos;
// Remember where the most recent balanced object started.
size_t start = std::string::npos;
// Track nested braces outside of quoted strings.
int depth = 0;
// Track whether the scan is currently inside a quoted string.
bool in_string = false;
// Track escape sequences so quotes inside strings are handled correctly.
bool escaped = false;
// Record whether at least one complete object was found.
bool found = false;
// Keep the latest complete object candidate.
std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
// Scan the input text one character at a time.
for (size_t i = 0; i < text.size(); ++i) {
// Inspect the current character.
const char chr = text[i];
// Inside a string literal, only escapes and quotes affect state.
if (in_string) {
if (escaped) {
// The current character was escaped, so clear the escape flag.
escaped = false;
} else if (ch == '\\') {
} else if (chr == '\\') {
// Mark the next character as escaped.
escaped = true;
} else if (ch == '"') {
} else if (chr == '"') {
// Closing quote ends the string literal.
in_string = false;
}
continue;
}
if (ch == '"') {
// Opening quotes enter string mode.
if (chr == '"') {
in_string = true;
continue;
}
if (ch == '{') {
// Opening braces begin or nest a JSON object.
if (chr == '{') {
if (depth == 0) {
// Record the start of the outermost object.
start = i;
}
// Increase nesting depth for the active object.
++depth;
continue;
}
if (ch == '}') {
// Closing braces may complete an object.
if (chr == '}') {
if (depth == 0) {
// Ignore stray closing braces.
continue;
}
// Drop one level of nesting.
--depth;
if (depth == 0 && start != std::string::npos) {
// Capture the latest complete object seen so far.
candidate = text.substr(start, i - start + 1);
found = true;
}
@@ -229,22 +268,14 @@ static bool ExtractLastJsonObject(const std::string& text,
return false;
}
// Return the captured object text to the caller.
json_out = std::move(candidate);
return true;
}
std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
return {};
}
static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out,
std::string& description_out) {
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
@@ -281,9 +312,11 @@ static std::optional<std::string> ValidateBreweryJson(
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
@@ -327,25 +360,12 @@ static std::optional<std::string> ValidateBreweryJson(
return std::nullopt;
}
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars);
}
std::string ExtractLastJsonObject(const std::string& text) {
// Reuse the internal parser and return an empty string if none was found.
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt);
}
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) {
AppendTokenPiece(vocab, token, output);
}
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
return {};
}