mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
fix llama grammar
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
#include "biergarten_data_generator.h"
|
||||
#include "json_handling/json_loader.h"
|
||||
|
||||
static constexpr size_t kBreweryAmount = 50;
|
||||
static constexpr size_t kBreweryAmount = 5;
|
||||
|
||||
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
||||
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#include "data_generation/llama_generator_helpers.h"
|
||||
|
||||
static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery(
|
||||
root ::= ws "{" ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws
|
||||
root ::= ws "{" ws "\"reasoning\"" ws ":" ws string ws "," ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws
|
||||
ws ::= [ \t\n\r]*
|
||||
string ::= "\"" char+ "\""
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] escape
|
||||
@@ -36,11 +36,6 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
const std::string country_suffix =
|
||||
location.country.empty() ? std::string{}
|
||||
: std::format(", {}", location.country);
|
||||
const std::string region_suffix =
|
||||
safe_region_context.empty()
|
||||
? "."
|
||||
: std::format(". Regional context: {}", safe_region_context);
|
||||
|
||||
/**
|
||||
* Load brewery system prompt from file
|
||||
* Falls back to minimal inline prompt if file not found
|
||||
@@ -53,9 +48,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
* culturally relevant and locally-inspired brewery attributes
|
||||
*/
|
||||
std::string prompt = std::format(
|
||||
"Write a brewery name and place-specific long description for a craft "
|
||||
"brewery in {}{}{}",
|
||||
location.city, country_suffix, region_suffix);
|
||||
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}",
|
||||
location.city, location.country, safe_region_context);
|
||||
|
||||
/**
|
||||
* Store location context for retry prompts (without repeating full context)
|
||||
@@ -101,7 +95,7 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
// Update prompt with error details to guide LLM toward correct output.
|
||||
prompt = std::format(
|
||||
R"(Your previous response was invalid. Error: {}
|
||||
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
|
||||
Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "<brief planning summary>", "name": "<brewery name>", "description": "<single-paragraph description>"}}.
|
||||
Do not include markdown, comments, extra keys, or literal placeholder values.
|
||||
|
||||
{})",
|
||||
|
||||
@@ -84,9 +84,8 @@ std::string PrepareRegionContext(std::string_view region_context,
|
||||
std::string ToChatPrompt(const llama_model* model,
|
||||
const std::string& system_prompt,
|
||||
const std::string& user_prompt) {
|
||||
std::string combined_prompt = system_prompt;
|
||||
combined_prompt.append("\n\n");
|
||||
combined_prompt.append(user_prompt);
|
||||
std::string combined_prompt =
|
||||
std::format("{}\n\n{}", system_prompt, user_prompt);
|
||||
|
||||
const char* tmpl = llama_model_chat_template(model, nullptr);
|
||||
if (tmpl == nullptr) {
|
||||
@@ -103,9 +102,9 @@ std::string ToChatPrompt(const llama_model* model,
|
||||
|
||||
constexpr std::size_t min_template_buffer_size = 1024;
|
||||
|
||||
std::vector<char> buffer(std::max<std::size_t>(
|
||||
min_template_buffer_size,
|
||||
(system_prompt.size() + user_prompt.size()) * 4));
|
||||
std::vector<char> buffer(
|
||||
std::max<std::size_t>(min_template_buffer_size,
|
||||
(system_prompt.size() + user_prompt.size()) * 4));
|
||||
|
||||
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
|
||||
int32_t message_count) -> int32_t {
|
||||
|
||||
@@ -101,7 +101,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
.temperature = sampling_temperature_,
|
||||
.top_k = sampling_top_k_,
|
||||
.top_p = sampling_top_p_,
|
||||
.seed = rng_(),
|
||||
.seed = static_cast<uint32_t>(rng_()),
|
||||
};
|
||||
auto sampler = MakeSamplerChain(vocab, sampler_config, grammar);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user