fix llama grammar

This commit is contained in:
Aaron Po
2026-04-15 23:28:27 -04:00
parent 62dfb5e14a
commit 6682b5de01
7 changed files with 23 additions and 28 deletions

View File

@@ -13,7 +13,7 @@
#include "biergarten_data_generator.h"
#include "json_handling/json_loader.h"
static constexpr size_t kBreweryAmount = 50;
static constexpr size_t kBreweryAmount = 5;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");

View File

@@ -17,7 +17,7 @@
#include "data_generation/llama_generator_helpers.h"
static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery(
root ::= ws "{" ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws
root ::= ws "{" ws "\"reasoning\"" ws ":" ws string ws "," ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws
ws ::= [ \t\n\r]*
string ::= "\"" char+ "\""
char ::= [^"\\\x7F\x00-\x1F] | [\\] escape
@@ -36,11 +36,6 @@ BreweryResult LlamaGenerator::GenerateBrewery(
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
@@ -53,9 +48,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
"## CITY:\n{}\n\n## COUNTRY:\n{}\n\n## CONTEXT:\n{}",
location.city, location.country, safe_region_context);
/**
* Store location context for retry prompts (without repeating full context)
@@ -101,7 +95,7 @@ BreweryResult LlamaGenerator::GenerateBrewery(
// Update prompt with error details to guide LLM toward correct output.
prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Return ONLY valid JSON with exactly these keys, in this exact order: {{"reasoning": "<brief planning summary>", "name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
{})",

View File

@@ -84,9 +84,8 @@ std::string PrepareRegionContext(std::string_view region_context,
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt = system_prompt;
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
std::string combined_prompt =
std::format("{}\n\n{}", system_prompt, user_prompt);
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
@@ -103,9 +102,9 @@ std::string ToChatPrompt(const llama_model* model,
constexpr std::size_t min_template_buffer_size = 1024;
std::vector<char> buffer(std::max<std::size_t>(
min_template_buffer_size,
(system_prompt.size() + user_prompt.size()) * 4));
std::vector<char> buffer(
std::max<std::size_t>(min_template_buffer_size,
(system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {

View File

@@ -101,7 +101,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
.temperature = sampling_temperature_,
.top_k = sampling_top_k_,
.top_p = sampling_top_p_,
.seed = rng_(),
.seed = static_cast<uint32_t>(rng_()),
};
auto sampler = MakeSamplerChain(vocab, sampler_config, grammar);