mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
eat: make Gemma 4 the default model, enable thinking mode
This commit is contained in:
@@ -27,6 +27,10 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
"LlamaGenerator: sampling top-p must be in (0, 1]");
|
||||
}
|
||||
|
||||
if (options.top_k == 0U) {
|
||||
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
|
||||
}
|
||||
|
||||
if (options.seed < -1) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: seed must be >= 0, or -1 for random");
|
||||
@@ -39,6 +43,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
|
||||
sampling_temperature_ = options.temperature;
|
||||
sampling_top_p_ = options.top_p;
|
||||
sampling_top_k_ = options.top_k;
|
||||
if (options.seed == -1) {
|
||||
std::random_device random_device;
|
||||
rng_.seed(random_device());
|
||||
|
||||
@@ -6,15 +6,65 @@
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include <array>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "data_generation/llama_generator.h"
|
||||
#include "data_generation/llama_generator_helpers.h"
|
||||
|
||||
BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
const std::string& city_name, const std::string& country_name,
|
||||
const std::string& region_context) {
|
||||
namespace {
|
||||
|
||||
auto ExtractFinalJsonPayload(std::string raw_response) -> std::string {
|
||||
auto trim = [](std::string_view text) -> std::string_view {
|
||||
const std::size_t first = text.find_first_not_of(" \t\n\r");
|
||||
if (first == std::string_view::npos) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const std::size_t last = text.find_last_not_of(" \t\n\r");
|
||||
return text.substr(first, last - first + 1);
|
||||
};
|
||||
|
||||
static const std::array<std::string_view, 4> separator_tokens = {
|
||||
"<|turn|>", "<turn|>", "<channel|>", "<|channel|>"};
|
||||
|
||||
std::size_t separator_pos = std::string::npos;
|
||||
std::size_t separator_length = 0;
|
||||
for (const std::string_view token : separator_tokens) {
|
||||
const std::size_t candidate_pos = raw_response.rfind(token);
|
||||
if (candidate_pos != std::string::npos &&
|
||||
(separator_pos == std::string::npos ||
|
||||
candidate_pos > separator_pos)) {
|
||||
separator_pos = candidate_pos;
|
||||
separator_length = token.size();
|
||||
}
|
||||
}
|
||||
|
||||
if (separator_pos != std::string::npos) {
|
||||
raw_response.erase(0, separator_pos + separator_length);
|
||||
}
|
||||
|
||||
const std::string_view trimmed = trim(raw_response);
|
||||
const std::size_t first_brace = trimmed.find('{');
|
||||
if (first_brace == std::string_view::npos) {
|
||||
return std::string(trimmed);
|
||||
}
|
||||
|
||||
const std::size_t last_brace = trimmed.find_last_of('}');
|
||||
if (last_brace == std::string_view::npos || last_brace < first_brace) {
|
||||
return std::string(trimmed.substr(first_brace));
|
||||
}
|
||||
|
||||
return std::string(
|
||||
trimmed.substr(first_brace, last_brace - first_brace + 1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
auto LlamaGenerator::GenerateBrewery(const BreweryLocation& location,
|
||||
const std::string& region_context)
|
||||
-> BreweryResult {
|
||||
/**
|
||||
* Preprocess and truncate region context to manageable size
|
||||
*/
|
||||
@@ -24,10 +74,9 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
/**
|
||||
* Load brewery system prompt from file
|
||||
* Falls back to minimal inline prompt if file not found
|
||||
* Default path: prompts/brewery_system_prompt_expanded.txt
|
||||
*/
|
||||
const std::string system_prompt =
|
||||
LoadBrewerySystemPrompt("prompts/brewery_system_prompt_expanded.txt");
|
||||
LoadBrewerySystemPrompt("prompts/system.md");
|
||||
|
||||
/**
|
||||
* User prompt: provides geographic context to guide generation towards
|
||||
@@ -35,21 +84,28 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
*/
|
||||
std::string prompt =
|
||||
"Write a brewery name and place-specific long description for a craft "
|
||||
"brewery in " +
|
||||
city_name +
|
||||
(country_name.empty() ? std::string("")
|
||||
: std::string(", ") + country_name) +
|
||||
(safe_region_context.empty()
|
||||
? std::string(".")
|
||||
: std::string(". Regional context: ") + safe_region_context);
|
||||
"brewery in ";
|
||||
prompt.append(location.city_name);
|
||||
if (!location.country_name.empty()) {
|
||||
prompt.append(", ");
|
||||
prompt.append(location.country_name);
|
||||
}
|
||||
if (safe_region_context.empty()) {
|
||||
prompt.append(".");
|
||||
} else {
|
||||
prompt.append(". Regional context: ");
|
||||
prompt.append(safe_region_context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Store location context for retry prompts (without repeating full context)
|
||||
*/
|
||||
const std::string retry_location =
|
||||
"Location: " + city_name +
|
||||
(country_name.empty() ? std::string("")
|
||||
: std::string(", ") + country_name);
|
||||
std::string retry_location = "Location: ";
|
||||
retry_location.append(location.city_name);
|
||||
if (!location.country_name.empty()) {
|
||||
retry_location.append(", ");
|
||||
retry_location.append(location.country_name);
|
||||
}
|
||||
|
||||
/**
|
||||
* RETRY LOOP with validation and error correction
|
||||
@@ -72,8 +128,9 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
|
||||
std::string name;
|
||||
std::string description;
|
||||
const std::string json_only = ExtractFinalJsonPayload(raw);
|
||||
const std::string validation_error =
|
||||
ValidateBreweryJsonPublic(raw, name, description);
|
||||
ValidateBreweryJsonPublic(json_only, name, description);
|
||||
if (validation_error.empty()) {
|
||||
// Success: return parsed brewery data
|
||||
return {std::move(name), std::move(description)};
|
||||
@@ -92,9 +149,9 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
"Your previous response was invalid. Error: " + validation_error +
|
||||
"\nReturn ONLY valid JSON with this exact schema: "
|
||||
"{\"name\": \"string\", \"description\": \"string\"}."
|
||||
"\nDo not include markdown, comments, or extra keys."
|
||||
"\n\n" +
|
||||
retry_location;
|
||||
"\nDo not include markdown, comments, or extra keys.";
|
||||
prompt += "\n\n";
|
||||
prompt += retry_location;
|
||||
}
|
||||
|
||||
// All retry attempts exhausted: log failure and throw exception
|
||||
|
||||
@@ -119,7 +119,8 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
/**
|
||||
* SAMPLER CONFIGURATION PHASE
|
||||
* Set up the probabilistic token selection pipeline (sampler chain)
|
||||
* Samplers are applied in sequence: temperature -> top-p -> distribution
|
||||
* Samplers are applied in sequence: temperature -> top-k -> top-p ->
|
||||
* distribution
|
||||
*/
|
||||
llama_sampler_chain_params sampler_params =
|
||||
llama_sampler_chain_default_params();
|
||||
@@ -135,6 +136,13 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
*/
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_temp(sampling_temperature_));
|
||||
/**
|
||||
* Top-K: limits sampling to the most likely tokens before nucleus
|
||||
* sampling
|
||||
*/
|
||||
llama_sampler_chain_add(
|
||||
sampler.get(),
|
||||
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
|
||||
/**
|
||||
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
|
||||
* probability
|
||||
|
||||
@@ -9,10 +9,10 @@
|
||||
|
||||
#include "data_generation/mock_generator.h"
|
||||
|
||||
std::size_t MockGenerator::DeterministicHash(const std::string& a,
|
||||
const std::string& b) {
|
||||
auto MockGenerator::DeterministicHash(const BreweryLocation& location)
|
||||
-> std::size_t {
|
||||
std::size_t seed = 0;
|
||||
boost::hash_combine(seed, a);
|
||||
boost::hash_combine(seed, b);
|
||||
boost::hash_combine(seed, location.city_name);
|
||||
boost::hash_combine(seed, location.country_name);
|
||||
return seed;
|
||||
}
|
||||
|
||||
@@ -8,11 +8,10 @@
|
||||
|
||||
#include "data_generation/mock_generator.h"
|
||||
|
||||
auto MockGenerator::GenerateBrewery(const std::string& city_name,
|
||||
const std::string& country_name,
|
||||
auto MockGenerator::GenerateBrewery(const BreweryLocation& location,
|
||||
const std::string& /*region_context*/)
|
||||
-> BreweryResult {
|
||||
const std::size_t hash = DeterministicHash(city_name, country_name);
|
||||
const std::size_t hash = DeterministicHash(location);
|
||||
|
||||
const std::string& adjective =
|
||||
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
|
||||
@@ -21,11 +20,20 @@ auto MockGenerator::GenerateBrewery(const std::string& city_name,
|
||||
const std::string& base_description =
|
||||
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
|
||||
|
||||
const std::string name = city_name + " " + adjective + " " + noun;
|
||||
const std::string description =
|
||||
base_description + " Based in " + city_name +
|
||||
(country_name.empty() ? std::string(".")
|
||||
: std::string(", ") + country_name + ".");
|
||||
std::string name(location.city_name);
|
||||
name.append(" ");
|
||||
name.append(adjective);
|
||||
name.append(" ");
|
||||
name.append(noun);
|
||||
|
||||
std::string description = base_description;
|
||||
description.append(" Based in ");
|
||||
description.append(location.city_name);
|
||||
if (!location.country_name.empty()) {
|
||||
description.append(", ");
|
||||
description.append(location.country_name);
|
||||
}
|
||||
description.append(".");
|
||||
|
||||
return {name, description};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user