mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
cleanup
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "data_generation/llama_generator.h"
|
||||
#include "data_generation/llama_generator_helpers.h"
|
||||
#include "llama.h"
|
||||
|
||||
/**
|
||||
@@ -25,12 +25,12 @@
|
||||
*/
|
||||
static std::string Trim(std::string_view value) {
|
||||
constexpr std::string_view whitespace = " \t\n\r\f\v";
|
||||
const std::size_t first_index = value.find_first_not_of(whitespace);
|
||||
const size_t first_index = value.find_first_not_of(whitespace);
|
||||
if (first_index == std::string_view::npos) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const std::size_t last_index = value.find_last_not_of(whitespace);
|
||||
const size_t last_index = value.find_last_not_of(whitespace);
|
||||
return std::string(value.substr(first_index, last_index - first_index + 1));
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ static std::string CondenseWhitespace(std::string_view text) {
|
||||
out.reserve(text.size());
|
||||
|
||||
bool pending_space = false;
|
||||
for (const unsigned char chr : text) {
|
||||
for (const char chr : text) {
|
||||
if (std::isspace(chr) != 0) {
|
||||
if (!out.empty()) {
|
||||
pending_space = true;
|
||||
@@ -55,7 +55,7 @@ static std::string CondenseWhitespace(std::string_view text) {
|
||||
out.push_back(' ');
|
||||
pending_space = false;
|
||||
}
|
||||
out.push_back(static_cast<char>(chr));
|
||||
out.push_back(chr);
|
||||
}
|
||||
|
||||
return out;
|
||||
@@ -65,8 +65,8 @@ static std::string CondenseWhitespace(std::string_view text) {
|
||||
* Truncate region context to fit within max length while preserving word
|
||||
* boundaries
|
||||
*/
|
||||
static std::string PrepareRegionContext(std::string_view region_context,
|
||||
const size_t max_chars) {
|
||||
std::string PrepareRegionContext(std::string_view region_context,
|
||||
const size_t max_chars) {
|
||||
std::string normalized = CondenseWhitespace(region_context);
|
||||
if (normalized.size() <= max_chars) {
|
||||
return normalized;
|
||||
@@ -82,11 +82,10 @@ static std::string PrepareRegionContext(std::string_view region_context,
|
||||
return normalized;
|
||||
}
|
||||
|
||||
static std::string ToChatPrompt(const llama_model* model,
|
||||
const std::string& system_prompt,
|
||||
const std::string& user_prompt) {
|
||||
std::string combined_prompt;
|
||||
combined_prompt.append(system_prompt);
|
||||
std::string ToChatPrompt(const llama_model* model,
|
||||
const std::string& system_prompt,
|
||||
const std::string& user_prompt) {
|
||||
std::string combined_prompt = system_prompt;
|
||||
combined_prompt.append("\n\n");
|
||||
combined_prompt.append(user_prompt);
|
||||
|
||||
@@ -127,7 +126,7 @@ static std::string ToChatPrompt(const llama_model* model,
|
||||
int32_t template_result = apply_template_with_resize(messages.data(), 2);
|
||||
|
||||
if (template_result >= 0) {
|
||||
return {buffer.data(), static_cast<std::size_t>(template_result)};
|
||||
return {buffer.data(), static_cast<size_t>(template_result)};
|
||||
}
|
||||
|
||||
spdlog::warn(
|
||||
@@ -151,74 +150,114 @@ static std::string ToChatPrompt(const llama_model* model,
|
||||
return combined_prompt;
|
||||
}
|
||||
|
||||
return {buffer.data(), static_cast<std::size_t>(template_result)};
|
||||
return {buffer.data(), static_cast<size_t>(template_result)};
|
||||
}
|
||||
|
||||
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||
std::string& output) {
|
||||
std::array<char, 256> buffer{};
|
||||
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||
std::string& output) {
|
||||
constexpr size_t initial_buffer_size = 256;
|
||||
|
||||
std::array<char, initial_buffer_size> buffer{};
|
||||
|
||||
// serialize the sampled token into UTF-8 bytes
|
||||
|
||||
auto buffer_too_small = [](int32_t result) -> bool { return result < 0; };
|
||||
|
||||
int32_t bytes =
|
||||
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
|
||||
|
||||
if (bytes < 0) {
|
||||
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
|
||||
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
|
||||
static_cast<int32_t>(dynamic_buffer.size()), 0,
|
||||
true);
|
||||
if (bytes < 0) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: failed to decode sampled token piece");
|
||||
}
|
||||
|
||||
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
|
||||
if (!buffer_too_small(bytes)) {
|
||||
// Append the decoded bytes from the stack buffer.
|
||||
output.append(buffer.data(), static_cast<size_t>(bytes));
|
||||
return;
|
||||
}
|
||||
|
||||
output.append(buffer.data(), static_cast<std::size_t>(bytes));
|
||||
const int32_t required_size = -bytes;
|
||||
std::vector<char> dynamic_buffer(static_cast<size_t>(required_size));
|
||||
|
||||
// Retry token decoding against the larger heap buffer.
|
||||
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
|
||||
static_cast<int32_t>(dynamic_buffer.size()), 0,
|
||||
true);
|
||||
|
||||
if (!buffer_too_small(bytes)) {
|
||||
output.append(dynamic_buffer.data(), static_cast<size_t>(bytes));
|
||||
}
|
||||
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: failed to decode sampled token piece");
|
||||
}
|
||||
|
||||
// Shared parser used by the public extractor and JSON validation.
|
||||
static bool ExtractLastJsonObject(const std::string& text,
|
||||
std::string& json_out) {
|
||||
std::size_t start = std::string::npos;
|
||||
// Remember where the most recent balanced object started.
|
||||
size_t start = std::string::npos;
|
||||
|
||||
// Track nested braces outside of quoted strings.
|
||||
int depth = 0;
|
||||
|
||||
// Track whether the scan is currently inside a quoted string.
|
||||
bool in_string = false;
|
||||
|
||||
// Track escape sequences so quotes inside strings are handled correctly.
|
||||
bool escaped = false;
|
||||
|
||||
// Record whether at least one complete object was found.
|
||||
bool found = false;
|
||||
|
||||
// Keep the latest complete object candidate.
|
||||
std::string candidate;
|
||||
|
||||
for (std::size_t i = 0; i < text.size(); ++i) {
|
||||
const char ch = text[i];
|
||||
// Scan the input text one character at a time.
|
||||
for (size_t i = 0; i < text.size(); ++i) {
|
||||
// Inspect the current character.
|
||||
const char chr = text[i];
|
||||
|
||||
// Inside a string literal, only escapes and quotes affect state.
|
||||
if (in_string) {
|
||||
if (escaped) {
|
||||
// The current character was escaped, so clear the escape flag.
|
||||
escaped = false;
|
||||
} else if (ch == '\\') {
|
||||
} else if (chr == '\\') {
|
||||
// Mark the next character as escaped.
|
||||
escaped = true;
|
||||
} else if (ch == '"') {
|
||||
} else if (chr == '"') {
|
||||
// Closing quote ends the string literal.
|
||||
in_string = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '"') {
|
||||
// Opening quotes enter string mode.
|
||||
if (chr == '"') {
|
||||
in_string = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '{') {
|
||||
// Opening braces begin or nest a JSON object.
|
||||
if (chr == '{') {
|
||||
if (depth == 0) {
|
||||
// Record the start of the outermost object.
|
||||
start = i;
|
||||
}
|
||||
|
||||
// Increase nesting depth for the active object.
|
||||
++depth;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '}') {
|
||||
// Closing braces may complete an object.
|
||||
if (chr == '}') {
|
||||
if (depth == 0) {
|
||||
// Ignore stray closing braces.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Drop one level of nesting.
|
||||
--depth;
|
||||
if (depth == 0 && start != std::string::npos) {
|
||||
// Capture the latest complete object seen so far.
|
||||
candidate = text.substr(start, i - start + 1);
|
||||
found = true;
|
||||
}
|
||||
@@ -229,22 +268,14 @@ static bool ExtractLastJsonObject(const std::string& text,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return the captured object text to the caller.
|
||||
json_out = std::move(candidate);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string ExtractLastJsonObjectPublic(const std::string& text) {
|
||||
std::string extracted;
|
||||
if (ExtractLastJsonObject(text, extracted)) {
|
||||
return extracted;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::optional<std::string> ValidateBreweryJson(
|
||||
const std::string& raw, std::string& name_out,
|
||||
std::string& description_out) {
|
||||
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||
std::string& name_out,
|
||||
std::string& description_out) {
|
||||
auto validate_object = [&](const boost::json::value& jv,
|
||||
std::string& error_out) -> bool {
|
||||
if (!jv.is_object()) {
|
||||
@@ -281,9 +312,11 @@ static std::optional<std::string> ValidateBreweryJson(
|
||||
|
||||
std::string name_lower = name_out;
|
||||
std::string description_lower = description_out;
|
||||
|
||||
std::transform(
|
||||
name_lower.begin(), name_lower.end(), name_lower.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
|
||||
std::transform(description_lower.begin(), description_lower.end(),
|
||||
description_lower.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
@@ -327,25 +360,12 @@ static std::optional<std::string> ValidateBreweryJson(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Forward declarations for helper functions exposed to other translation units
|
||||
std::string PrepareRegionContextPublic(std::string_view region_context,
|
||||
std::size_t max_chars) {
|
||||
return PrepareRegionContext(region_context, max_chars);
|
||||
}
|
||||
std::string ExtractLastJsonObject(const std::string& text) {
|
||||
// Reuse the internal parser and return an empty string if none was found.
|
||||
std::string extracted;
|
||||
if (ExtractLastJsonObject(text, extracted)) {
|
||||
return extracted;
|
||||
}
|
||||
|
||||
std::string ToChatPromptPublic(const llama_model* model,
|
||||
const std::string& system_prompt,
|
||||
const std::string& user_prompt) {
|
||||
return ToChatPrompt(model, system_prompt, user_prompt);
|
||||
}
|
||||
|
||||
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
|
||||
std::string& output) {
|
||||
AppendTokenPiece(vocab, token, output);
|
||||
}
|
||||
|
||||
std::optional<std::string> ValidateBreweryJsonPublic(
|
||||
const std::string& raw, std::string& name_out,
|
||||
std::string& description_out) {
|
||||
return ValidateBreweryJson(raw, name_out, description_out);
|
||||
return {};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user