fix: address critical correctness, reliability, and design issues in pipeline

CORRECTNESS FIXES:
- json_loader: Add RollbackTransaction() and call it on exception instead of
  CommitTransaction(). Prevents partial data corruption on parse/disk errors.
- wikipedia_service: Fix invalid MediaWiki API parameter explaintext=true ->
  explaintext=1. Now returns plain text instead of HTML markup in contexts.
- helpers: Fix ParseTwoLineResponse filter to only remove known thinking tags
  (<think>, <reasoning>, <reflect>) instead of any <...> pattern. Prevents
  silently removing legitimate output like <username>content</username>.

RELIABILITY & DESIGN IMPROVEMENTS:
- load/main: Make n_ctx (context window size) configurable via --n-ctx flag
  (default 2048, range 1-32768) to support larger models like Qwen3-14B.
- generate_brewery: Prevent retry prompt growth by extracting location context
  into constant and using compact retry format (error + schema + location only).
  Avoids token truncation on final retry attempts.
- database: Fix data representativeness by changing QueryCities from
  ORDER BY name (alphabetic bias) to ORDER BY RANDOM() for unbiased sampling.
  Convert all SQLITE_STATIC to SQLITE_TRANSIENT to prevent use-after-free risks.

POLISH:
- infer: Advance sampling seed between generation calls to improve diversity
  across brewery and user generation.
- data_downloader: Remove unnecessary commit hash truncation; use full hash.
- json_loader: Fix misleading log message from "RapidJSON" to "Boost.JSON".
This commit is contained in:
Aaron Po
2026-04-03 11:58:00 -04:00
parent 8d306bf691
commit e4e16a5084
14 changed files with 202 additions and 121 deletions

View File

@@ -33,6 +33,10 @@ struct ApplicationOptions {
/// random).
float top_p = 0.92f;
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 2048;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;

View File

@@ -16,6 +16,8 @@ class LlamaGenerator final : public DataGenerator {
void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void SetContextSize(uint32_t n_ctx);
void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string& city_name,
const std::string& country_name,
@@ -39,6 +41,7 @@ class LlamaGenerator final : public DataGenerator {
float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu;
uint32_t n_ctx_ = 2048;
};
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -59,6 +59,9 @@ class SqliteDatabase {
/// @brief Commits the active database transaction.
void CommitTransaction();
/// @brief Rolls back the active database transaction.
void RollbackTransaction();
/// @brief Inserts a country row.
void InsertCountry(int id, const std::string& name, const std::string& iso2,
const std::string& iso3);

View File

@@ -28,11 +28,12 @@ std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() {
auto llama_generator = std::make_unique<LlamaGenerator>();
llama_generator->SetSamplingOptions(options_.temperature, options_.top_p,
options_.seed);
llama_generator->SetContextSize(options_.n_ctx);
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"seed={})",
"n_ctx={}, seed={})",
options_.model_path, options_.temperature, options_.top_p,
options_.seed);
options_.n_ctx, options_.seed);
generator = std::move(llama_generator);
}
generator->Load(options_.model_path);

View File

@@ -25,15 +25,10 @@ std::string DataDownloader::DownloadCountriesDatabase(
return cache_path;
}
std::string short_commit = commit;
if (commit.length() > 7) {
short_commit = commit.substr(0, 7);
}
std::string url =
"https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json";
commit + "/json/countries+states+cities.json";
spdlog::info("[DataDownloader] Downloading: {}", url);

View File

@@ -50,6 +50,14 @@ BreweryResult LlamaGenerator::GenerateBrewery(
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
"Location: " + city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name);
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
@@ -84,19 +92,16 @@ BreweryResult LlamaGenerator::GenerateBrewery(
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
// Update prompt with error details to guide LLM toward correct output
// Update prompt with error details to guide LLM toward correct output.
// For retries, use a compact prompt format to avoid exceeding token
// limits.
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
"\n\n" +
retry_location;
}
// All retry attempts exhausted: log failure and throw exception

View File

@@ -147,7 +147,17 @@ std::pair<std::string, std::string> ParseTwoLineResponse(
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}

View File

@@ -186,5 +186,11 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
/**
* Advance seed for next generation to improve output diversity
*/
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
return output;
}

View File

@@ -42,7 +42,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_params.n_ctx = n_ctx_;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {

View File

@@ -48,3 +48,18 @@ void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::SetContextSize(uint32_t n_ctx) {
/**
* Validate context size: must be positive and reasonable for the model
*/
if (n_ctx == 0 || n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
/**
* Store context size for use during model loading
*/
n_ctx_ = n_ctx;
}

View File

@@ -80,6 +80,16 @@ void SqliteDatabase::CommitTransaction() {
}
}
void SqliteDatabase::RollbackTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_);
char* err = nullptr;
if (sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown";
sqlite3_free(err);
throw std::runtime_error("RollbackTransaction failed: " + msg);
}
}
void SqliteDatabase::InsertCountry(int id, const std::string& name,
const std::string& iso2,
const std::string& iso3) {
@@ -96,9 +106,9 @@ void SqliteDatabase::InsertCountry(int id, const std::string& name,
throw std::runtime_error("Failed to prepare country insert");
sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_TRANSIENT);
if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert country");
@@ -123,8 +133,8 @@ void SqliteDatabase::InsertState(int id, int country_id,
sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, country_id);
sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_TRANSIENT);
if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert state");
@@ -150,7 +160,7 @@ void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, state_id);
sqlite3_bind_int(stmt, 3, country_id);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_double(stmt, 5, latitude);
sqlite3_bind_double(stmt, 6, longitude);
@@ -165,7 +175,8 @@ std::vector<City> SqliteDatabase::QueryCities() {
std::vector<City> cities;
sqlite3_stmt* stmt = nullptr;
const char* query = "SELECT id, name, country_id FROM cities ORDER BY name";
const char* query =
"SELECT id, name, country_id FROM cities ORDER BY RANDOM()";
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) {

View File

@@ -11,7 +11,7 @@ void JsonLoader::LoadWorldCities(const std::string& json_path,
constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path);
spdlog::info("\nLoading {} (streaming Boost.JSON SAX)...", json_path);
db.BeginTransaction();
bool transactionOpen = true;
@@ -44,7 +44,8 @@ void JsonLoader::LoadWorldCities(const std::string& json_path,
}
} catch (...) {
if (transactionOpen) {
db.CommitTransaction();
db.RollbackTransaction();
transactionOpen = false;
}
throw;
}

View File

@@ -1,12 +1,12 @@
#include <spdlog/spdlog.h>
#include <boost/program_options.hpp>
#include <iostream>
#include <memory>
#include <boost/program_options.hpp>
#include <spdlog/spdlog.h>
#include "biergarten_data_generator.h"
#include "web_client/curl_web_client.h"
#include "database/database.h"
#include "web_client/curl_web_client.h"
namespace po = boost::program_options;
@@ -18,101 +18,117 @@ namespace po = boost::program_options;
* @param options Output ApplicationOptions struct.
* @return true if parsing succeeded and should proceed, false otherwise.
*/
bool ParseArguments(int argc, char **argv, ApplicationOptions &options) {
// If no arguments provided, display usage and exit
if (argc == 1) {
std::cout << "Biergarten Pipeline - Geographic Data Pipeline with Brewery Generation\n\n";
std::cout << "Usage: biergarten-pipeline [options]\n\n";
std::cout << "Options:\n";
std::cout << " --mocked Use mocked generator for brewery/user data\n";
std::cout << " --model, -m PATH Path to LLM model file (gguf) for generation\n";
std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: /tmp)\n";
std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 (default: 0.8)\n";
std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 (default: 0.92)\n";
std::cout << " --seed SEED Random seed: -1 for random (default: -1)\n";
std::cout << " --help, -h Show this help message\n\n";
std::cout << "Note: --mocked and --model are mutually exclusive. Exactly one must be provided.\n";
std::cout << "Data source is always pinned to commit c5eb7772 (stable 2026-03-28).\n";
return false;
}
bool ParseArguments(int argc, char** argv, ApplicationOptions& options) {
// If no arguments provided, display usage and exit
if (argc == 1) {
std::cout << "Biergarten Pipeline - Geographic Data Pipeline with "
"Brewery Generation\n\n";
std::cout << "Usage: biergarten-pipeline [options]\n\n";
std::cout << "Options:\n";
std::cout << " --mocked Use mocked generator for "
"brewery/user data\n";
std::cout << " --model, -m PATH Path to LLM model file (gguf) for "
"generation\n";
std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: "
"/tmp)\n";
std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 "
"(default: 0.8)\n";
std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 "
"(default: 0.92)\n";
std::cout << " --n-ctx SIZE Context window size in tokens "
"(default: 2048)\n";
std::cout << " --seed SEED Random seed: -1 for random "
"(default: -1)\n";
std::cout << " --help, -h Show this help message\n\n";
std::cout << "Note: --mocked and --model are mutually exclusive. Exactly "
"one must be provided.\n";
std::cout << "Data source is always pinned to commit c5eb7772 (stable "
"2026-03-28).\n";
return false;
}
po::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"mocked", po::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", po::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"cache-dir,c", po::value<std::string>()->default_value("/tmp"),
"Directory for cached JSON")(
"temperature", po::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")(
"top-p", po::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"seed", po::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
po::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"mocked", po::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", po::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"cache-dir,c", po::value<std::string>()->default_value("/tmp"),
"Directory for cached JSON")(
"temperature", po::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")(
"top-p", po::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"n-ctx", po::value<uint32_t>()->default_value(2048),
"Context window size in tokens (1-32768)")(
"seed", po::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
if (vm.count("help")) {
std::cout << desc << "\n";
return false;
}
if (vm.count("help")) {
std::cout << desc << "\n";
return false;
}
// Check for mutually exclusive --mocked and --model flags
bool use_mocked = vm["mocked"].as<bool>();
std::string model_path = vm["model"].as<std::string>();
// Check for mutually exclusive --mocked and --model flags
bool use_mocked = vm["mocked"].as<bool>();
std::string model_path = vm["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error("ERROR: --mocked and --model are mutually exclusive");
return false;
}
if (use_mocked && !model_path.empty()) {
spdlog::error("ERROR: --mocked and --model are mutually exclusive");
return false;
}
if (!use_mocked && model_path.empty()) {
spdlog::error("ERROR: Either --mocked or --model must be specified");
return false;
}
if (!use_mocked && model_path.empty()) {
spdlog::error("ERROR: Either --mocked or --model must be specified");
return false;
}
// Warn if sampling parameters are provided with --mocked
if (use_mocked) {
bool hasTemperature = vm["temperature"].defaulted() == false;
bool hasTopP = vm["top-p"].defaulted() == false;
bool hasSeed = vm["seed"].defaulted() == false;
// Warn if sampling parameters are provided with --mocked
if (use_mocked) {
bool hasTemperature = vm["temperature"].defaulted() == false;
bool hasTopP = vm["top-p"].defaulted() == false;
bool hasSeed = vm["seed"].defaulted() == false;
if (hasTemperature || hasTopP || hasSeed) {
spdlog::warn("WARNING: Sampling parameters (--temperature, --top-p, --seed) are ignored when using --mocked");
}
}
if (hasTemperature || hasTopP || hasSeed) {
spdlog::warn(
"WARNING: Sampling parameters (--temperature, --top-p, --seed) "
"are ignored when using --mocked");
}
}
options.use_mocked = use_mocked;
options.model_path = model_path;
options.cache_dir = vm["cache-dir"].as<std::string>();
options.temperature = vm["temperature"].as<float>();
options.top_p = vm["top-p"].as<float>();
options.seed = vm["seed"].as<int>();
// commit is always pinned to c5eb7772
options.use_mocked = use_mocked;
options.model_path = model_path;
options.cache_dir = vm["cache-dir"].as<std::string>();
options.temperature = vm["temperature"].as<float>();
options.top_p = vm["top-p"].as<float>();
options.n_ctx = vm["n-ctx"].as<uint32_t>();
options.seed = vm["seed"].as<int>();
// commit is always pinned to c5eb7772
return true;
return true;
}
int main(int argc, char *argv[]) {
try {
const CurlGlobalState curl_state;
int main(int argc, char* argv[]) {
try {
const CurlGlobalState curl_state;
ApplicationOptions options;
if (!ParseArguments(argc, argv, options)) {
return 0;
}
ApplicationOptions options;
if (!ParseArguments(argc, argv, options)) {
return 0;
}
auto webClient = std::make_shared<CURLWebClient>();
SqliteDatabase database;
auto webClient = std::make_shared<CURLWebClient>();
SqliteDatabase database;
BiergartenDataGenerator generator(options, webClient, database);
return generator.Run();
BiergartenDataGenerator generator(options, webClient, database);
return generator.Run();
} catch (const std::exception &e) {
spdlog::error("ERROR: Application failed: {}", e.what());
return 1;
}
} catch (const std::exception& e) {
spdlog::error("ERROR: Application failed: {}", e.what());
return 1;
}
}

View File

@@ -11,7 +11,7 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string encoded = client_->UrlEncode(std::string(query));
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=true&format=json";
"&prop=extracts&explaintext=1&format=json";
const std::string body = client_->Get(url);
@@ -19,16 +19,27 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
return extract;
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
return extract;
}
}
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (ec) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
ec.message());
}
return {};