Pipeline: add CURL/WebClient & Wikipedia service

Introduce a pluggable web client interface and concrete CURL implementation: adds IWebClient, CURLWebClient, and CurlGlobalState (headers + curl_web_client.cpp). DataDownloader now accepts an IWebClient and delegates downloads. Add WikipediaService for cached Wikipedia summary lookups. Refactor SqliteDatabase to return full City records and update consumers accordingly. Improve JsonLoader to use batched transactions during streaming parses. Enhance LlamaGenerator with sampling options, increased token limits, JSON extraction/validation, and other parsing helpers. Modernize CMake: set policy/version, add project_options, simplify FetchContent usage (spdlog), require Boost components (program_options/json), list pipeline sources explicitly, and tweak post-build/memcheck targets. Update README to match implementation changes and new CLI/config conventions.
This commit is contained in:
Aaron Po
2026-04-02 16:29:16 -04:00
parent ac136f7179
commit 98083ab40c
16 changed files with 1125 additions and 794 deletions

View File

@@ -0,0 +1,139 @@
#include "curl_web_client.h"
#include <cstdio>
#include <curl/curl.h>
#include <fstream>
#include <memory>
#include <sstream>
#include <stdexcept>
CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally");
}
}
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
namespace {
// curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void *contents, size_t size, size_t nmemb,
void *userp) {
size_t realsize = size * nmemb;
auto *s = static_cast<std::string *>(userp);
s->append(static_cast<char *>(contents), realsize);
return realsize;
}
// curl write callback that writes to a file stream
static size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb,
void *userp) {
size_t realsize = size * nmemb;
auto *outFile = static_cast<std::ofstream *>(userp);
outFile->write(static_cast<char *>(contents), realsize);
return realsize;
}
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() {
CURL *handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL *curl, const std::string &url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
} // namespace
CURLWebClient::CURLWebClient() {}
CURLWebClient::~CURLWebClient() {}
void CURLWebClient::DownloadToFile(const std::string &url,
const std::string &filePath) {
auto curl = create_handle();
std::ofstream outFile(filePath, std::ios::binary);
if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " +
filePath);
}
set_common_get_options(curl.get(), url, 30L, 300L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void *>(&outFile));
CURLcode res = curl_easy_perform(curl.get());
outFile.close();
if (res != CURLE_OK) {
std::remove(filePath.c_str());
std::string error = std::string("[CURLWebClient] Download failed: ") +
curl_easy_strerror(res);
throw std::runtime_error(error);
}
long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
std::remove(filePath.c_str());
std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str());
}
}
std::string CURLWebClient::Get(const std::string &url) {
auto curl = create_handle();
std::string response_string;
set_common_get_options(curl.get(), url, 10L, 20L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str());
}
return response_string;
}
std::string CURLWebClient::UrlEncode(const std::string &value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char *output = curl_easy_escape(nullptr, value.c_str(), 0);
if (output) {
std::string result(output);
curl_free(output);
return result;
}
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}

View File

@@ -1,20 +1,13 @@
#include "data_downloader.h"
#include <cstdio>
#include <curl/curl.h>
#include "web_client.h"
#include <filesystem>
#include <fstream>
#include <spdlog/spdlog.h>
#include <sstream>
#include <stdexcept>
static size_t WriteCallback(void *contents, size_t size, size_t nmemb,
void *userp) {
size_t realsize = size * nmemb;
std::ofstream *outFile = static_cast<std::ofstream *>(userp);
outFile->write(static_cast<char *>(contents), realsize);
return realsize;
}
DataDownloader::DataDownloader() {}
DataDownloader::DataDownloader(std::shared_ptr<IWebClient> webClient)
: m_webClient(std::move(webClient)) {}
DataDownloader::~DataDownloader() {}
@@ -41,56 +34,7 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cachePath,
spdlog::info("[DataDownloader] Downloading: {}", url);
CURL *curl = curl_easy_init();
if (!curl) {
throw std::runtime_error("[DataDownloader] Failed to initialize libcurl");
}
std::ofstream outFile(cachePath, std::ios::binary);
if (!outFile.is_open()) {
curl_easy_cleanup(curl);
throw std::runtime_error("[DataDownloader] Cannot open file for writing: " +
cachePath);
}
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, static_cast<void *>(&outFile));
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 30L);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 300L);
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
CURLcode res = curl_easy_perform(curl);
outFile.close();
if (res != CURLE_OK) {
curl_easy_cleanup(curl);
std::remove(cachePath.c_str());
std::string error = std::string("[DataDownloader] Download failed: ") +
curl_easy_strerror(res);
throw std::runtime_error(error);
}
long httpCode = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode);
curl_easy_cleanup(curl);
if (httpCode != 200) {
std::remove(cachePath.c_str());
std::stringstream ss;
ss << "[DataDownloader] HTTP error " << httpCode
<< " (commit: " << shortCommit << ")";
throw std::runtime_error(ss.str());
}
m_webClient->DownloadToFile(url, cachePath);
std::ifstream fileCheck(cachePath, std::ios::binary | std::ios::ate);
std::streamsize size = fileCheck.tellg();

View File

@@ -157,13 +157,12 @@ void SqliteDatabase::InsertCity(int id, int stateId, int countryId,
sqlite3_finalize(stmt);
}
std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() {
std::vector<City> SqliteDatabase::QueryCities() {
std::lock_guard<std::mutex> lock(dbMutex);
std::vector<std::pair<int, std::string>> cities;
std::vector<City> cities;
sqlite3_stmt *stmt = nullptr;
const char *query = "SELECT id, name FROM cities ORDER BY name";
const char *query = "SELECT id, name, country_id FROM cities ORDER BY name";
int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) {
@@ -174,7 +173,8 @@ std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() {
int id = sqlite3_column_int(stmt, 0);
const char *name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
cities.push_back({id, name ? std::string(name) : ""});
int countryId = sqlite3_column_int(stmt, 2);
cities.push_back({id, name ? std::string(name) : "", countryId});
}
sqlite3_finalize(stmt);

View File

@@ -1,32 +1,52 @@
#include <chrono>
#include <spdlog/spdlog.h>
#include "json_loader.h"
#include "stream_parser.h"
#include <chrono>
#include <spdlog/spdlog.h>
void JsonLoader::LoadWorldCities(const std::string &jsonPath,
SqliteDatabase &db) {
constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", jsonPath);
db.BeginTransaction();
bool transactionOpen = true;
size_t citiesProcessed = 0;
StreamingJsonParser::Parse(
jsonPath, db,
[&](const CityRecord &record) {
db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude);
citiesProcessed++;
},
[&](size_t current, size_t total) {
if (current % 10000 == 0 && current > 0) {
spdlog::info(" [Progress] Parsed {} cities...", current);
}
});
try {
StreamingJsonParser::Parse(
jsonPath, db,
[&](const CityRecord &record) {
db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude);
++citiesProcessed;
spdlog::info(" OK: Parsed all cities from JSON");
if (citiesProcessed % kBatchSize == 0) {
db.CommitTransaction();
db.BeginTransaction();
}
},
[&](size_t current, size_t /*total*/) {
if (current % kBatchSize == 0 && current > 0) {
spdlog::info(" [Progress] Parsed {} cities...", current);
}
});
db.CommitTransaction();
spdlog::info(" OK: Parsed all cities from JSON");
if (transactionOpen) {
db.CommitTransaction();
transactionOpen = false;
}
} catch (...) {
if (transactionOpen) {
db.CommitTransaction();
}
throw;
}
auto endTime = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(

View File

@@ -1,7 +1,3 @@
#include "llama_generator.h"
#include "llama.h"
#include <algorithm>
#include <array>
#include <cctype>
@@ -11,8 +7,12 @@
#include <string>
#include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
#include "llama_generator.h"
namespace {
std::string trim(std::string value) {
@@ -26,10 +26,47 @@ std::string trim(std::string value) {
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool inWhitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!inWhitespace) {
out.push_back(' ');
inWhitespace = true;
}
continue;
}
inWhitespace = false;
out.push_back(static_cast<char>(ch));
}
return trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view regionContext,
std::size_t maxChars = 700) {
std::string normalized = CondenseWhitespace(std::string(regionContext));
if (normalized.size() <= maxChars) {
return normalized;
}
normalized.resize(maxChars);
const std::size_t lastSpace = normalized.find_last_of(' ');
if (lastSpace != std::string::npos && lastSpace > maxChars / 2) {
normalized.resize(lastSpace);
}
normalized += "...";
return normalized;
}
std::string stripCommonPrefix(std::string line) {
line = trim(std::move(line));
// Strip simple list markers like "- ", "* ", "1. ", "2) ".
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = trim(line.substr(1));
} else {
@@ -68,6 +105,50 @@ std::string stripCommonPrefix(std::string line) {
return trim(std::move(line));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty())
lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>')
continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2)
throw std::runtime_error(errorMessage);
std::string first = trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty())
second += ' ';
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty())
throw std::runtime_error(errorMessage);
return {first, second};
}
std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
@@ -75,10 +156,7 @@ std::string toChatPrompt(const llama_model *model,
return userPrompt;
}
const llama_chat_message message{
"user",
userPrompt.c_str(),
};
const llama_chat_message message{"user", userPrompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
int32_t required =
@@ -106,14 +184,11 @@ std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// Fall back to concatenating but keep system and user parts distinct.
return systemPrompt + "\n\n" + userPrompt;
}
const llama_chat_message messages[2] = {
{"system", systemPrompt.c_str()},
{"user", userPrompt.c_str()},
};
const llama_chat_message messages[2] = {{"system", systemPrompt.c_str()},
{"user", userPrompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (systemPrompt.size() + userPrompt.size()) * 4));
@@ -161,73 +236,135 @@ void appendTokenPiece(const llama_vocab *vocab, llama_token token,
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) {
std::size_t start = std::string::npos;
int depth = 0;
bool inString = false;
bool escaped = false;
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty()) {
lines.push_back(std::move(line));
}
}
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
// Filter out obvious internal-thought / meta lines that sometimes leak from
// models (e.g. "<think>", "Okay, so the user is asking me...").
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Skip single-token angle-bracket markers like <think> or <...>
if (!l.empty() && l.front() == '<' && l.back() == '>') {
if (inString) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
inString = false;
}
continue;
}
// Skip short internal commentary that starts with common discourse markers
if (low.rfind("okay,", 0) == 0 || low.rfind("wait,", 0) == 0 ||
low.rfind("hmm", 0) == 0) {
if (ch == '"') {
inString = true;
continue;
}
// Skip lines that look like self-descriptions of what the model is doing
if (low.find("user is asking") != std::string::npos ||
low.find("protocol") != std::string::npos ||
low.find("parse") != std::string::npos ||
low.find("return only") != std::string::npos) {
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) {
throw std::runtime_error(errorMessage);
}
std::string first = trim(filtered.front());
std::string second;
for (std::size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) {
second += ' ';
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
jsonOut = text.substr(start, i - start + 1);
return true;
}
}
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty()) {
throw std::runtime_error(errorMessage);
}
return {first, second};
return false;
}
std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut,
std::string &descriptionOut) {
auto validateObject = [&](const boost::json::value &jv,
std::string &errorOut) -> bool {
if (!jv.is_object()) {
errorOut = "JSON root must be an object";
return false;
}
const auto &obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
errorOut = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
errorOut = "JSON field 'description' is missing or not a string";
return false;
}
nameOut = trim(std::string(obj.at("name").as_string().c_str()));
descriptionOut =
trim(std::string(obj.at("description").as_string().c_str()));
if (nameOut.empty()) {
errorOut = "JSON field 'name' must not be empty";
return false;
}
if (descriptionOut.empty()) {
errorOut = "JSON field 'description' must not be empty";
return false;
}
std::string nameLower = nameOut;
std::string descriptionLower = descriptionOut;
std::transform(
nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(descriptionLower.begin(), descriptionLower.end(),
descriptionLower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (nameLower == "string" || descriptionLower == "string") {
errorOut = "JSON appears to be a schema placeholder, not content";
return false;
}
errorOut.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validationError;
if (ec) {
std::string extracted;
if (!extractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
} // namespace
LlamaGenerator::~LlamaGenerator() {
@@ -244,10 +381,30 @@ LlamaGenerator::~LlamaGenerator() {
llama_backend_free();
}
void LlamaGenerator::load(const std::string &modelPath) {
if (modelPath.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
void LlamaGenerator::setSamplingOptions(float temperature, float topP,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(topP > 0.0f && topP <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = topP;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::load(const std::string &modelPath) {
if (modelPath.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
@@ -261,7 +418,7 @@ void LlamaGenerator::load(const std::string &modelPath) {
llama_backend_init();
llama_model_params modelParams = llama_model_default_params();
model_ = llama_load_model_from_file(modelPath.c_str(), modelParams);
model_ = llama_model_load_from_file(modelPath.c_str(), modelParams);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + modelPath);
@@ -281,14 +438,12 @@ void LlamaGenerator::load(const std::string &modelPath) {
}
std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
if (model_ == nullptr || context_ == nullptr) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
}
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
llama_memory_clear(llama_get_memory(context_), true);
@@ -308,17 +463,33 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0) {
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effectiveMaxTokens = std::max(1, std::min(maxTokens, nCtx - 1));
int32_t promptBudget = std::min(nBatch, nCtx - effectiveMaxTokens);
promptBudget = std::max<int32_t>(1, promptBudget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > promptBudget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, promptBudget);
promptTokens.resize(static_cast<std::size_t>(promptBudget));
tokenCount = promptBudget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0) {
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
llama_sampler_chain_params samplerParams =
llama_sampler_chain_default_params();
@@ -326,116 +497,45 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(samplerParams),
&llama_sampler_free);
if (!sampler) {
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy());
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generatedTokens;
generatedTokens.reserve(static_cast<std::size_t>(maxTokens));
for (int i = 0; i < maxTokens; ++i) {
for (int i = 0; i < effectiveMaxTokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) {
if (llama_vocab_is_eog(vocab, next))
break;
}
generatedTokens.push_back(next);
llama_token token = next;
const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, oneTokenBatch) != 0) {
if (llama_decode(context_, oneTokenBatch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
std::string output;
for (const llama_token token : generatedTokens) {
for (const llama_token token : generatedTokens)
appendTokenPiece(vocab, token, output);
}
return output;
}
BreweryResult
LlamaGenerator::generateBrewery(const std::string &cityName,
const std::string &countryName,
const std::string &regionContext) {
std::string systemPrompt =
R"(# SYSTEM PROTOCOL: ZERO-CHATTER DETERMINISTIC OUTPUT
**MODALITY:** DATA-RETURN ENGINE ONLY
**ROLE:** Your response must contain 0% metadata and 100% signal.
---
## MANDATORY CONSTRAINTS
1. **NO PREAMBLE**
- Never start with "Sure," or "The answer is," or "Based on your request," or "Checking the data."
- Do not acknowledge the user's prompt or provide status updates.
2. **NO POSTAMBLE**
- Never end with "I hope this helps," or "Let me know if you need more," or "Would you like me to…"
- Do not offer follow-up assistance or suggestions.
3. **NO SENTENCE FRAMING**
- Provide only the raw value, date, number, or name.
- Do not wrap the answer in a sentence. (e.g., return 1997, NOT The year was 1997).
- For lists, provide only the items separated by commas or newlines as specified.
4. **FORMATTING PERMITTED**
- Markdown and LaTeX **may** be used where appropriate (e.g., tables, equations).
- Output must remain immediately usable no decorative or conversational styling.
5. **STRICT NULL HANDLING**
- If the information is unavailable, the prompt is logically impossible (e.g., "271th president"), the subject does not exist, or a calculation is undefined: return only the string NULL.
- If the prompt is too ambiguous to provide a single value: return NULL.
---
## EXECUTION LOGIC
1. **Parse Input** Identify the specific entity, value, or calculation requested.
2. **Verify Factuality** Access internal knowledge or tools.
3. **Filter for Signal** Strip all surrounding prose.
4. **Format Check** Apply Markdown or LaTeX only where it serves the data.
5. **Output** Return the raw value only.
---
## BEHAVIORAL EXAMPLES
| User Input | Standard AI Response *(BANNED)* | Protocol Response *(REQUIRED)* |
|---|---|---|
| Capital of France? | The capital of France is Paris. | Paris |
| 15% of 200 | 15% of 200 is 30. | 30 |
| Who wrote '1984'? | George Orwell wrote that novel. | George Orwell |
| ISO code for Japan | The code is JP. | JP |
| $\sqrt{x}$ where $x$ is a potato | A potato has no square root. | NULL |
| 500th US President | There haven't been that many. | NULL |
| Pythagorean theorem | The theorem states... | $a^2 + b^2 = c^2$ |
---
## FINAL INSTRUCTION
Total silence is preferred over conversational error. Any deviation from the raw-value-only format is a protocol failure. Proceed with next input.)";
std::string prompt =
"Generate a craft brewery name and 1000 character description for a "
"brewery located in " +
cityName +
(countryName.empty() ? std::string("")
: std::string(", ") + countryName) +
". " + regionContext +
" Respond with exactly two lines: first line is the name, second line is "
"the description. Do not include bullets, numbering, or any extra text.";
const std::string raw = infer(systemPrompt, prompt, 512);
auto [name, description] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed brewery response");
return {name, description};
}
std::string LlamaGenerator::infer(const std::string &systemPrompt,
const std::string &prompt, int maxTokens) {
if (model_ == nullptr || context_ == nullptr) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
}
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
llama_memory_clear(llama_get_memory(context_), true);
@@ -456,17 +556,33 @@ std::string LlamaGenerator::infer(const std::string &systemPrompt,
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0) {
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effectiveMaxTokens = std::max(1, std::min(maxTokens, nCtx - 1));
int32_t promptBudget = std::min(nBatch, nCtx - effectiveMaxTokens);
promptBudget = std::max<int32_t>(1, promptBudget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > promptBudget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, promptBudget);
promptTokens.resize(static_cast<std::size_t>(promptBudget));
tokenCount = promptBudget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0) {
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
llama_sampler_chain_params samplerParams =
llama_sampler_chain_default_params();
@@ -474,61 +590,145 @@ std::string LlamaGenerator::infer(const std::string &systemPrompt,
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(samplerParams),
&llama_sampler_free);
if (!sampler) {
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy());
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generatedTokens;
generatedTokens.reserve(static_cast<std::size_t>(maxTokens));
for (int i = 0; i < maxTokens; ++i) {
for (int i = 0; i < effectiveMaxTokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) {
if (llama_vocab_is_eog(vocab, next))
break;
}
generatedTokens.push_back(next);
llama_token token = next;
const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, oneTokenBatch) != 0) {
if (llama_decode(context_, oneTokenBatch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
std::string output;
for (const llama_token token : generatedTokens) {
for (const llama_token token : generatedTokens)
appendTokenPiece(vocab, token, output);
}
return output;
}
UserResult LlamaGenerator::generateUser(const std::string &locale) {
BreweryResult
LlamaGenerator::generateBrewery(const std::string &cityName,
const std::string &countryName,
const std::string &regionContext) {
const std::string safeRegionContext = PrepareRegionContext(regionContext);
const std::string systemPrompt =
"You are a copywriter for a craft beer travel guide. "
"Your writing is vivid, specific to place, and avoids generic beer "
"cliches. "
"You must output ONLY valid JSON. "
"The JSON schema must be exactly: {\"name\": \"string\", "
"\"description\": \"string\"}. "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Generate a plausible craft beer enthusiast username and a one-sentence "
"bio. Locale: " +
locale +
". Respond with exactly two lines: first line is the username (no "
"spaces), second line is the bio. Do not include bullets, numbering, "
"or any extra text.";
"Write a brewery name and place-specific description for a craft "
"brewery in " +
cityName +
(countryName.empty() ? std::string("")
: std::string(", ") + countryName) +
(safeRegionContext.empty()
? std::string(".")
: std::string(". Regional context: ") + safeRegionContext);
const std::string raw = infer(prompt, 128);
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
const int maxAttempts = 3;
std::string raw;
std::string lastError;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = infer(systemPrompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
std::string name;
std::string description;
const std::string validationError =
ValidateBreweryJson(raw, name, description);
if (validationError.empty()) {
return {std::move(name), std::move(description)};
}
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
lastError = validationError;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validationError);
prompt = "Your previous response was invalid. Error: " + validationError +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
cityName +
(countryName.empty() ? std::string("")
: std::string(", ") + countryName) +
(safeRegionContext.empty()
? std::string("")
: std::string("\nRegional context: ") + safeRegionContext);
}
return {username, bio};
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
maxAttempts, lastError.empty() ? raw : lastError);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}
UserResult LlamaGenerator::generateUser(const std::string &locale) {
const std::string systemPrompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int maxAttempts = 3;
std::string raw;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = infer(systemPrompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200)
bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception &e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
maxAttempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -1,35 +1,66 @@
#include <algorithm>
#include <filesystem>
#include <iostream>
#include <memory>
#include <unordered_map>
#include <vector>
#include <boost/program_options.hpp>
#include <spdlog/spdlog.h>
#include "curl_web_client.h"
#include "data_downloader.h"
#include "data_generator.h"
#include "database.h"
#include "json_loader.h"
#include "llama_generator.h"
#include "mock_generator.h"
#include <curl/curl.h>
#include <filesystem>
#include <memory>
#include <spdlog/spdlog.h>
#include <vector>
#include "wikipedia_service.h"
static bool FileExists(const std::string &filePath) {
return std::filesystem::exists(filePath);
}
namespace po = boost::program_options;
int main(int argc, char *argv[]) {
try {
curl_global_init(CURL_GLOBAL_DEFAULT);
const CurlGlobalState curl_state;
std::string modelPath = argc > 1 ? argv[1] : "";
std::string cacheDir = argc > 2 ? argv[2] : "/tmp";
std::string commit =
argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28
po::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"model,m", po::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"cache-dir,c", po::value<std::string>()->default_value("/tmp"),
"Directory for cached JSON")(
"temperature", po::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")(
"top-p", po::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"seed", po::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer")(
"commit", po::value<std::string>()->default_value("c5eb7772"),
"Git commit hash for DB consistency");
std::string countryName = argc > 4 ? argv[4] : "";
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
if (vm.count("help")) {
std::cout << desc << "\n";
return 0;
}
std::string modelPath = vm["model"].as<std::string>();
std::string cacheDir = vm["cache-dir"].as<std::string>();
float temperature = vm["temperature"].as<float>();
float topP = vm["top-p"].as<float>();
int seed = vm["seed"].as<int>();
std::string commit = vm["commit"].as<std::string>();
std::string jsonPath = cacheDir + "/countries+states+cities.json";
std::string dbPath = cacheDir + "/biergarten-pipeline.db";
bool hasJsonCache = FileExists(jsonPath);
bool hasDbCache = FileExists(dbPath);
bool hasJsonCache = std::filesystem::exists(jsonPath);
bool hasDbCache = std::filesystem::exists(dbPath);
auto webClient = std::make_shared<CURLWebClient>();
SqliteDatabase db;
@@ -40,7 +71,7 @@ int main(int argc, char *argv[]) {
spdlog::info("[Pipeline] Cache hit: skipping download and parse");
} else {
spdlog::info("\n[Pipeline] Downloading geographic data from GitHub...");
DataDownloader downloader;
DataDownloader downloader(webClient);
downloader.DownloadCountriesDatabase(jsonPath, commit);
JsonLoader::LoadWorldCities(jsonPath, db);
@@ -52,17 +83,30 @@ int main(int argc, char *argv[]) {
generator = std::make_unique<MockGenerator>();
spdlog::info("[Generator] Using MockGenerator (no model path provided)");
} else {
generator = std::make_unique<LlamaGenerator>();
spdlog::info("[Generator] Using LlamaGenerator: {}", modelPath);
auto llamaGenerator = std::make_unique<LlamaGenerator>();
llamaGenerator->setSamplingOptions(temperature, topP, seed);
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"seed={})",
modelPath, temperature, topP, seed);
generator = std::move(llamaGenerator);
}
generator->load(modelPath);
WikipediaService wikipediaService(webClient);
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
auto countries = db.QueryCountries(50);
auto states = db.QueryStates(50);
auto cities = db.QueryCities();
// Build a quick map of country id -> name for per-city lookups.
auto allCountries = db.QueryCountries(0);
std::unordered_map<int, std::string> countryMap;
for (const auto &c : allCountries)
countryMap[c.id] = c.name;
spdlog::info("\nTotal records loaded:");
spdlog::info(" Countries: {}", db.QueryCountries(0).size());
spdlog::info(" States: {}", db.QueryStates(0).size());
@@ -79,8 +123,23 @@ int main(int argc, char *argv[]) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
for (size_t i = 0; i < sampleCount; i++) {
const auto &[cityId, cityName] = cities[i];
auto brewery = generator->generateBrewery(cityName, countryName, "");
const auto &city = cities[i];
const int cityId = city.id;
const std::string cityName = city.name;
std::string localCountry;
const auto countryIt = countryMap.find(city.countryId);
if (countryIt != countryMap.end()) {
localCountry = countryIt->second;
}
const std::string regionContext =
wikipediaService.GetSummary(cityName, localCountry);
spdlog::debug("[Pipeline] Region context for {}: {}", cityName,
regionContext);
auto brewery =
generator->generateBrewery(cityName, localCountry, regionContext);
generatedBreweries.push_back({cityId, cityName, brewery});
}
@@ -95,12 +154,10 @@ int main(int argc, char *argv[]) {
spdlog::info("\nOK: Pipeline completed successfully");
curl_global_cleanup();
return 0;
} catch (const std::exception &e) {
spdlog::error("ERROR: Pipeline failed: {}", e.what());
curl_global_cleanup();
return 1;
}
}

View File

@@ -1,15 +1,22 @@
#include "stream_parser.h"
#include "database.h"
#include <cstdio>
#include <rapidjson/filereadstream.h>
#include <rapidjson/reader.h>
#include <rapidjson/stringbuffer.h>
#include <stdexcept>
#include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h>
using namespace rapidjson;
#include "database.h"
#include "stream_parser.h"
class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>;
class CityRecordHandler : public BaseReaderHandler<UTF8<>, CityRecordHandler> {
public:
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
struct ParseContext {
SqliteDatabase *db = nullptr;
std::function<void(const CityRecord &)> on_city;
@@ -20,11 +27,35 @@ public:
int states_inserted = 0;
};
CityRecordHandler(ParseContext &ctx) : context(ctx) {}
explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {}
bool StartArray() {
private:
ParseContext &context;
int depth = 0;
bool in_countries_array = false;
bool in_country_object = false;
bool in_states_array = false;
bool in_state_object = false;
bool in_cities_array = false;
bool building_city = false;
int current_country_id = 0;
int current_state_id = 0;
CityRecord current_city = {};
std::string current_key;
std::string current_key_val;
std::string current_string_val;
std::string country_info[3];
std::string state_info[2];
// Boost.JSON SAX Hooks
bool on_document_begin(boost::system::error_code &) { return true; }
bool on_document_end(boost::system::error_code &) { return true; }
bool on_array_begin(boost::system::error_code &) {
depth++;
if (depth == 1) {
in_countries_array = true;
} else if (depth == 3 && current_key == "states") {
@@ -35,7 +66,7 @@ public:
return true;
}
bool EndArray(SizeType /*elementCount*/) {
bool on_array_end(std::size_t, boost::system::error_code &) {
if (depth == 1) {
in_countries_array = false;
} else if (depth == 3) {
@@ -47,9 +78,8 @@ public:
return true;
}
bool StartObject() {
bool on_object_begin(boost::system::error_code &) {
depth++;
if (depth == 2 && in_countries_array) {
in_country_object = true;
current_country_id = 0;
@@ -68,7 +98,7 @@ public:
return true;
}
bool EndObject(SizeType /*memberCount*/) {
bool on_object_end(std::size_t, boost::system::error_code &) {
if (depth == 6 && building_city) {
if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) {
@@ -84,7 +114,7 @@ public:
context.total_file_size);
}
} catch (const std::exception &e) {
spdlog::warn(" WARN: Failed to emit city: {}", e.what());
spdlog::warn("Record parsing failed: {}", e.what());
}
}
building_city = false;
@@ -95,7 +125,7 @@ public:
state_info[0], state_info[1]);
context.states_inserted++;
} catch (const std::exception &e) {
spdlog::warn(" WARN: Failed to insert state: {}", e.what());
spdlog::warn("Record parsing failed: {}", e.what());
}
}
in_state_object = false;
@@ -106,7 +136,7 @@ public:
country_info[1], country_info[2]);
context.countries_inserted++;
} catch (const std::exception &e) {
spdlog::warn(" WARN: Failed to insert country: {}", e.what());
spdlog::warn("Record parsing failed: {}", e.what());
}
}
in_country_object = false;
@@ -116,46 +146,71 @@ public:
return true;
}
bool Key(const char *str, SizeType len, bool /*copy*/) {
current_key.assign(str, len);
bool on_key_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_key_val.append(s.data(), s.size());
return true;
}
bool String(const char *str, SizeType len, bool /*copy*/) {
bool on_key(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_key_val.append(s.data(), s.size());
current_key = current_key_val;
current_key_val.clear();
return true;
}
bool on_string_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_string_val.append(s.data(), s.size());
return true;
}
bool on_string(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_string_val.append(s.data(), s.size());
if (building_city && current_key == "name") {
current_city.name.assign(str, len);
current_city.name = current_string_val;
} else if (in_state_object && current_key == "name") {
state_info[0].assign(str, len);
state_info[0] = current_string_val;
} else if (in_state_object && current_key == "iso2") {
state_info[1].assign(str, len);
state_info[1] = current_string_val;
} else if (in_country_object && current_key == "name") {
country_info[0].assign(str, len);
country_info[0] = current_string_val;
} else if (in_country_object && current_key == "iso2") {
country_info[1].assign(str, len);
country_info[1] = current_string_val;
} else if (in_country_object && current_key == "iso3") {
country_info[2].assign(str, len);
country_info[2] = current_string_val;
}
current_string_val.clear();
return true;
}
bool Int(int i) {
bool on_number_part(boost::json::string_view, boost::system::error_code &) {
return true;
}
bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code &) {
if (building_city && current_key == "id") {
current_city.id = i;
current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") {
current_state_id = i;
current_state_id = static_cast<int>(i);
} else if (in_country_object && current_key == "id") {
current_country_id = i;
current_country_id = static_cast<int>(i);
}
return true;
}
bool Uint(unsigned i) { return Int(static_cast<int>(i)); }
bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code &ec) {
return on_int64(static_cast<int64_t>(u), "", ec);
}
bool Int64(int64_t i) { return Int(static_cast<int>(i)); }
bool Uint64(uint64_t i) { return Int(static_cast<int>(i)); }
bool Double(double d) {
bool on_double(double d, boost::json::string_view,
boost::system::error_code &) {
if (building_city) {
if (current_key == "latitude") {
current_city.latitude = d;
@@ -166,27 +221,14 @@ public:
return true;
}
bool Bool(bool /*b*/) { return true; }
bool Null() { return true; }
private:
ParseContext &context;
int depth = 0;
bool in_countries_array = false;
bool in_country_object = false;
bool in_states_array = false;
bool in_state_object = false;
bool in_cities_array = false;
bool building_city = false;
int current_country_id = 0;
int current_state_id = 0;
CityRecord current_city = {};
std::string current_key;
std::string country_info[3];
std::string state_info[2];
bool on_bool(bool, boost::system::error_code &) { return true; }
bool on_null(boost::system::error_code &) { return true; }
bool on_comment_part(boost::json::string_view, boost::system::error_code &) {
return true;
}
bool on_comment(boost::json::string_view, boost::system::error_code &) {
return true;
}
};
void StreamingJsonParser::Parse(
@@ -194,7 +236,7 @@ void StreamingJsonParser::Parse(
std::function<void(const CityRecord &)> onCity,
std::function<void(size_t, size_t)> onProgress) {
spdlog::info(" Streaming parse of {}...", filePath);
spdlog::info(" Streaming parse of {} (Boost.JSON)...", filePath);
FILE *file = std::fopen(filePath.c_str(), "rb");
if (!file) {
@@ -212,23 +254,35 @@ void StreamingJsonParser::Parse(
CityRecordHandler::ParseContext ctx{&db, onCity, onProgress, 0,
total_size, 0, 0};
CityRecordHandler handler(ctx);
boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx);
Reader reader;
char buf[65536];
FileReadStream frs(file, buf, sizeof(buf));
size_t bytes_read;
boost::system::error_code ec;
if (!reader.Parse(frs, handler)) {
ParseErrorCode errCode = reader.GetParseErrorCode();
size_t errOffset = reader.GetErrorOffset();
std::fclose(file);
throw std::runtime_error(std::string("JSON parse error at offset ") +
std::to_string(errOffset) +
" (code: " + std::to_string(errCode) + ")");
while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
char const *p = buf;
std::size_t remain = bytes_read;
while (remain > 0) {
std::size_t consumed = parser.write_some(true, p, remain, ec);
if (ec) {
std::fclose(file);
throw std::runtime_error("JSON parse error: " + ec.message());
}
p += consumed;
remain -= consumed;
}
}
parser.write_some(false, nullptr, 0, ec); // Signal EOF
std::fclose(file);
if (ec) {
throw std::runtime_error("JSON parse error at EOF: " + ec.message());
}
spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted);
}

View File

@@ -0,0 +1,77 @@
#include "wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
WikipediaService::WikipediaService(std::shared_ptr<IWebClient> client)
: client_(std::move(client)) {}
std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string encoded = client_->UrlEncode(std::string(query));
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=true&format=json";
const std::string body = client_->Get(url);
boost::system::error_code ec;
boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) {
auto &pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto &page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
return extract;
}
}
}
return {};
}
std::string WikipediaService::GetSummary(std::string_view city,
std::string_view country) {
const std::string key = std::string(city) + "|" + std::string(country);
const auto cacheIt = cache_.find(key);
if (cacheIt != cache_.end()) {
return cacheIt->second;
}
std::string result;
if (!client_) {
cache_.emplace(key, result);
return result;
}
std::string regionQuery(city);
if (!country.empty()) {
regionQuery += ", ";
regionQuery += country;
}
const std::string beerQuery = "beer in " + std::string(city);
try {
const std::string regionExtract = FetchExtract(regionQuery);
const std::string beerExtract = FetchExtract(beerQuery);
if (!regionExtract.empty()) {
result += regionExtract;
}
if (!beerExtract.empty()) {
if (!result.empty())
result += "\n\n";
result += beerExtract;
}
} catch (const std::runtime_error &e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
e.what());
}
cache_.emplace(key, result);
return result;
}