fix: address critical correctness, reliability, and design issues in pipeline

CORRECTNESS FIXES:
- json_loader: Add RollbackTransaction() and call it on exception instead of
  CommitTransaction(). Prevents partial data corruption on parse/disk errors.
- wikipedia_service: Fix invalid MediaWiki API parameter explaintext=true ->
  explaintext=1. Now returns plain text instead of HTML markup in contexts.
- helpers: Fix ParseTwoLineResponse filter to only remove known thinking tags
  (<think>, <reasoning>, <reflect>) instead of any <...> pattern. Prevents
  silently removing legitimate output like <username>content</username>.

RELIABILITY & DESIGN IMPROVEMENTS:
- load/main: Make n_ctx (context window size) configurable via --n-ctx flag
  (default 2048, range 1-32768) to support larger models like Qwen3-14B.
- generate_brewery: Prevent retry prompt growth by extracting location context
  into constant and using compact retry format (error + schema + location only).
  Avoids token truncation on final retry attempts.
- database: Fix data representativeness by changing QueryCities from
  ORDER BY name (alphabetic bias) to ORDER BY RANDOM() for unbiased sampling.
  Convert all SQLITE_STATIC to SQLITE_TRANSIENT to prevent use-after-free risks.

POLISH:
- infer: Advance sampling seed between generation calls to improve diversity
  across brewery and user generation.
- data_downloader: Remove unnecessary commit hash truncation; use full hash.
- json_loader: Fix misleading log message from "RapidJSON" to "Boost.JSON".
This commit is contained in:
Aaron Po
2026-04-03 11:58:00 -04:00
parent 8d306bf691
commit e4e16a5084
14 changed files with 202 additions and 121 deletions

View File

@@ -33,6 +33,10 @@ struct ApplicationOptions {
/// random). /// random).
float top_p = 0.92f; float top_p = 0.92f;
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 2048;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;

View File

@@ -16,6 +16,8 @@ class LlamaGenerator final : public DataGenerator {
void SetSamplingOptions(float temperature, float top_p, int seed = -1); void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void SetContextSize(uint32_t n_ctx);
void Load(const std::string& model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string& city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string& country_name, const std::string& country_name,
@@ -39,6 +41,7 @@ class LlamaGenerator final : public DataGenerator {
float sampling_temperature_ = 0.8f; float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f; float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu; uint32_t sampling_seed_ = 0xFFFFFFFFu;
uint32_t n_ctx_ = 2048;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -59,6 +59,9 @@ class SqliteDatabase {
/// @brief Commits the active database transaction. /// @brief Commits the active database transaction.
void CommitTransaction(); void CommitTransaction();
/// @brief Rolls back the active database transaction.
void RollbackTransaction();
/// @brief Inserts a country row. /// @brief Inserts a country row.
void InsertCountry(int id, const std::string& name, const std::string& iso2, void InsertCountry(int id, const std::string& name, const std::string& iso2,
const std::string& iso3); const std::string& iso3);

View File

@@ -28,11 +28,12 @@ std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() {
auto llama_generator = std::make_unique<LlamaGenerator>(); auto llama_generator = std::make_unique<LlamaGenerator>();
llama_generator->SetSamplingOptions(options_.temperature, options_.top_p, llama_generator->SetSamplingOptions(options_.temperature, options_.top_p,
options_.seed); options_.seed);
llama_generator->SetContextSize(options_.n_ctx);
spdlog::info( spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, " "[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"seed={})", "n_ctx={}, seed={})",
options_.model_path, options_.temperature, options_.top_p, options_.model_path, options_.temperature, options_.top_p,
options_.seed); options_.n_ctx, options_.seed);
generator = std::move(llama_generator); generator = std::move(llama_generator);
} }
generator->Load(options_.model_path); generator->Load(options_.model_path);

View File

@@ -25,15 +25,10 @@ std::string DataDownloader::DownloadCountriesDatabase(
return cache_path; return cache_path;
} }
std::string short_commit = commit;
if (commit.length() > 7) {
short_commit = commit.substr(0, 7);
}
std::string url = std::string url =
"https://raw.githubusercontent.com/dr5hn/" "https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" + "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json"; commit + "/json/countries+states+cities.json";
spdlog::info("[DataDownloader] Downloading: {}", url); spdlog::info("[DataDownloader] Downloading: {}", url);

View File

@@ -50,6 +50,14 @@ BreweryResult LlamaGenerator::GenerateBrewery(
? std::string(".") ? std::string(".")
: std::string(". Regional context: ") + safe_region_context); : std::string(". Regional context: ") + safe_region_context);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
"Location: " + city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name);
/** /**
* RETRY LOOP with validation and error correction * RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based * Attempts to generate valid brewery data up to 3 times, with feedback-based
@@ -84,19 +92,16 @@ BreweryResult LlamaGenerator::GenerateBrewery(
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error); attempt + 1, validation_error);
// Update prompt with error details to guide LLM toward correct output // Update prompt with error details to guide LLM toward correct output.
// For retries, use a compact prompt format to avoid exceeding token
// limits.
prompt = prompt =
"Your previous response was invalid. Error: " + validation_error + "Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: " "\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}." "{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys." "\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " + "\n\n" +
city_name + retry_location;
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
} }
// All retry attempts exhausted: log failure and throw exception // All retry attempts exhausted: log failure and throw exception

View File

@@ -147,7 +147,17 @@ std::pair<std::string, std::string> ParseTwoLineResponse(
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (!l.empty() && l.front() == '<' && low.back() == '>') continue; // Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue; if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l)); filtered.push_back(std::move(l));
} }

View File

@@ -186,5 +186,11 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
std::string output; std::string output;
for (const llama_token token : generated_tokens) for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output); AppendTokenPiecePublic(vocab, token, output);
/**
* Advance seed for next generation to improve output diversity
*/
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
return output; return output;
} }

View File

@@ -42,7 +42,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
} }
llama_context_params context_params = llama_context_default_params(); llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048; context_params.n_ctx = n_ctx_;
context_ = llama_init_from_model(model_, context_params); context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) { if (context_ == nullptr) {

View File

@@ -48,3 +48,18 @@ void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED) sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed); : static_cast<uint32_t>(seed);
} }
void LlamaGenerator::SetContextSize(uint32_t n_ctx) {
/**
* Validate context size: must be positive and reasonable for the model
*/
if (n_ctx == 0 || n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
/**
* Store context size for use during model loading
*/
n_ctx_ = n_ctx;
}

View File

@@ -80,6 +80,16 @@ void SqliteDatabase::CommitTransaction() {
} }
} }
void SqliteDatabase::RollbackTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_);
char* err = nullptr;
if (sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown";
sqlite3_free(err);
throw std::runtime_error("RollbackTransaction failed: " + msg);
}
}
void SqliteDatabase::InsertCountry(int id, const std::string& name, void SqliteDatabase::InsertCountry(int id, const std::string& name,
const std::string& iso2, const std::string& iso2,
const std::string& iso3) { const std::string& iso3) {
@@ -96,9 +106,9 @@ void SqliteDatabase::InsertCountry(int id, const std::string& name,
throw std::runtime_error("Failed to prepare country insert"); throw std::runtime_error("Failed to prepare country insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_TRANSIENT);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert country"); throw std::runtime_error("Failed to insert country");
@@ -123,8 +133,8 @@ void SqliteDatabase::InsertState(int id, int country_id,
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, country_id); sqlite3_bind_int(stmt, 2, country_id);
sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_TRANSIENT);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert state"); throw std::runtime_error("Failed to insert state");
@@ -150,7 +160,7 @@ void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, state_id); sqlite3_bind_int(stmt, 2, state_id);
sqlite3_bind_int(stmt, 3, country_id); sqlite3_bind_int(stmt, 3, country_id);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_double(stmt, 5, latitude); sqlite3_bind_double(stmt, 5, latitude);
sqlite3_bind_double(stmt, 6, longitude); sqlite3_bind_double(stmt, 6, longitude);
@@ -165,7 +175,8 @@ std::vector<City> SqliteDatabase::QueryCities() {
std::vector<City> cities; std::vector<City> cities;
sqlite3_stmt* stmt = nullptr; sqlite3_stmt* stmt = nullptr;
const char* query = "SELECT id, name, country_id FROM cities ORDER BY name"; const char* query =
"SELECT id, name, country_id FROM cities ORDER BY RANDOM()";
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {

View File

@@ -11,7 +11,7 @@ void JsonLoader::LoadWorldCities(const std::string& json_path,
constexpr size_t kBatchSize = 10000; constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path); spdlog::info("\nLoading {} (streaming Boost.JSON SAX)...", json_path);
db.BeginTransaction(); db.BeginTransaction();
bool transactionOpen = true; bool transactionOpen = true;
@@ -44,7 +44,8 @@ void JsonLoader::LoadWorldCities(const std::string& json_path,
} }
} catch (...) { } catch (...) {
if (transactionOpen) { if (transactionOpen) {
db.CommitTransaction(); db.RollbackTransaction();
transactionOpen = false;
} }
throw; throw;
} }

View File

@@ -1,12 +1,12 @@
#include <spdlog/spdlog.h>
#include <boost/program_options.hpp>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <boost/program_options.hpp>
#include <spdlog/spdlog.h>
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include "web_client/curl_web_client.h"
#include "database/database.h" #include "database/database.h"
#include "web_client/curl_web_client.h"
namespace po = boost::program_options; namespace po = boost::program_options;
@@ -21,18 +21,29 @@ namespace po = boost::program_options;
bool ParseArguments(int argc, char** argv, ApplicationOptions& options) { bool ParseArguments(int argc, char** argv, ApplicationOptions& options) {
// If no arguments provided, display usage and exit // If no arguments provided, display usage and exit
if (argc == 1) { if (argc == 1) {
std::cout << "Biergarten Pipeline - Geographic Data Pipeline with Brewery Generation\n\n"; std::cout << "Biergarten Pipeline - Geographic Data Pipeline with "
"Brewery Generation\n\n";
std::cout << "Usage: biergarten-pipeline [options]\n\n"; std::cout << "Usage: biergarten-pipeline [options]\n\n";
std::cout << "Options:\n"; std::cout << "Options:\n";
std::cout << " --mocked Use mocked generator for brewery/user data\n"; std::cout << " --mocked Use mocked generator for "
std::cout << " --model, -m PATH Path to LLM model file (gguf) for generation\n"; "brewery/user data\n";
std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: /tmp)\n"; std::cout << " --model, -m PATH Path to LLM model file (gguf) for "
std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 (default: 0.8)\n"; "generation\n";
std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 (default: 0.92)\n"; std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: "
std::cout << " --seed SEED Random seed: -1 for random (default: -1)\n"; "/tmp)\n";
std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 "
"(default: 0.8)\n";
std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 "
"(default: 0.92)\n";
std::cout << " --n-ctx SIZE Context window size in tokens "
"(default: 2048)\n";
std::cout << " --seed SEED Random seed: -1 for random "
"(default: -1)\n";
std::cout << " --help, -h Show this help message\n\n"; std::cout << " --help, -h Show this help message\n\n";
std::cout << "Note: --mocked and --model are mutually exclusive. Exactly one must be provided.\n"; std::cout << "Note: --mocked and --model are mutually exclusive. Exactly "
std::cout << "Data source is always pinned to commit c5eb7772 (stable 2026-03-28).\n"; "one must be provided.\n";
std::cout << "Data source is always pinned to commit c5eb7772 (stable "
"2026-03-28).\n";
return false; return false;
} }
@@ -48,6 +59,8 @@ bool ParseArguments(int argc, char **argv, ApplicationOptions &options) {
"Sampling temperature (higher = more random)")( "Sampling temperature (higher = more random)")(
"top-p", po::value<float>()->default_value(0.92f), "top-p", po::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")( "Nucleus sampling top-p in (0,1] (higher = more random)")(
"n-ctx", po::value<uint32_t>()->default_value(2048),
"Context window size in tokens (1-32768)")(
"seed", po::value<int>()->default_value(-1), "seed", po::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer"); "Sampler seed: -1 for random, otherwise non-negative integer");
@@ -81,7 +94,9 @@ bool ParseArguments(int argc, char **argv, ApplicationOptions &options) {
bool hasSeed = vm["seed"].defaulted() == false; bool hasSeed = vm["seed"].defaulted() == false;
if (hasTemperature || hasTopP || hasSeed) { if (hasTemperature || hasTopP || hasSeed) {
spdlog::warn("WARNING: Sampling parameters (--temperature, --top-p, --seed) are ignored when using --mocked"); spdlog::warn(
"WARNING: Sampling parameters (--temperature, --top-p, --seed) "
"are ignored when using --mocked");
} }
} }
@@ -90,6 +105,7 @@ bool ParseArguments(int argc, char **argv, ApplicationOptions &options) {
options.cache_dir = vm["cache-dir"].as<std::string>(); options.cache_dir = vm["cache-dir"].as<std::string>();
options.temperature = vm["temperature"].as<float>(); options.temperature = vm["temperature"].as<float>();
options.top_p = vm["top-p"].as<float>(); options.top_p = vm["top-p"].as<float>();
options.n_ctx = vm["n-ctx"].as<uint32_t>();
options.seed = vm["seed"].as<int>(); options.seed = vm["seed"].as<int>();
// commit is always pinned to c5eb7772 // commit is always pinned to c5eb7772

View File

@@ -11,7 +11,7 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string encoded = client_->UrlEncode(std::string(query)); const std::string encoded = client_->UrlEncode(std::string(query));
const std::string url = const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=true&format=json"; "&prop=extracts&explaintext=1&format=json";
const std::string body = client_->Get(url); const std::string body = client_->Get(url);
@@ -19,6 +19,7 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
boost::json::value doc = boost::json::parse(body, ec); boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) { if (!ec && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto& page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
@@ -29,6 +30,16 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
return extract; return extract;
} }
} }
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (ec) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
ec.message());
} }
return {}; return {};