mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Enhance ValidateBreweryJson to include reasoning output and update GenerateBrewery to use user_prompt
Add gemma parser
This commit is contained in:
@@ -107,6 +107,7 @@ set(SOURCES
|
|||||||
src/data_generation/llama/infer.cc
|
src/data_generation/llama/infer.cc
|
||||||
src/data_generation/llama/load.cc
|
src/data_generation/llama/load.cc
|
||||||
src/data_generation/llama/load_brewery_prompt.cc
|
src/data_generation/llama/load_brewery_prompt.cc
|
||||||
|
src/data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.cc
|
||||||
src/data_generation/mock/deterministic_hash.cc
|
src/data_generation/mock/deterministic_hash.cc
|
||||||
src/data_generation/mock/generate_brewery.cc
|
src/data_generation/mock/generate_brewery.cc
|
||||||
src/data_generation/mock/generate_user.cc
|
src/data_generation/mock/generate_user.cc
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
|
||||||
#include "data_generation/data_generator.h"
|
#include "data_generation/data_generator.h"
|
||||||
|
#include "data_generation/prompt_formatting/prompt_formatter.h"
|
||||||
#include "data_model/application_options.h"
|
#include "data_model/application_options.h"
|
||||||
|
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
@@ -31,9 +32,11 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
*
|
*
|
||||||
* @param options Parsed application options.
|
* @param options Parsed application options.
|
||||||
* @param model_path Filesystem path to GGUF model assets.
|
* @param model_path Filesystem path to GGUF model assets.
|
||||||
|
* @param prompt_formatter Formatter that produces model-specific prompts.
|
||||||
*/
|
*/
|
||||||
LlamaGenerator(const ApplicationOptions& options,
|
LlamaGenerator(const ApplicationOptions& options,
|
||||||
const std::string& model_path);
|
const std::string& model_path,
|
||||||
|
std::shared_ptr<IPromptFormatter> prompt_formatter);
|
||||||
|
|
||||||
~LlamaGenerator() override;
|
~LlamaGenerator() override;
|
||||||
|
|
||||||
@@ -132,6 +135,7 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
std::mt19937 rng_;
|
std::mt19937 rng_;
|
||||||
uint32_t n_ctx_ = kDefaultContextSize;
|
uint32_t n_ctx_ = kDefaultContextSize;
|
||||||
std::string brewery_system_prompt_;
|
std::string brewery_system_prompt_;
|
||||||
|
std::shared_ptr<IPromptFormatter> prompt_formatter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
|
||||||
struct llama_model;
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
using llama_token = int32_t;
|
using llama_token = int32_t;
|
||||||
|
|
||||||
@@ -26,18 +25,6 @@ using llama_token = int32_t;
|
|||||||
std::string PrepareRegionContext(std::string_view region_context,
|
std::string PrepareRegionContext(std::string_view region_context,
|
||||||
size_t max_chars = 2000);
|
size_t max_chars = 2000);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Applies model chat template to system and user prompts.
|
|
||||||
*
|
|
||||||
* @param model Loaded llama model.
|
|
||||||
* @param system_prompt System prompt text.
|
|
||||||
* @param user_prompt User prompt text.
|
|
||||||
* @return Model-formatted prompt.
|
|
||||||
*/
|
|
||||||
std::string ToChatPrompt(const llama_model* model,
|
|
||||||
const std::string& system_prompt,
|
|
||||||
const std::string& user_prompt);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Decodes a sampled token and appends it to output text.
|
* @brief Decodes a sampled token and appends it to output text.
|
||||||
*
|
*
|
||||||
@@ -58,6 +45,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
|||||||
*/
|
*/
|
||||||
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||||
std::string& name_out,
|
std::string& name_out,
|
||||||
std::string& description_out);
|
std::string& description_out,
|
||||||
|
std::string& reasoning_out);
|
||||||
|
|
||||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
#include "data_generation/prompt_formatting/prompt_formatter.h"
|
||||||
|
|
||||||
|
class Gemma4JinjaPromptFormatter final : public IPromptFormatter {
|
||||||
|
public:
|
||||||
|
Gemma4JinjaPromptFormatter() = default;
|
||||||
|
~Gemma4JinjaPromptFormatter() override = default;
|
||||||
|
|
||||||
|
[[nodiscard]] std::string Format(std::string_view system_prompt,
|
||||||
|
std::string_view user_prompt) const override;
|
||||||
|
};
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
class IPromptFormatter {
|
||||||
|
public:
|
||||||
|
IPromptFormatter() = default;
|
||||||
|
IPromptFormatter(const IPromptFormatter&) = delete;
|
||||||
|
IPromptFormatter& operator=(const IPromptFormatter&) = delete;
|
||||||
|
IPromptFormatter(IPromptFormatter&&) = delete;
|
||||||
|
IPromptFormatter& operator=(IPromptFormatter&&) = delete;
|
||||||
|
virtual ~IPromptFormatter() = default;
|
||||||
|
|
||||||
|
[[nodiscard]] virtual std::string Format(
|
||||||
|
std::string_view system_prompt,
|
||||||
|
std::string_view user_prompt) const = 0;
|
||||||
|
};
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
<|think|>
|
|
||||||
Return only one raw JSON object as the final answer, with exactly three keys: "reasoning", "name", and "description".
|
Return only one raw JSON object as the final answer, with exactly three keys: "reasoning", "name", and "description".
|
||||||
The "reasoning" key MUST be the first key in the object.
|
The "reasoning" key MUST be the first key in the object.
|
||||||
No markdown, code fences, preamble, or extra keys.
|
No markdown, code fences, preamble, or extra keys.
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ escape ::= ["\\/bfnrt] | "u" hex hex hex hex
|
|||||||
hex ::= [0-9a-fA-F]
|
hex ::= [0-9a-fA-F]
|
||||||
)json_brewery";
|
)json_brewery";
|
||||||
|
|
||||||
|
static constexpr int kBreweryInitialMaxTokens = 2800;
|
||||||
|
static constexpr int kBreweryTruncationRetryTokenBump = 700;
|
||||||
|
static constexpr int kBreweryMaxTokensCeiling = 5000;
|
||||||
|
|
||||||
BreweryResult LlamaGenerator::GenerateBrewery(
|
BreweryResult LlamaGenerator::GenerateBrewery(
|
||||||
const Location& location, const std::string& region_context) {
|
const Location& location, const std::string& region_context) {
|
||||||
/**
|
/**
|
||||||
@@ -43,11 +47,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
const std::string system_prompt =
|
const std::string system_prompt =
|
||||||
LoadBrewerySystemPrompt("prompts/system.md");
|
LoadBrewerySystemPrompt("prompts/system.md");
|
||||||
|
|
||||||
/**
|
|
||||||
* User prompt: provides geographic context to guide generation towards
|
std::string user_prompt = std::format(
|
||||||
* culturally relevant and locally-inspired brewery attributes
|
|
||||||
*/
|
|
||||||
std::string prompt = std::format(
|
|
||||||
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}",
|
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}",
|
||||||
location.city, location.country, safe_region_context);
|
location.city, location.country, safe_region_context);
|
||||||
|
|
||||||
@@ -66,11 +67,14 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
std::string raw;
|
std::string raw;
|
||||||
std::string last_error;
|
std::string last_error;
|
||||||
|
|
||||||
|
// Token budget: too small risks truncating valid JSON mid-string.
|
||||||
|
// Start conservatively but allow adaptive increases on truncation.
|
||||||
|
int max_tokens = kBreweryInitialMaxTokens;
|
||||||
|
|
||||||
// Limit output length to keep it concise and focused
|
// Limit output length to keep it concise and focused
|
||||||
for (int attempt = 0; attempt < max_attempts; ++attempt) {
|
for (int attempt = 0; attempt < max_attempts; ++attempt) {
|
||||||
constexpr int max_tokens = 1052;
|
|
||||||
// Generate brewery data from LLM
|
// Generate brewery data from LLM
|
||||||
raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar);
|
raw = this->Infer(system_prompt, user_prompt, max_tokens, kBreweryJsonGrammar);
|
||||||
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
|
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
|
||||||
raw);
|
raw);
|
||||||
|
|
||||||
@@ -78,10 +82,16 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
|
std::string reasoning;
|
||||||
const std::optional<std::string> validation_error =
|
const std::optional<std::string> validation_error =
|
||||||
ValidateBreweryJson(raw, name, description);
|
ValidateBreweryJson(raw, name, description, reasoning);
|
||||||
if (!validation_error.has_value()) {
|
if (!validation_error.has_value()) {
|
||||||
// Success: return parsed brewery data
|
// Success: return parsed brewery data
|
||||||
|
|
||||||
|
spdlog::info(
|
||||||
|
"LlamaGenerator: successfully generated brewery data on attempt {}:\n reasoning='{}',\n name='{}',\n description='{}'",
|
||||||
|
attempt + 1, reasoning, name, description);
|
||||||
|
|
||||||
return BreweryResult{.name = std::move(name),
|
return BreweryResult{.name = std::move(name),
|
||||||
.description = std::move(description)};
|
.description = std::move(description)};
|
||||||
}
|
}
|
||||||
@@ -92,12 +102,27 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
|
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
|
||||||
attempt + 1, *validation_error);
|
attempt + 1, *validation_error);
|
||||||
|
|
||||||
|
|
||||||
|
if (last_error == "JSON parse error: incomplete JSON") {
|
||||||
|
const int previous_max_tokens = max_tokens;
|
||||||
|
max_tokens = std::min(max_tokens + kBreweryTruncationRetryTokenBump,
|
||||||
|
kBreweryMaxTokensCeiling);
|
||||||
|
spdlog::info(
|
||||||
|
"LlamaGenerator: detected truncated JSON; increasing max_tokens from {} to {} and retrying",
|
||||||
|
previous_max_tokens, max_tokens);
|
||||||
|
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Update prompt with error details to guide LLM toward correct output.
|
// Update prompt with error details to guide LLM toward correct output.
|
||||||
prompt = std::format(
|
user_prompt = std::format(
|
||||||
R"(Your previous response was invalid. Error: {}
|
R"(Your previous response was invalid. Error: {}
|
||||||
Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "<brief planning summary>", "name": "<brewery name>", "description": "<single-paragraph description>"}}.
|
Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "<brief planning summary>", "name": "<brewery name>", "description": "<single-paragraph description>"}}.
|
||||||
Do not include markdown, comments, extra keys, or literal placeholder values.
|
Do not include markdown, comments, extra keys, or literal placeholder values.
|
||||||
|
|
||||||
|
Keep the JSON strings concise enough to fit within the token budget.
|
||||||
|
|
||||||
{})",
|
{})",
|
||||||
*validation_error, retry_location);
|
*validation_error, retry_location);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,6 @@
|
|||||||
* parsing, token decoding, and JSON validation helpers for Llama modules.
|
* parsing, token decoding, and JSON validation helpers for Llama modules.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <boost/json.hpp>
|
#include <boost/json.hpp>
|
||||||
@@ -81,89 +79,6 @@ std::string PrepareRegionContext(std::string_view region_context,
|
|||||||
return normalized;
|
return normalized;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ToChatPrompt(const llama_model* model,
|
|
||||||
const std::string& system_prompt,
|
|
||||||
const std::string& user_prompt) {
|
|
||||||
std::string combined_prompt =
|
|
||||||
std::format("{}\n\n{}", system_prompt, user_prompt);
|
|
||||||
|
|
||||||
const char* template_str = llama_model_chat_template(model, nullptr);
|
|
||||||
|
|
||||||
// If metadata is missing (nullptr), attempt to use the built-in "gemma" alias
|
|
||||||
// to leverage the library's interleaved template for Gemma 4 support.
|
|
||||||
if (template_str == nullptr) {
|
|
||||||
template_str = "gemma";
|
|
||||||
spdlog::info(
|
|
||||||
"LlamaGenerator: model chat template metadata missing; attempting "
|
|
||||||
"built-in 'gemma' alias");
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::array<llama_chat_message, 2> messages = {{
|
|
||||||
{.role = "system", .content = system_prompt.c_str()},
|
|
||||||
{.role = "user", .content = user_prompt.c_str()},
|
|
||||||
}};
|
|
||||||
|
|
||||||
constexpr std::size_t min_template_buffer_size = 1024;
|
|
||||||
|
|
||||||
std::vector<char> buffer(
|
|
||||||
std::max<std::size_t>(min_template_buffer_size,
|
|
||||||
(system_prompt.size() + user_prompt.size()) * 4));
|
|
||||||
|
|
||||||
auto apply_template_with_resize = [&](const char* tmpl,
|
|
||||||
const llama_chat_message* chat_messages,
|
|
||||||
int32_t message_count) -> int32_t {
|
|
||||||
int32_t result = llama_chat_apply_template(
|
|
||||||
tmpl, chat_messages, message_count, true, buffer.data(),
|
|
||||||
static_cast<int32_t>(buffer.size()));
|
|
||||||
|
|
||||||
if (result < 0) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto buffer_size = static_cast<int32_t>(buffer.size());
|
|
||||||
if (result >= buffer_size) {
|
|
||||||
buffer.resize(static_cast<std::size_t>(result) + 1);
|
|
||||||
result = llama_chat_apply_template(
|
|
||||||
tmpl, chat_messages, message_count, true, buffer.data(),
|
|
||||||
static_cast<int32_t>(buffer.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
|
|
||||||
int32_t template_result =
|
|
||||||
apply_template_with_resize(template_str, messages.data(), 2);
|
|
||||||
|
|
||||||
if (template_result >= 0) {
|
|
||||||
return {buffer.data(), static_cast<size_t>(template_result)};
|
|
||||||
}
|
|
||||||
|
|
||||||
spdlog::warn(
|
|
||||||
"LlamaGenerator: chat template rejected system/user messages (result "
|
|
||||||
"{}); trying single user fallback",
|
|
||||||
template_result);
|
|
||||||
|
|
||||||
// FALLBACK: If the template fails (e.g., model rejecting the "system" role),
|
|
||||||
// combine the system and user prompts into a single "user" message.
|
|
||||||
const std::array<llama_chat_message, 1> fallback_msg = {{
|
|
||||||
{.role = "user", .content = combined_prompt.c_str()},
|
|
||||||
}};
|
|
||||||
|
|
||||||
template_result =
|
|
||||||
apply_template_with_resize(template_str, fallback_msg.data(), 1);
|
|
||||||
|
|
||||||
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
|
|
||||||
if (template_result < 0) {
|
|
||||||
spdlog::warn(
|
|
||||||
"LlamaGenerator: chat template fallback failed (result {}); using "
|
|
||||||
"raw prompt text",
|
|
||||||
template_result);
|
|
||||||
return combined_prompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {buffer.data(), static_cast<size_t>(template_result)};
|
|
||||||
}
|
|
||||||
|
|
||||||
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||||
std::string& output) {
|
std::string& output) {
|
||||||
constexpr size_t initial_buffer_size = 256;
|
constexpr size_t initial_buffer_size = 256;
|
||||||
@@ -193,6 +108,7 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
|||||||
|
|
||||||
if (!buffer_too_small(bytes)) {
|
if (!buffer_too_small(bytes)) {
|
||||||
output.append(dynamic_buffer.data(), static_cast<size_t>(bytes));
|
output.append(dynamic_buffer.data(), static_cast<size_t>(bytes));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -201,7 +117,8 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
|||||||
|
|
||||||
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||||
std::string& name_out,
|
std::string& name_out,
|
||||||
std::string& description_out) {
|
std::string& description_out,
|
||||||
|
std::string& reasoning_out) {
|
||||||
auto validate_object = [&](const boost::json::value& json_value,
|
auto validate_object = [&](const boost::json::value& json_value,
|
||||||
std::string& error_out) -> bool {
|
std::string& error_out) -> bool {
|
||||||
if (!json_value.is_object()) {
|
if (!json_value.is_object()) {
|
||||||
@@ -209,7 +126,14 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const auto& obj = json_value.get_object();
|
const auto& obj = json_value.get_object();
|
||||||
|
|
||||||
|
if (!obj.contains("reasoning") || !obj.at("reasoning").is_string()) {
|
||||||
|
error_out = "JSON field 'reasoning' is missing or not a string";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!obj.contains("name") || !obj.at("name").is_string()) {
|
if (!obj.contains("name") || !obj.at("name").is_string()) {
|
||||||
error_out = "JSON field 'name' is missing or not a string";
|
error_out = "JSON field 'name' is missing or not a string";
|
||||||
return false;
|
return false;
|
||||||
@@ -219,6 +143,12 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
error_out = "JSON field 'description' is missing or not a string";
|
error_out = "JSON field 'description' is missing or not a string";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
const auto& reasoning_value = obj.at("reasoning").as_string();
|
||||||
|
reasoning_out = Trim(std::string_view(reasoning_value.data(), reasoning_value.size()));
|
||||||
|
if (reasoning_out.empty()) {
|
||||||
|
error_out = "JSON field 'reasoning' must not be empty";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
const auto& name_value = obj.at("name").as_string();
|
const auto& name_value = obj.at("name").as_string();
|
||||||
const auto& description_value = obj.at("description").as_string();
|
const auto& description_value = obj.at("description").as_string();
|
||||||
@@ -239,15 +169,16 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
std::string name_lower = name_out;
|
std::string name_lower = name_out;
|
||||||
std::string description_lower = description_out;
|
std::string description_lower = description_out;
|
||||||
|
|
||||||
std::ranges::transform(name_lower, name_lower.begin(),
|
|
||||||
[](unsigned char character) {
|
|
||||||
return static_cast<char>(std::tolower(character));
|
|
||||||
});
|
|
||||||
|
|
||||||
std::ranges::transform(description_lower, description_lower.begin(),
|
auto string_to_lower = [](std::string& str_out) {
|
||||||
[](unsigned char character) {
|
std::ranges::transform(str_out, str_out.begin(),
|
||||||
return static_cast<char>(std::tolower(character));
|
[](unsigned char character) {
|
||||||
});
|
return static_cast<char>(std::tolower(character));
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
string_to_lower(name_lower);
|
||||||
|
string_to_lower(description_lower);
|
||||||
|
|
||||||
if (name_lower == "string" || description_lower == "string") {
|
if (name_lower == "string" || description_lower == "string") {
|
||||||
error_out = "JSON appears to be a schema placeholder, not content";
|
error_out = "JSON appears to be a schema placeholder, not content";
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
|||||||
const std::string& prompt,
|
const std::string& prompt,
|
||||||
const int max_tokens,
|
const int max_tokens,
|
||||||
std::string_view grammar) {
|
std::string_view grammar) {
|
||||||
return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt),
|
return InferFormatted(prompt_formatter_->Format(system_prompt, prompt),
|
||||||
max_tokens, grammar);
|
max_tokens, grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,12 +31,19 @@ void LlamaGenerator::ContextDeleter::operator()(
|
|||||||
}
|
}
|
||||||
|
|
||||||
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||||
const std::string& model_path)
|
const std::string& model_path,
|
||||||
: rng_(std::random_device{}()) {
|
std::shared_ptr<IPromptFormatter> prompt_formatter)
|
||||||
|
: rng_(std::random_device{}()),
|
||||||
|
prompt_formatter_(std::move(prompt_formatter)) {
|
||||||
if (model_path.empty()) {
|
if (model_path.empty()) {
|
||||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!prompt_formatter_) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LlamaGenerator: prompt formatter dependency must not be null");
|
||||||
|
}
|
||||||
|
|
||||||
if (options.temperature < 0.0F) {
|
if (options.temperature < 0.0F) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: sampling temperature must be >= 0");
|
"LlamaGenerator: sampling temperature must be >= 0");
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
|
||||||
|
|
||||||
|
#include <format>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
static constexpr std::string_view kWhitespace = " \t\n\r\f\v";
|
||||||
|
|
||||||
|
// Strips leading and trailing whitespace to ensure clean prompt injection.
|
||||||
|
static std::string_view Trim(std::string_view value) {
|
||||||
|
const size_t first_index = value.find_first_not_of(kWhitespace);
|
||||||
|
|
||||||
|
const bool is_all_whitespace = (first_index == std::string_view::npos);
|
||||||
|
if (is_all_whitespace) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t last_index = value.find_last_not_of(kWhitespace);
|
||||||
|
return value.substr(first_index, last_index - first_index + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Gemma4JinjaPromptFormatter::Format(
|
||||||
|
std::string_view system_prompt, std::string_view user_prompt) const {
|
||||||
|
std::string_view trimmed_system = Trim(system_prompt);
|
||||||
|
std::string_view trimmed_user = Trim(user_prompt);
|
||||||
|
|
||||||
|
return std::format(
|
||||||
|
"<|turn|>system\n<|think|>\n{}\n<|turn|>\n"
|
||||||
|
"<|turn|>user\n{}\n<|turn|>\n"
|
||||||
|
"<|turn|>model\n<|channel>thought\n",
|
||||||
|
trimmed_system, trimmed_user);
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@
|
|||||||
#include "biergarten_data_generator.h"
|
#include "biergarten_data_generator.h"
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
#include "data_generation/mock_generator.h"
|
#include "data_generation/mock_generator.h"
|
||||||
|
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
|
||||||
#include "data_model/application_options.h"
|
#include "data_model/application_options.h"
|
||||||
#include "llama_backend_state.h"
|
#include "llama_backend_state.h"
|
||||||
#include "services/enrichment_service.h"
|
#include "services/enrichment_service.h"
|
||||||
@@ -147,6 +148,7 @@ int main(const int argc, char** argv) {
|
|||||||
di::bind<WebClient>().to<CURLWebClient>(),
|
di::bind<WebClient>().to<CURLWebClient>(),
|
||||||
di::bind<ApplicationOptions>().to(options),
|
di::bind<ApplicationOptions>().to(options),
|
||||||
di::bind<IEnrichmentService>().to<WikipediaService>(),
|
di::bind<IEnrichmentService>().to<WikipediaService>(),
|
||||||
|
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
|
||||||
di::bind<std::string>().to(options.model_path),
|
di::bind<std::string>().to(options.model_path),
|
||||||
di::bind<DataGenerator>().to(
|
di::bind<DataGenerator>().to(
|
||||||
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
|
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
|
||||||
|
|||||||
Reference in New Issue
Block a user