From ac136f717921c9af16328e4fb0f2e490db84d45d Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Thu, 2 Apr 2026 01:04:41 -0400 Subject: [PATCH] Enhance brewery generation: add country name parameter and improve prompt handling --- pipeline/includes/data_generator.h | 1 + pipeline/includes/llama_generator.h | 9 +- pipeline/includes/mock_generator.h | 1 + pipeline/src/llama_generator.cpp | 340 ++++++++++++++++++++++++++-- pipeline/src/main.cpp | 42 ++-- pipeline/src/mock_generator.cpp | 7 +- 6 files changed, 357 insertions(+), 43 deletions(-) diff --git a/pipeline/includes/data_generator.h b/pipeline/includes/data_generator.h index 61ae07e..b3f324e 100644 --- a/pipeline/includes/data_generator.h +++ b/pipeline/includes/data_generator.h @@ -19,6 +19,7 @@ public: virtual void load(const std::string &modelPath) = 0; virtual BreweryResult generateBrewery(const std::string &cityName, + const std::string &countryName, const std::string ®ionContext) = 0; virtual UserResult generateUser(const std::string &locale) = 0; diff --git a/pipeline/includes/llama_generator.h b/pipeline/includes/llama_generator.h index 0865ecc..8696602 100644 --- a/pipeline/includes/llama_generator.h +++ b/pipeline/includes/llama_generator.h @@ -13,11 +13,18 @@ public: void load(const std::string &modelPath) override; BreweryResult generateBrewery(const std::string &cityName, + const std::string &countryName, const std::string ®ionContext) override; UserResult generateUser(const std::string &locale) override; private: - std::string infer(const std::string &prompt, int maxTokens = 256); + std::string infer(const std::string &prompt, int maxTokens = 5000); + // Overload that allows passing a system message separately so chat-capable + // models receive a proper system role instead of having the system text + // concatenated into the user prompt (helps avoid revealing internal + // reasoning or instructions in model output). + std::string infer(const std::string &systemPrompt, const std::string &prompt, + int maxTokens = 5000); llama_model *model_ = nullptr; llama_context *context_ = nullptr; diff --git a/pipeline/includes/mock_generator.h b/pipeline/includes/mock_generator.h index efc4d3f..ca3f1d7 100644 --- a/pipeline/includes/mock_generator.h +++ b/pipeline/includes/mock_generator.h @@ -8,6 +8,7 @@ class MockGenerator final : public IDataGenerator { public: void load(const std::string &modelPath) override; BreweryResult generateBrewery(const std::string &cityName, + const std::string &countryName, const std::string ®ionContext) override; UserResult generateUser(const std::string &locale) override; diff --git a/pipeline/src/llama_generator.cpp b/pipeline/src/llama_generator.cpp index 952f5fc..bc53ff6 100644 --- a/pipeline/src/llama_generator.cpp +++ b/pipeline/src/llama_generator.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -25,6 +26,117 @@ std::string trim(std::string value) { return value; } +std::string stripCommonPrefix(std::string line) { + line = trim(std::move(line)); + + // Strip simple list markers like "- ", "* ", "1. ", "2) ". + if (!line.empty() && (line[0] == '-' || line[0] == '*')) { + line = trim(line.substr(1)); + } else { + std::size_t i = 0; + while (i < line.size() && + std::isdigit(static_cast(line[i]))) { + ++i; + } + if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) { + line = trim(line.substr(i + 1)); + } + } + + auto stripLabel = [&line](const std::string &label) { + if (line.size() >= label.size()) { + bool matches = true; + for (std::size_t i = 0; i < label.size(); ++i) { + if (std::tolower(static_cast(line[i])) != + std::tolower(static_cast(label[i]))) { + matches = false; + break; + } + } + if (matches) { + line = trim(line.substr(label.size())); + } + } + }; + + stripLabel("name:"); + stripLabel("brewery name:"); + stripLabel("description:"); + stripLabel("username:"); + stripLabel("bio:"); + + return trim(std::move(line)); +} + +std::string toChatPrompt(const llama_model *model, + const std::string &userPrompt) { + const char *tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + return userPrompt; + } + + const llama_chat_message message{ + "user", + userPrompt.c_str(), + }; + + std::vector buffer(std::max(1024, userPrompt.size() * 4)); + int32_t required = + llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), + static_cast(buffer.size())); + + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + + if (required >= static_cast(buffer.size())) { + buffer.resize(static_cast(required) + 1); + required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), + static_cast(buffer.size())); + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + } + + return std::string(buffer.data(), static_cast(required)); +} + +std::string toChatPrompt(const llama_model *model, + const std::string &systemPrompt, + const std::string &userPrompt) { + const char *tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + // Fall back to concatenating but keep system and user parts distinct. + return systemPrompt + "\n\n" + userPrompt; + } + + const llama_chat_message messages[2] = { + {"system", systemPrompt.c_str()}, + {"user", userPrompt.c_str()}, + }; + + std::vector buffer(std::max( + 1024, (systemPrompt.size() + userPrompt.size()) * 4)); + int32_t required = + llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), + static_cast(buffer.size())); + + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + + if (required >= static_cast(buffer.size())) { + buffer.resize(static_cast(required) + 1); + required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), + static_cast(buffer.size())); + if (required < 0) { + throw std::runtime_error("LlamaGenerator: failed to apply chat template"); + } + } + + return std::string(buffer.data(), static_cast(required)); +} + void appendTokenPiece(const llama_vocab *vocab, llama_token token, std::string &output) { std::array buffer{}; @@ -51,13 +163,63 @@ void appendTokenPiece(const llama_vocab *vocab, llama_token token, std::pair parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) { - const auto newlinePos = raw.find('\n'); - if (newlinePos == std::string::npos) { + std::string normalized = raw; + std::replace(normalized.begin(), normalized.end(), '\r', '\n'); + + std::vector lines; + std::stringstream stream(normalized); + std::string line; + while (std::getline(stream, line)) { + line = stripCommonPrefix(std::move(line)); + if (!line.empty()) { + lines.push_back(std::move(line)); + } + } + + // Filter out obvious internal-thought / meta lines that sometimes leak from + // models (e.g. "", "Okay, so the user is asking me..."). + std::vector filtered; + for (auto &l : lines) { + std::string low = l; + std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + + // Skip single-token angle-bracket markers like or <...> + if (!l.empty() && l.front() == '<' && l.back() == '>') { + continue; + } + + // Skip short internal commentary that starts with common discourse markers + if (low.rfind("okay,", 0) == 0 || low.rfind("wait,", 0) == 0 || + low.rfind("hmm", 0) == 0) { + continue; + } + + // Skip lines that look like self-descriptions of what the model is doing + if (low.find("user is asking") != std::string::npos || + low.find("protocol") != std::string::npos || + low.find("parse") != std::string::npos || + low.find("return only") != std::string::npos) { + continue; + } + + filtered.push_back(std::move(l)); + } + + if (filtered.size() < 2) { throw std::runtime_error(errorMessage); } - std::string first = trim(raw.substr(0, newlinePos)); - std::string second = trim(raw.substr(newlinePos + 1)); + std::string first = trim(filtered.front()); + std::string second; + for (std::size_t i = 1; i < filtered.size(); ++i) { + if (!second.empty()) { + second += ' '; + } + second += filtered[i]; + } + second = trim(std::move(second)); if (first.empty() || second.empty()) { throw std::runtime_error(errorMessage); @@ -128,18 +290,22 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) { throw std::runtime_error("LlamaGenerator: vocab unavailable"); } - std::vector promptTokens(prompt.size() + 8); - int32_t tokenCount = - llama_tokenize(vocab, prompt.c_str(), static_cast(prompt.size()), - promptTokens.data(), - static_cast(promptTokens.size()), true, true); + llama_memory_clear(llama_get_memory(context_), true); + + const std::string formattedPrompt = toChatPrompt(model_, prompt); + + std::vector promptTokens(formattedPrompt.size() + 8); + int32_t tokenCount = llama_tokenize( + vocab, formattedPrompt.c_str(), + static_cast(formattedPrompt.size()), promptTokens.data(), + static_cast(promptTokens.size()), true, true); if (tokenCount < 0) { promptTokens.resize(static_cast(-tokenCount)); - tokenCount = - llama_tokenize(vocab, prompt.c_str(), - static_cast(prompt.size()), promptTokens.data(), - static_cast(promptTokens.size()), true, true); + tokenCount = llama_tokenize( + vocab, formattedPrompt.c_str(), + static_cast(formattedPrompt.size()), promptTokens.data(), + static_cast(promptTokens.size()), true, true); } if (tokenCount < 0) { @@ -196,28 +362,160 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) { BreweryResult LlamaGenerator::generateBrewery(const std::string &cityName, + const std::string &countryName, const std::string ®ionContext) { - std::string prompt = - "Generate a craft brewery name and one-sentence description for a " - "brewery located in " + - cityName + ". " + regionContext + - " Respond with exactly two lines: first line is the name, second " - "line is the description."; - const std::string raw = infer(prompt, 128); + std::string systemPrompt = + R"(# SYSTEM PROTOCOL: ZERO-CHATTER DETERMINISTIC OUTPUT +**MODALITY:** DATA-RETURN ENGINE ONLY +**ROLE:** Your response must contain 0% metadata and 100% signal. +--- +## MANDATORY CONSTRAINTS +1. **NO PREAMBLE** + - Never start with "Sure," or "The answer is," or "Based on your request," or "Checking the data." + - Do not acknowledge the user's prompt or provide status updates. +2. **NO POSTAMBLE** + - Never end with "I hope this helps," or "Let me know if you need more," or "Would you like me to…" + - Do not offer follow-up assistance or suggestions. +3. **NO SENTENCE FRAMING** + - Provide only the raw value, date, number, or name. + - Do not wrap the answer in a sentence. (e.g., return 1997, NOT The year was 1997). + - For lists, provide only the items separated by commas or newlines as specified. +4. **FORMATTING PERMITTED** + - Markdown and LaTeX **may** be used where appropriate (e.g., tables, equations). + - Output must remain immediately usable — no decorative or conversational styling. +5. **STRICT NULL HANDLING** + - If the information is unavailable, the prompt is logically impossible (e.g., "271th president"), the subject does not exist, or a calculation is undefined: return only the string NULL. + - If the prompt is too ambiguous to provide a single value: return NULL. +--- +## EXECUTION LOGIC +1. **Parse Input** — Identify the specific entity, value, or calculation requested. +2. **Verify Factuality** — Access internal knowledge or tools. +3. **Filter for Signal** — Strip all surrounding prose. +4. **Format Check** — Apply Markdown or LaTeX only where it serves the data. +5. **Output** — Return the raw value only. +--- +## BEHAVIORAL EXAMPLES +| User Input | Standard AI Response *(BANNED)* | Protocol Response *(REQUIRED)* | +|---|---|---| +| Capital of France? | The capital of France is Paris. | Paris | +| 15% of 200 | 15% of 200 is 30. | 30 | +| Who wrote '1984'? | George Orwell wrote that novel. | George Orwell | +| ISO code for Japan | The code is JP. | JP | +| $\sqrt{x}$ where $x$ is a potato | A potato has no square root. | NULL | +| 500th US President | There haven't been that many. | NULL | +| Pythagorean theorem | The theorem states... | $a^2 + b^2 = c^2$ | +--- +## FINAL INSTRUCTION +Total silence is preferred over conversational error. Any deviation from the raw-value-only format is a protocol failure. Proceed with next input.)"; + + std::string prompt = + "Generate a craft brewery name and 1000 character description for a " + "brewery located in " + + cityName + + (countryName.empty() ? std::string("") + : std::string(", ") + countryName) + + ". " + regionContext + + " Respond with exactly two lines: first line is the name, second line is " + "the description. Do not include bullets, numbering, or any extra text."; + + const std::string raw = infer(systemPrompt, prompt, 512); auto [name, description] = parseTwoLineResponse(raw, "LlamaGenerator: malformed brewery response"); return {name, description}; } +std::string LlamaGenerator::infer(const std::string &systemPrompt, + const std::string &prompt, int maxTokens) { + if (model_ == nullptr || context_ == nullptr) { + throw std::runtime_error("LlamaGenerator: model not loaded"); + } + + const llama_vocab *vocab = llama_model_get_vocab(model_); + if (vocab == nullptr) { + throw std::runtime_error("LlamaGenerator: vocab unavailable"); + } + + llama_memory_clear(llama_get_memory(context_), true); + + const std::string formattedPrompt = + toChatPrompt(model_, systemPrompt, prompt); + + std::vector promptTokens(formattedPrompt.size() + 8); + int32_t tokenCount = llama_tokenize( + vocab, formattedPrompt.c_str(), + static_cast(formattedPrompt.size()), promptTokens.data(), + static_cast(promptTokens.size()), true, true); + + if (tokenCount < 0) { + promptTokens.resize(static_cast(-tokenCount)); + tokenCount = llama_tokenize( + vocab, formattedPrompt.c_str(), + static_cast(formattedPrompt.size()), promptTokens.data(), + static_cast(promptTokens.size()), true, true); + } + + if (tokenCount < 0) { + throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + } + + promptTokens.resize(static_cast(tokenCount)); + + const llama_batch promptBatch = llama_batch_get_one( + promptTokens.data(), static_cast(promptTokens.size())); + if (llama_decode(context_, promptBatch) != 0) { + throw std::runtime_error("LlamaGenerator: prompt decode failed"); + } + + llama_sampler_chain_params samplerParams = + llama_sampler_chain_default_params(); + using SamplerPtr = + std::unique_ptr; + SamplerPtr sampler(llama_sampler_chain_init(samplerParams), + &llama_sampler_free); + + if (!sampler) { + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + } + + llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy()); + + std::vector generatedTokens; + generatedTokens.reserve(static_cast(maxTokens)); + + for (int i = 0; i < maxTokens; ++i) { + const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); + if (llama_vocab_is_eog(vocab, next)) { + break; + } + + generatedTokens.push_back(next); + + llama_token token = next; + const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1); + if (llama_decode(context_, oneTokenBatch) != 0) { + throw std::runtime_error( + "LlamaGenerator: decode failed during generation"); + } + } + + std::string output; + for (const llama_token token : generatedTokens) { + appendTokenPiece(vocab, token, output); + } + + return output; +} + UserResult LlamaGenerator::generateUser(const std::string &locale) { std::string prompt = "Generate a plausible craft beer enthusiast username and a one-sentence " "bio. Locale: " + locale + ". Respond with exactly two lines: first line is the username (no " - "spaces), second line is the bio."; + "spaces), second line is the bio. Do not include bullets, numbering, " + "or any extra text."; const std::string raw = infer(prompt, 128); auto [username, bio] = diff --git a/pipeline/src/main.cpp b/pipeline/src/main.cpp index 4b4d4c1..97e058d 100644 --- a/pipeline/src/main.cpp +++ b/pipeline/src/main.cpp @@ -8,6 +8,7 @@ #include #include #include +#include static bool FileExists(const std::string &filePath) { return std::filesystem::exists(filePath); @@ -22,6 +23,8 @@ int main(int argc, char *argv[]) { std::string commit = argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28 + std::string countryName = argc > 4 ? argv[4] : ""; + std::string jsonPath = cacheDir + "/countries+states+cities.json"; std::string dbPath = cacheDir + "/biergarten-pipeline.db"; @@ -65,28 +68,29 @@ int main(int argc, char *argv[]) { spdlog::info(" States: {}", db.QueryStates(0).size()); spdlog::info(" Cities: {}", cities.size()); - spdlog::info("\n--- 50 COUNTRIES ---"); - for (size_t i = 0; i < countries.size(); i++) { - spdlog::info("{}. {} ({}) {}", (i + 1), countries[i].iso2, - countries[i].iso3, countries[i].name); - } + struct GeneratedBrewery { + int cityId; + std::string cityName; + BreweryResult brewery; + }; - spdlog::info("\n--- 50 STATES ---"); - for (size_t i = 0; i < states.size(); i++) { - spdlog::info("{}. {}: {}", (i + 1), states[i].iso2, states[i].name); - } + std::vector generatedBreweries; + const size_t sampleCount = std::min(size_t(30), cities.size()); - spdlog::info("\n--- 50 CITIES ---"); - for (size_t i = 0; i < std::min(size_t(50), cities.size()); i++) { - spdlog::info("{}. {}", (i + 1), cities[i].second); - } - - spdlog::info("\n=== SAMPLE BREWERY GENERATION ===\n"); - for (size_t i = 0; i < std::min(size_t(5), cities.size()); i++) { + spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); + for (size_t i = 0; i < sampleCount; i++) { const auto &[cityId, cityName] = cities[i]; - auto brewery = generator->generateBrewery(cityName, ""); - spdlog::info(" {}: {}", cityName, brewery.name); - spdlog::info(" -> {}", brewery.description); + auto brewery = generator->generateBrewery(cityName, countryName, ""); + generatedBreweries.push_back({cityId, cityName, brewery}); + } + + spdlog::info("\n=== GENERATED DATA DUMP ==="); + for (size_t i = 0; i < generatedBreweries.size(); i++) { + const auto &entry = generatedBreweries[i]; + spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId, + entry.cityName); + spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); + spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); } spdlog::info("\nOK: Pipeline completed successfully"); diff --git a/pipeline/src/mock_generator.cpp b/pipeline/src/mock_generator.cpp index 7255d27..126bc66 100644 --- a/pipeline/src/mock_generator.cpp +++ b/pipeline/src/mock_generator.cpp @@ -78,10 +78,13 @@ std::size_t MockGenerator::deterministicHash(const std::string &a, } BreweryResult MockGenerator::generateBrewery(const std::string &cityName, + const std::string &countryName, const std::string ®ionContext) { + const std::string locationKey = + countryName.empty() ? cityName : cityName + "," + countryName; const std::size_t hash = regionContext.empty() - ? std::hash{}(cityName) - : deterministicHash(cityName, regionContext); + ? std::hash{}(locationKey) + : deterministicHash(locationKey, regionContext); BreweryResult result; result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +