Pipeline: add CURL/WebClient & Wikipedia service

Introduce a pluggable web client interface and concrete CURL implementation: adds IWebClient, CURLWebClient, and CurlGlobalState (headers + curl_web_client.cpp). DataDownloader now accepts an IWebClient and delegates downloads. Add WikipediaService for cached Wikipedia summary lookups. Refactor SqliteDatabase to return full City records and update consumers accordingly. Improve JsonLoader to use batched transactions during streaming parses. Enhance LlamaGenerator with sampling options, increased token limits, JSON extraction/validation, and other parsing helpers. Modernize CMake: set policy/version, add project_options, simplify FetchContent usage (spdlog), require Boost components (program_options/json), list pipeline sources explicitly, and tweak post-build/memcheck targets. Update README to match implementation changes and new CLI/config conventions.
This commit is contained in:
Aaron Po
2026-04-02 16:29:16 -04:00
parent ac136f7179
commit 98083ab40c
16 changed files with 1125 additions and 794 deletions

View File

@@ -1,35 +1,66 @@
#include <algorithm>
#include <filesystem>
#include <iostream>
#include <memory>
#include <unordered_map>
#include <vector>
#include <boost/program_options.hpp>
#include <spdlog/spdlog.h>
#include "curl_web_client.h"
#include "data_downloader.h"
#include "data_generator.h"
#include "database.h"
#include "json_loader.h"
#include "llama_generator.h"
#include "mock_generator.h"
#include <curl/curl.h>
#include <filesystem>
#include <memory>
#include <spdlog/spdlog.h>
#include <vector>
#include "wikipedia_service.h"
static bool FileExists(const std::string &filePath) {
return std::filesystem::exists(filePath);
}
namespace po = boost::program_options;
int main(int argc, char *argv[]) {
try {
curl_global_init(CURL_GLOBAL_DEFAULT);
const CurlGlobalState curl_state;
std::string modelPath = argc > 1 ? argv[1] : "";
std::string cacheDir = argc > 2 ? argv[2] : "/tmp";
std::string commit =
argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28
po::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"model,m", po::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"cache-dir,c", po::value<std::string>()->default_value("/tmp"),
"Directory for cached JSON")(
"temperature", po::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")(
"top-p", po::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"seed", po::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer")(
"commit", po::value<std::string>()->default_value("c5eb7772"),
"Git commit hash for DB consistency");
std::string countryName = argc > 4 ? argv[4] : "";
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
if (vm.count("help")) {
std::cout << desc << "\n";
return 0;
}
std::string modelPath = vm["model"].as<std::string>();
std::string cacheDir = vm["cache-dir"].as<std::string>();
float temperature = vm["temperature"].as<float>();
float topP = vm["top-p"].as<float>();
int seed = vm["seed"].as<int>();
std::string commit = vm["commit"].as<std::string>();
std::string jsonPath = cacheDir + "/countries+states+cities.json";
std::string dbPath = cacheDir + "/biergarten-pipeline.db";
bool hasJsonCache = FileExists(jsonPath);
bool hasDbCache = FileExists(dbPath);
bool hasJsonCache = std::filesystem::exists(jsonPath);
bool hasDbCache = std::filesystem::exists(dbPath);
auto webClient = std::make_shared<CURLWebClient>();
SqliteDatabase db;
@@ -40,7 +71,7 @@ int main(int argc, char *argv[]) {
spdlog::info("[Pipeline] Cache hit: skipping download and parse");
} else {
spdlog::info("\n[Pipeline] Downloading geographic data from GitHub...");
DataDownloader downloader;
DataDownloader downloader(webClient);
downloader.DownloadCountriesDatabase(jsonPath, commit);
JsonLoader::LoadWorldCities(jsonPath, db);
@@ -52,17 +83,30 @@ int main(int argc, char *argv[]) {
generator = std::make_unique<MockGenerator>();
spdlog::info("[Generator] Using MockGenerator (no model path provided)");
} else {
generator = std::make_unique<LlamaGenerator>();
spdlog::info("[Generator] Using LlamaGenerator: {}", modelPath);
auto llamaGenerator = std::make_unique<LlamaGenerator>();
llamaGenerator->setSamplingOptions(temperature, topP, seed);
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"seed={})",
modelPath, temperature, topP, seed);
generator = std::move(llamaGenerator);
}
generator->load(modelPath);
WikipediaService wikipediaService(webClient);
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
auto countries = db.QueryCountries(50);
auto states = db.QueryStates(50);
auto cities = db.QueryCities();
// Build a quick map of country id -> name for per-city lookups.
auto allCountries = db.QueryCountries(0);
std::unordered_map<int, std::string> countryMap;
for (const auto &c : allCountries)
countryMap[c.id] = c.name;
spdlog::info("\nTotal records loaded:");
spdlog::info(" Countries: {}", db.QueryCountries(0).size());
spdlog::info(" States: {}", db.QueryStates(0).size());
@@ -79,8 +123,23 @@ int main(int argc, char *argv[]) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
for (size_t i = 0; i < sampleCount; i++) {
const auto &[cityId, cityName] = cities[i];
auto brewery = generator->generateBrewery(cityName, countryName, "");
const auto &city = cities[i];
const int cityId = city.id;
const std::string cityName = city.name;
std::string localCountry;
const auto countryIt = countryMap.find(city.countryId);
if (countryIt != countryMap.end()) {
localCountry = countryIt->second;
}
const std::string regionContext =
wikipediaService.GetSummary(cityName, localCountry);
spdlog::debug("[Pipeline] Region context for {}: {}", cityName,
regionContext);
auto brewery =
generator->generateBrewery(cityName, localCountry, regionContext);
generatedBreweries.push_back({cityId, cityName, brewery});
}
@@ -95,12 +154,10 @@ int main(int argc, char *argv[]) {
spdlog::info("\nOK: Pipeline completed successfully");
curl_global_cleanup();
return 0;
} catch (const std::exception &e) {
spdlog::error("ERROR: Pipeline failed: {}", e.what());
curl_global_cleanup();
return 1;
}
}