diff --git a/tooling/pipeline/includes/data_model/models.h b/tooling/pipeline/includes/data_model/models.h index f08cf41..c046557 100644 --- a/tooling/pipeline/includes/data_model/models.h +++ b/tooling/pipeline/includes/data_model/models.h @@ -83,6 +83,9 @@ struct SamplingOptions { /// @brief Random seed (-1 for random, otherwise non-negative). int seed = -1; + + /// @brief Number of layers to offload to GPU. + int n_gpu_layers = 0; }; /** @@ -95,8 +98,7 @@ struct GeneratorOptions { /// @brief Use mocked generator instead of actual LLM inference. bool use_mocked = false; - /// @brief Number of layers to offload to GPU. - int n_gpu_layers = 0; + /// @brief Specific sampling parameters for this generator. /// If nullopt, the application should use global defaults. diff --git a/tooling/pipeline/includes/web_client/http_web_client.h b/tooling/pipeline/includes/web_client/http_web_client.h index 778d5d3..a38beba 100644 --- a/tooling/pipeline/includes/web_client/http_web_client.h +++ b/tooling/pipeline/includes/web_client/http_web_client.h @@ -42,7 +42,7 @@ public: * @param value Raw string to encode. * @return Percent-encoded string safe for use in a URL. */ - std::string UrlEncode(const std::string& value) override; + std::string EncodeURL(const std::string& value) override; }; diff --git a/tooling/pipeline/includes/web_client/web_client.h b/tooling/pipeline/includes/web_client/web_client.h index bb16323..641eb12 100644 --- a/tooling/pipeline/includes/web_client/web_client.h +++ b/tooling/pipeline/includes/web_client/web_client.h @@ -30,7 +30,7 @@ class WebClient { * @param value Raw string value. * @return Encoded value safe for URL usage. */ - virtual std::string UrlEncode(const std::string& value) = 0; + virtual std::string EncodeURL(const std::string& value) = 0; }; #endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_ diff --git a/tooling/pipeline/src/application_options/parse_arguments.cc b/tooling/pipeline/src/application_options/parse_arguments.cc index b06c1b7..e568bd9 100644 --- a/tooling/pipeline/src/application_options/parse_arguments.cc +++ b/tooling/pipeline/src/application_options/parse_arguments.cc @@ -30,6 +30,8 @@ std::optional ParseArguments(const int argc, char** argv) { "Context window size in tokens"); opt("seed", prog_opts::value()->default_value(sampling_defaults.seed), "Sampler seed: -1 for random, otherwise non-negative integer"); + opt("n-gpu-layers", prog_opts::value()->default_value(0), + "Number of layers to offload to GPU"); }; // --mocked and --model are mutually exclusive; validation is enforced below @@ -50,8 +52,7 @@ std::optional ParseArguments(const int argc, char** argv) { opt("prompt-dir", prog_opts::value()->default_value(""), "Directory containing named prompt files (e.g. BREWERY_GENERATION.md)." " Required when not using --mocked."); - opt("n-gpu-layers", prog_opts::value()->default_value(0), - "Number of layers to offload to GPU"); + }; add_sampling_options(); @@ -113,7 +114,7 @@ std::optional ParseArguments(const int argc, char** argv) { options.generator.use_mocked = use_mocked; options.generator.model_path = model_path; - options.generator.n_gpu_layers = n_gpu_layers; + // options.generator.n_gpu_layers = n_gpu_layers; // Only populate sampling config when the user explicitly overrides at // least one value. Leaving it as std::nullopt lets LlamaGenerator fall @@ -122,7 +123,7 @@ std::optional ParseArguments(const int argc, char** argv) { const bool user_provided_sampling = !var_map["temperature"].defaulted() || !var_map["top-p"].defaulted() || !var_map["top-k"].defaulted() || !var_map["n-ctx"].defaulted() || - !var_map["seed"].defaulted(); + !var_map["seed"].defaulted() || !var_map["n_gpu_layers"].defaulted(); if (user_provided_sampling) { // Warn but do not fail — the run is still valid, the flags are just @@ -136,6 +137,7 @@ std::optional ParseArguments(const int argc, char** argv) { sampling.top_k = var_map["top-k"].as(); sampling.n_ctx = var_map["n-ctx"].as(); sampling.seed = var_map["seed"].as(); + sampling.n_gpu_layers = var_map["n-gpu-layers"].as(); options.generator.sampling = sampling; } diff --git a/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc b/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc index 5cf60b6..2427a15 100644 --- a/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc +++ b/tooling/pipeline/src/biergarten_data_generator/query_cities_with_countries.cc @@ -13,7 +13,7 @@ #include "biergarten_data_generator.h" #include "json_handling/json_loader.h" -static constexpr size_t kBreweryAmount = 50; +static constexpr size_t kBreweryAmount = 40; std::vector BiergartenDataGenerator::QueryCitiesWithCountries() { spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); diff --git a/tooling/pipeline/src/biergarten_data_generator/run.cc b/tooling/pipeline/src/biergarten_data_generator/run.cc index 82ebbfc..4ee2b46 100644 --- a/tooling/pipeline/src/biergarten_data_generator/run.cc +++ b/tooling/pipeline/src/biergarten_data_generator/run.cc @@ -21,8 +21,8 @@ bool BiergartenDataGenerator::Run() { for (auto& city : cities) { try { std::string region_context = context_service_->GetLocationContext(city); - spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}", - city.city, city.country, region_context); + // spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}", + // city.city, city.iso3166_2, region_context); enriched.push_back( EnrichedCity{.location = std::move(city), diff --git a/tooling/pipeline/src/data_generation/llama/llama_generator.cc b/tooling/pipeline/src/data_generation/llama/llama_generator.cc index 72a888e..d780f2f 100644 --- a/tooling/pipeline/src/data_generation/llama/llama_generator.cc +++ b/tooling/pipeline/src/data_generation/llama/llama_generator.cc @@ -89,7 +89,7 @@ LlamaGenerator::LlamaGenerator( } n_ctx_ = sampling.n_ctx; - n_gpu_layers_ = options.generator.n_gpu_layers; + n_gpu_layers_ = sampling.n_gpu_layers; this->Load(model_path); } diff --git a/tooling/pipeline/src/main.cc b/tooling/pipeline/src/main.cc index 3b2a3ce..caaed6d 100644 --- a/tooling/pipeline/src/main.cc +++ b/tooling/pipeline/src/main.cc @@ -8,11 +8,9 @@ #include #include - #include #include #include - #include #include "biergarten_data_generator.h" @@ -21,12 +19,12 @@ #include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h" #include "data_model/models.h" #include "llama_backend_state.h" -#include "services/enrichment/enrichment_service.h" #include "services/database/export_service.h" -#include "services/prompting/prompt_directory.h" #include "services/database/sqlite_export_service.h" #include "services/datetime/timer.h" +#include "services/enrichment/enrichment_service.h" #include "services/enrichment/wikipedia_service.h" +#include "services/prompting/prompt_directory.h" #include "web_client/http_web_client.h" namespace di = boost::di; @@ -43,7 +41,9 @@ int main(const int argc, char** argv) { spdlog::set_level(spdlog::level::debug); #endif - const auto parsed_options = ParseArguments(argc, argv); + const std::optional parsed_options = + ParseArguments(argc, argv); + if (!parsed_options.has_value()) { return 0; } @@ -73,7 +73,7 @@ int main(const int argc, char** argv) { di::bind().to(model_path), di::bind().to( [options, model_path, sampling, &prompt_directory]( - const auto& inj) -> std::unique_ptr { + const auto& inj) -> std::unique_ptr { if (options.generator.use_mocked) { spdlog::info( "[Generator] Using MockGenerator (no model path provided)"); @@ -89,7 +89,9 @@ int main(const int argc, char** argv) { options, model_path, inj.template create>(), std::move(prompt_directory)); - })); + }) + + ); auto generator = injector.create>(); diff --git a/tooling/pipeline/src/services/wikipedia/fetch_extract.cc b/tooling/pipeline/src/services/wikipedia/fetch_extract.cc index 748ed36..e4ab7d0 100644 --- a/tooling/pipeline/src/services/wikipedia/fetch_extract.cc +++ b/tooling/pipeline/src/services/wikipedia/fetch_extract.cc @@ -1,61 +1,106 @@ /** * @file wikipedia/fetch_extract.cc - * @brief WikipediaService::FetchExtract() implementation. */ #include #include +#include #include #include #include "services/enrichment/wikipedia_service.h" +using namespace boost; + std::string WikipediaService::FetchExtract(std::string_view query) { + const std::string cache_key(query); - const auto cache_it = this->extract_cache_.find(cache_key); - if (cache_it != this->extract_cache_.end()) { + + // 1. Cache Lookup + if (const auto cache_it = this->extract_cache_.find(cache_key); + cache_it != this->extract_cache_.end()) { + spdlog::debug("Wikipedia: Cache hit for {}!", cache_key); return cache_it->second; } - const std::string encoded = this->client_->UrlEncode(cache_key); - const std::string url = - "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + - "&prop=extracts&explaintext=1&format=json"; + const std::string encoded = this->client_->EncodeURL(cache_key); + const std::string url = std::format( + "https://en.wikipedia.org/w/" + "api.php?action=query&titles={}&prop=extracts&explaintext=1&format=json", + encoded); + const std::string body = this->client_->Get(url); - - boost::system::error_code parse_error; - boost::json::value doc = boost::json::parse(body, parse_error); - - if (!parse_error && doc.is_object()) { - try { - auto& pages = doc.at("query").at("pages").get_object(); - if (!pages.empty()) { - auto& page = pages.begin()->value().get_object(); - if (page.contains("extract") && page.at("extract").is_string()) { - const std::string_view extract_view = page.at("extract").as_string(); - std::string extract(extract_view); - - spdlog::debug("WikipediaService fetched {} chars for '{}'", - extract.size(), query); - - this->extract_cache_.emplace(cache_key, extract); - return extract; - } - } - this->extract_cache_.emplace(cache_key, std::string{}); - } catch (const std::exception& e) { - spdlog::warn( - "WikipediaService: failed to parse response structure for '{}': " - "{}", - query, e.what()); - return {}; - } - } else if (parse_error) { - spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query, - parse_error.message()); + { + using namespace std::literals::chrono_literals; + std::this_thread::sleep_for(1s); } - return {}; -} + // 2. Parse JSON + system::error_code ec; + json::value doc = json::parse(body, ec); + + if (ec) { + spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query, + ec.message()); + return {}; + } + + // 3. Safe Extraction + const json::object* obj = doc.if_object(); + if (!obj) { + spdlog::warn("WikipediaService: Expected root object for '{}'", query); + return {}; + } + + const json::value* query_ptr = obj->if_contains("query"); + const json::value* pages_ptr = + (query_ptr && query_ptr->is_object()) + ? query_ptr->get_object().if_contains("pages") + : nullptr; + + if (!pages_ptr || !pages_ptr->is_object()) { + spdlog::warn("WikipediaService: Missing query.pages for '{}'", query); + return {}; + } + + const json::object& pages = pages_ptr->get_object(); + if (pages.empty()) { + spdlog::warn("WikipediaService: No pages returned for '{}'", query); + this->extract_cache_.emplace(cache_key, ""); + return {}; + } + + // Wikipedia returns the page under a dynamic ID key; we just want the first + // one + const json::value& page_val = pages.begin()->value(); + if (!page_val.is_object()) { + spdlog::warn("WikipediaService: Unexpected page format for '{}'", query); + return {}; + } + + const json::object& page = page_val.get_object(); + + // Handle 404/Missing status + if (page.contains("missing")) { + spdlog::warn("WikipediaService: Page '{}' does not exist", query); + this->extract_cache_.emplace(cache_key, ""); + return {}; + } + + const json::value* extract_ptr = page.if_contains("extract"); + if (!extract_ptr || !extract_ptr->is_string()) { + spdlog::warn("WikipediaService: No extract string found for '{}'", query); + this->extract_cache_.emplace(cache_key, ""); + return {}; + } + + // 4. Success + std::string extract(extract_ptr->as_string()); + spdlog::info("WikipediaService: Fetched {} chars for '{}'", extract.size(), + query); + + this->extract_cache_.insert_or_assign(cache_key, extract); + return extract; +} \ No newline at end of file diff --git a/tooling/pipeline/src/services/wikipedia/get_summary.cc b/tooling/pipeline/src/services/wikipedia/get_summary.cc index 16fc7b6..e58bbf0 100644 --- a/tooling/pipeline/src/services/wikipedia/get_summary.cc +++ b/tooling/pipeline/src/services/wikipedia/get_summary.cc @@ -5,25 +5,31 @@ #include +#include #include +#include #include "services/enrichment/wikipedia_service.h" - std::string WikipediaService::GetLocationContext(const Location& loc) { - if (!client_) { + using namespace std::literals::chrono_literals; + if (!this->client_) { + spdlog::warn("Client is nullptr."); return {}; } std::string result; - std::string region_query(loc.city); - if (!loc.country.empty()) { - region_query += ", "; - region_query += loc.country; - } + // std::string region_query(loc.city); + // if (!loc.country.empty()) { + // region_query += loc.state_province, + // region_query += ", "; + // region_query += loc.country; + // } - const std::string beer_query = "beer in " + loc.country; - const std::string city_beer_query = "beer in " + loc.city; + constexpr std::string_view brewing_query = "brewing"; + const std::string location_query = + std::format("{}, {}", loc.city, loc.iso3166_2); + const std::string beer_query = std::format("beer in {}", loc.country); auto append_extract = [&result](const std::string& extract) -> void { if (extract.empty()) { @@ -36,11 +42,14 @@ std::string WikipediaService::GetLocationContext(const Location& loc) { }; try { - append_extract(FetchExtract(region_query)); + append_extract(FetchExtract(brewing_query)); append_extract(FetchExtract(beer_query)); - append_extract(FetchExtract(city_beer_query)); + spdlog::info("Done fetching for {}. Sleeping for 10 seconds.", + location_query); + std::this_thread::sleep_for(10s); + } catch (const std::runtime_error& e) { - spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query, + spdlog::debug("WikipediaService lookup failed for '{}': {}", location_query, e.what()); } return result; diff --git a/tooling/pipeline/src/web_client/http_web_client.cc b/tooling/pipeline/src/web_client/http_web_client.cc index aba30cf..4653102 100644 --- a/tooling/pipeline/src/web_client/http_web_client.cc +++ b/tooling/pipeline/src/web_client/http_web_client.cc @@ -12,6 +12,8 @@ #include #include +#include "spdlog/spdlog.h" + namespace { constexpr time_t kConnectionTimeoutSeconds = 5; constexpr time_t kReadTimeoutSeconds = 10; @@ -38,8 +40,12 @@ std::string HttpWebClient::Get(const std::string& url) { client.set_follow_location(true); client.set_connection_timeout(kConnectionTimeoutSeconds); client.set_read_timeout(kReadTimeoutSeconds); + client.set_default_headers({ + {"Accept", "application/json"}, + {"User-Agent", "biergarten-pipeline/1.0"} + }); - const auto result = client.Get(path); + const httplib::Result result = client.Get(path); if (!result) { throw std::runtime_error( @@ -48,6 +54,7 @@ std::string HttpWebClient::Get(const std::string& url) { } if (result->status < kSuccessMin || result->status >= kSuccessMax) { + spdlog::error("[HttpWebClient] Request failed for URL: " + url); throw std::runtime_error( "[HttpWebClient] HTTP " + std::to_string(result->status) + " for URL: " + url); @@ -56,6 +63,6 @@ std::string HttpWebClient::Get(const std::string& url) { return result->body; } -std::string HttpWebClient::UrlEncode(const std::string& value) { +std::string HttpWebClient::EncodeURL(const std::string& value) { return httplib::encode_uri_component(value); } \ No newline at end of file