Add timeout for enrichment, refactor json deserialization

This commit is contained in:
Aaron Po
2026-05-13 12:39:06 -04:00
parent b7c0b1c8d4
commit 773e7c774b
11 changed files with 140 additions and 73 deletions

View File

@@ -83,6 +83,9 @@ struct SamplingOptions {
/// @brief Random seed (-1 for random, otherwise non-negative).
int seed = -1;
/// @brief Number of layers to offload to GPU.
int n_gpu_layers = 0;
};
/**
@@ -95,8 +98,7 @@ struct GeneratorOptions {
/// @brief Use mocked generator instead of actual LLM inference.
bool use_mocked = false;
/// @brief Number of layers to offload to GPU.
int n_gpu_layers = 0;
/// @brief Specific sampling parameters for this generator.
/// If nullopt, the application should use global defaults.

View File

@@ -42,7 +42,7 @@ public:
* @param value Raw string to encode.
* @return Percent-encoded string safe for use in a URL.
*/
std::string UrlEncode(const std::string& value) override;
std::string EncodeURL(const std::string& value) override;
};

View File

@@ -30,7 +30,7 @@ class WebClient {
* @param value Raw string value.
* @return Encoded value safe for URL usage.
*/
virtual std::string UrlEncode(const std::string& value) = 0;
virtual std::string EncodeURL(const std::string& value) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -30,6 +30,8 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
"Context window size in tokens");
opt("seed", prog_opts::value<int>()->default_value(sampling_defaults.seed),
"Sampler seed: -1 for random, otherwise non-negative integer");
opt("n-gpu-layers", prog_opts::value<int>()->default_value(0),
"Number of layers to offload to GPU");
};
// --mocked and --model are mutually exclusive; validation is enforced below
@@ -50,8 +52,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
opt("prompt-dir", prog_opts::value<std::string>()->default_value(""),
"Directory containing named prompt files (e.g. BREWERY_GENERATION.md)."
" Required when not using --mocked.");
opt("n-gpu-layers", prog_opts::value<int>()->default_value(0),
"Number of layers to offload to GPU");
};
add_sampling_options();
@@ -113,7 +114,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
options.generator.use_mocked = use_mocked;
options.generator.model_path = model_path;
options.generator.n_gpu_layers = n_gpu_layers;
// options.generator.n_gpu_layers = n_gpu_layers;
// Only populate sampling config when the user explicitly overrides at
// least one value. Leaving it as std::nullopt lets LlamaGenerator fall
@@ -122,7 +123,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
const bool user_provided_sampling =
!var_map["temperature"].defaulted() || !var_map["top-p"].defaulted() ||
!var_map["top-k"].defaulted() || !var_map["n-ctx"].defaulted() ||
!var_map["seed"].defaulted();
!var_map["seed"].defaulted() || !var_map["n_gpu_layers"].defaulted();
if (user_provided_sampling) {
// Warn but do not fail — the run is still valid, the flags are just
@@ -136,6 +137,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
sampling.top_k = var_map["top-k"].as<uint32_t>();
sampling.n_ctx = var_map["n-ctx"].as<uint32_t>();
sampling.seed = var_map["seed"].as<int>();
sampling.n_gpu_layers = var_map["n-gpu-layers"].as<int>();
options.generator.sampling = sampling;
}

View File

@@ -13,7 +13,7 @@
#include "biergarten_data_generator.h"
#include "json_handling/json_loader.h"
static constexpr size_t kBreweryAmount = 50;
static constexpr size_t kBreweryAmount = 40;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");

View File

@@ -21,8 +21,8 @@ bool BiergartenDataGenerator::Run() {
for (auto& city : cities) {
try {
std::string region_context = context_service_->GetLocationContext(city);
spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
// spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}",
// city.city, city.iso3166_2, region_context);
enriched.push_back(
EnrichedCity{.location = std::move(city),

View File

@@ -89,7 +89,7 @@ LlamaGenerator::LlamaGenerator(
}
n_ctx_ = sampling.n_ctx;
n_gpu_layers_ = options.generator.n_gpu_layers;
n_gpu_layers_ = sampling.n_gpu_layers;
this->Load(model_path);
}

View File

@@ -8,11 +8,9 @@
#include <boost/di.hpp>
#include <boost/program_options.hpp>
#include <exception>
#include <memory>
#include <optional>
#include <string>
#include "biergarten_data_generator.h"
@@ -21,12 +19,12 @@
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
#include "data_model/models.h"
#include "llama_backend_state.h"
#include "services/enrichment/enrichment_service.h"
#include "services/database/export_service.h"
#include "services/prompting/prompt_directory.h"
#include "services/database/sqlite_export_service.h"
#include "services/datetime/timer.h"
#include "services/enrichment/enrichment_service.h"
#include "services/enrichment/wikipedia_service.h"
#include "services/prompting/prompt_directory.h"
#include "web_client/http_web_client.h"
namespace di = boost::di;
@@ -43,7 +41,9 @@ int main(const int argc, char** argv) {
spdlog::set_level(spdlog::level::debug);
#endif
const auto parsed_options = ParseArguments(argc, argv);
const std::optional<ApplicationOptions> parsed_options =
ParseArguments(argc, argv);
if (!parsed_options.has_value()) {
return 0;
}
@@ -73,7 +73,7 @@ int main(const int argc, char** argv) {
di::bind<std::string>().to(model_path),
di::bind<DataGenerator>().to(
[options, model_path, sampling, &prompt_directory](
const auto& inj) -> std::unique_ptr<DataGenerator> {
const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
@@ -89,7 +89,9 @@ int main(const int argc, char** argv) {
options, model_path,
inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory));
}));
})
);
auto generator =
injector.create<std::unique_ptr<BiergartenDataGenerator>>();

View File

@@ -1,61 +1,106 @@
/**
* @file wikipedia/fetch_extract.cc
* @brief WikipediaService::FetchExtract() implementation.
*/
#include <spdlog/spdlog.h>
#include <boost/json.hpp>
#include <format>
#include <string>
#include <string_view>
#include "services/enrichment/wikipedia_service.h"
using namespace boost;
std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
// 1. Cache Lookup
if (const auto cache_it = this->extract_cache_.find(cache_key);
cache_it != this->extract_cache_.end()) {
spdlog::debug("Wikipedia: Cache hit for {}!", cache_key);
return cache_it->second;
}
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string encoded = this->client_->EncodeURL(cache_key);
const std::string url = std::format(
"https://en.wikipedia.org/w/"
"api.php?action=query&titles={}&prop=extracts&explaintext=1&format=json",
encoded);
const std::string body = this->client_->Get(url);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string();
std::string extract(extract_view);
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
{
using namespace std::literals::chrono_literals;
std::this_thread::sleep_for(1s);
}
return {};
}
// 2. Parse JSON
system::error_code ec;
json::value doc = json::parse(body, ec);
if (ec) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
ec.message());
return {};
}
// 3. Safe Extraction
const json::object* obj = doc.if_object();
if (!obj) {
spdlog::warn("WikipediaService: Expected root object for '{}'", query);
return {};
}
const json::value* query_ptr = obj->if_contains("query");
const json::value* pages_ptr =
(query_ptr && query_ptr->is_object())
? query_ptr->get_object().if_contains("pages")
: nullptr;
if (!pages_ptr || !pages_ptr->is_object()) {
spdlog::warn("WikipediaService: Missing query.pages for '{}'", query);
return {};
}
const json::object& pages = pages_ptr->get_object();
if (pages.empty()) {
spdlog::warn("WikipediaService: No pages returned for '{}'", query);
this->extract_cache_.emplace(cache_key, "");
return {};
}
// Wikipedia returns the page under a dynamic ID key; we just want the first
// one
const json::value& page_val = pages.begin()->value();
if (!page_val.is_object()) {
spdlog::warn("WikipediaService: Unexpected page format for '{}'", query);
return {};
}
const json::object& page = page_val.get_object();
// Handle 404/Missing status
if (page.contains("missing")) {
spdlog::warn("WikipediaService: Page '{}' does not exist", query);
this->extract_cache_.emplace(cache_key, "");
return {};
}
const json::value* extract_ptr = page.if_contains("extract");
if (!extract_ptr || !extract_ptr->is_string()) {
spdlog::warn("WikipediaService: No extract string found for '{}'", query);
this->extract_cache_.emplace(cache_key, "");
return {};
}
// 4. Success
std::string extract(extract_ptr->as_string());
spdlog::info("WikipediaService: Fetched {} chars for '{}'", extract.size(),
query);
this->extract_cache_.insert_or_assign(cache_key, extract);
return extract;
}

View File

@@ -5,25 +5,31 @@
#include <spdlog/spdlog.h>
#include <chrono>
#include <string>
#include <thread>
#include "services/enrichment/wikipedia_service.h"
std::string WikipediaService::GetLocationContext(const Location& loc) {
if (!client_) {
using namespace std::literals::chrono_literals;
if (!this->client_) {
spdlog::warn("Client is nullptr.");
return {};
}
std::string result;
std::string region_query(loc.city);
if (!loc.country.empty()) {
region_query += ", ";
region_query += loc.country;
}
// std::string region_query(loc.city);
// if (!loc.country.empty()) {
// region_query += loc.state_province,
// region_query += ", ";
// region_query += loc.country;
// }
const std::string beer_query = "beer in " + loc.country;
const std::string city_beer_query = "beer in " + loc.city;
constexpr std::string_view brewing_query = "brewing";
const std::string location_query =
std::format("{}, {}", loc.city, loc.iso3166_2);
const std::string beer_query = std::format("beer in {}", loc.country);
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
@@ -36,11 +42,14 @@ std::string WikipediaService::GetLocationContext(const Location& loc) {
};
try {
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(brewing_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
spdlog::info("Done fetching for {}. Sleeping for 10 seconds.",
location_query);
std::this_thread::sleep_for(10s);
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
spdlog::debug("WikipediaService lookup failed for '{}': {}", location_query,
e.what());
}
return result;

View File

@@ -12,6 +12,8 @@
#include <string>
#include <utility>
#include "spdlog/spdlog.h"
namespace {
constexpr time_t kConnectionTimeoutSeconds = 5;
constexpr time_t kReadTimeoutSeconds = 10;
@@ -38,8 +40,12 @@ std::string HttpWebClient::Get(const std::string& url) {
client.set_follow_location(true);
client.set_connection_timeout(kConnectionTimeoutSeconds);
client.set_read_timeout(kReadTimeoutSeconds);
client.set_default_headers({
{"Accept", "application/json"},
{"User-Agent", "biergarten-pipeline/1.0"}
});
const auto result = client.Get(path);
const httplib::Result result = client.Get(path);
if (!result) {
throw std::runtime_error(
@@ -48,6 +54,7 @@ std::string HttpWebClient::Get(const std::string& url) {
}
if (result->status < kSuccessMin || result->status >= kSuccessMax) {
spdlog::error("[HttpWebClient] Request failed for URL: " + url);
throw std::runtime_error(
"[HttpWebClient] HTTP " + std::to_string(result->status) +
" for URL: " + url);
@@ -56,6 +63,6 @@ std::string HttpWebClient::Get(const std::string& url) {
return result->body;
}
std::string HttpWebClient::UrlEncode(const std::string& value) {
std::string HttpWebClient::EncodeURL(const std::string& value) {
return httplib::encode_uri_component(value);
}