Refactor BiergartenDataGenerator to use dependency injection container

This commit is contained in:
Aaron Po
2026-04-09 20:33:48 -04:00
parent 5d93d76e99
commit 824f5b2b4f
23 changed files with 332 additions and 394 deletions

View File

@@ -8,6 +8,7 @@
#include "biergarten_data_generator.h"
BiergartenDataGenerator::BiergartenDataGenerator(
ApplicationOptions const& options, std::shared_ptr<WebClient> web_client)
: options_(options), webClient_(std::move(web_client)) {
}
std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator)
: context_service_(std::move(context_service)),
generator_(std::move(generator)) {}

View File

@@ -1,65 +0,0 @@
/**
* @file biergarten_data_generator/enrich_with_wikipedia.cpp
* @brief BiergartenDataGenerator::EnrichWithWikipedia() implementation.
*/
#include <spdlog/spdlog.h>
#include <atomic>
#include <future>
#include <optional>
#include "biergarten_data_generator.h"
#include "wikipedia/wikipedia_service.h"
static auto TryGetRegionContext(
const std::shared_ptr<WebClient>& web_client, const Location* city_ptr,
std::atomic<size_t>* skipped_enrichment_count) noexcept
-> std::optional<std::string> {
try {
WikipediaService wikipedia_service(web_client);
return wikipedia_service.GetSummary(city_ptr->city, city_ptr->country);
} catch (...) {
skipped_enrichment_count->fetch_add(1, std::memory_order_relaxed);
return std::nullopt;
}
}
auto BiergartenDataGenerator::EnrichWithWikipedia(
const std::vector<Location>& cities) -> std::vector<EnrichedCity> {
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
std::atomic<size_t> skipped_enrichment_count = 0;
std::vector<std::future<std::optional<std::string>>> pending;
pending.reserve(cities.size());
for (const auto& city : cities) {
const Location* city_ptr = &city;
pending.push_back(std::async(std::launch::async, TryGetRegionContext,
webClient_, city_ptr,
&skipped_enrichment_count));
}
auto city_it = cities.cbegin();
for (auto& task : pending) {
auto maybe_region_context = task.get();
if (maybe_region_context.has_value()) {
spdlog::debug("[Pipeline] Region context for {}: {}", city_it->city,
*maybe_region_context);
enriched.push_back(
EnrichedCity{.location = *city_it,
.region_context = std::move(*maybe_region_context)});
}
++city_it;
}
if (skipped_enrichment_count.load(std::memory_order_relaxed) > 0) {
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to Wikipedia enrichment "
"errors",
skipped_enrichment_count.load(std::memory_order_relaxed));
}
return enriched;
}

View File

@@ -8,7 +8,7 @@
#include "biergarten_data_generator.h"
void BiergartenDataGenerator::GenerateBreweries(
DataGenerator& generator, const std::vector<EnrichedCity>& cities) {
const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generatedBreweries_.clear();
@@ -16,7 +16,7 @@ void BiergartenDataGenerator::GenerateBreweries(
for (const auto& enriched_city : cities) {
try {
auto brewery = generator.GenerateBrewery(
auto brewery = generator_->GenerateBrewery(
enriched_city.location.city, enriched_city.location.country,
enriched_city.region_context);
generatedBreweries_.push_back(GeneratedBrewery{

View File

@@ -1,35 +0,0 @@
/**
* @file biergarten_data_generator/initialize_generator.cpp
* @brief BiergartenDataGenerator::InitializeGenerator() implementation.
*/
#include <spdlog/spdlog.h>
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
auto BiergartenDataGenerator::InitializeGenerator() const
-> std::unique_ptr<DataGenerator> {
spdlog::info("Initializing brewery generator...");
std::unique_ptr<DataGenerator> generator;
if (options_.model_path.empty()) {
generator = std::make_unique<MockGenerator>();
spdlog::info("[Generator] Using MockGenerator (no model path provided)");
} else {
auto llama_generator = std::make_unique<LlamaGenerator>();
llama_generator->SetSamplingOptions(options_.temperature, options_.top_p,
options_.seed);
llama_generator->SetContextSize(options_.n_ctx);
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"n_ctx={}, seed={})",
options_.model_path, options_.temperature, options_.top_p,
options_.n_ctx, options_.seed);
generator = std::move(llama_generator);
}
generator->Load(options_.model_path);
return generator;
}

View File

@@ -9,10 +9,35 @@
auto BiergartenDataGenerator::Run() -> bool {
try {
const std::unique_ptr<DataGenerator> generator = InitializeGenerator();
const std::vector<Location> cities = QueryCitiesWithCountries();
const std::vector<EnrichedCity> enriched = EnrichWithWikipedia(cities);
this->GenerateBreweries(*generator, enriched);
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
size_t skipped_count = 0;
for (const auto& city : cities) {
try {
const std::string region_context =
context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
enriched.push_back(EnrichedCity{.location = city,
.region_context = region_context});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what());
}
}
if (skipped_count > 0) {
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count);
}
this->GenerateBreweries(enriched);
this->LogResults();
return true;
} catch (const std::exception& e) {

View File

@@ -0,0 +1,53 @@
/**
* @file data_generation/llama/constructor.cpp
* @brief LlamaGenerator constructor implementation.
*/
#include <llama.h>
#include <stdexcept>
#include <string>
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (options.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_seed_ = (options.seed < 0)
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(options.seed);
n_ctx_ = options.n_ctx;
try {
Load(model_path);
} catch (...) {
llama_backend_free();
throw;
}
}

View File

@@ -1,7 +1,7 @@
/**
* @file data_generation/llama/load.cpp
* @brief Initializes llama backend, loads model weights, creates inference
* context, and resets prior resources during model reload.
* context, and resets prior resources during model initialization.
*/
#include <spdlog/spdlog.h>
@@ -13,12 +13,6 @@
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
/**
* Validate input and clean up any previously loaded model/context
*/
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;

View File

@@ -1,64 +0,0 @@
/**
* @file data_generation/llama/set_sampling_options.cpp
* @brief Validates and stores sampling temperature, top-p, seed, and context
* size configuration used by subsequent LlamaGenerator inference calls.
*/
#include <stdexcept>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
/**
* Validate temperature: controls randomness in output distribution
* 0.0 = deterministic (always pick highest probability token)
* Higher values = more random/diverse output
*/
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
/**
* Validate top-p (nucleus sampling): only sample from top cumulative
* probability e.g., top-p=0.9 means sample from tokens that make up 90% of
* probability mass
*/
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
/**
* Validate seed: for reproducible results (-1 uses random seed)
*/
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
/**
* Store sampling parameters for use during token generation
*/
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::SetContextSize(uint32_t n_ctx) {
/**
* Validate context size: must be positive and reasonable for the model
*/
if (n_ctx == 0 || n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
/**
* Store context size for use during model loading
*/
n_ctx_ = n_ctx;
}

View File

@@ -1,15 +0,0 @@
/**
* @file data_generation/mock/load.cpp
* @brief Provides MockGenerator initialization behavior, which is a no-op load
* path that logs readiness without model resources.
*/
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -6,15 +6,22 @@
#include <spdlog/spdlog.h>
#include <boost/di.hpp>
#include <boost/program_options.hpp>
#include <exception>
#include <memory>
#include <sstream>
#include <string>
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
#include "services/enrichment_service.h"
#include "services/wikipedia_service.h"
#include "web_client/curl_web_client.h"
namespace prog_opts = boost::program_options;
namespace di = boost::di;
/**
* @brief Parse command-line arguments into ApplicationOptions.
@@ -44,26 +51,27 @@ auto ParseArguments(const int argc, char** argv,
// Handle the "no arguments" or "help" case
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream ss;
ss << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(ss.str());
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
return false;
}
try {
prog_opts::variables_map vm;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc), vm);
prog_opts::notify(vm);
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
if (vm.contains("help")) {
std::stringstream ss;
ss << "\n" << desc;
spdlog::info(ss.str());
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return false;
}
const auto use_mocked = vm["mocked"].as<bool>();
const auto model_path = vm["model"].as<std::string>();
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error(
@@ -77,9 +85,9 @@ auto ParseArguments(const int argc, char** argv,
return false;
}
const bool has_llm_params = !vm["temperature"].defaulted() ||
!vm["top-p"].defaulted() ||
!vm["seed"].defaulted();
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["seed"].defaulted();
if (use_mocked && has_llm_params) {
spdlog::warn(
@@ -89,10 +97,10 @@ auto ParseArguments(const int argc, char** argv,
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = vm["temperature"].as<float>();
options.top_p = vm["top-p"].as<float>();
options.n_ctx = vm["n-ctx"].as<uint32_t>();
options.seed = vm["seed"].as<int>();
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return true;
} catch (const std::exception& exception) {
@@ -115,8 +123,29 @@ auto main(const int argc, char** argv) noexcept -> int {
return 0;
}
auto webClient = std::make_shared<CURLWebClient>();
BiergartenDataGenerator generator(options, std::move(webClient));
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to([options](const auto& injector)
-> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>();
}
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, "
"n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.n_ctx, options.seed);
return injector.template create<std::unique_ptr<LlamaGenerator>>();
}));
auto generator = injector.create<BiergartenDataGenerator>();
if (!generator.Run()) {
spdlog::error("Pipeline execution failed");

View File

@@ -5,7 +5,7 @@
#include <utility>
#include "wikipedia/wikipedia_service.h"
#include "services/wikipedia_service.h"
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {}

View File

@@ -9,7 +9,7 @@
#include <string>
#include <string_view>
#include "wikipedia/wikipedia_service.h"
#include "services/wikipedia_service.h"
auto WikipediaService::FetchExtract(std::string_view query) const
-> std::string {

View File

@@ -0,0 +1,54 @@
/**
* @file wikipedia/get_summary.cpp
* @brief WikipediaService::GetLocationContext() implementation.
*/
#include <spdlog/spdlog.h>
#include <string>
#include "services/wikipedia_service.h"
auto WikipediaService::GetLocationContext(const Location& loc) -> std::string {
const std::string cache_key = loc.city + "|" + loc.country;
const auto cache_it = cache_.find(cache_key);
if (cache_it != cache_.end()) {
return cache_it->second;
}
std::string result;
if (!client_) {
cache_.emplace(cache_key, result);
return result;
}
std::string region_query(loc.city);
if (!loc.country.empty()) {
region_query += ", ";
region_query += loc.country;
}
const std::string beer_query = "beer in " + loc.country;
try {
const std::string region_extract = FetchExtract(region_query);
const std::string beer_extract = FetchExtract(beer_query);
if (!region_extract.empty()) {
result += region_extract;
}
if (!beer_extract.empty()) {
if (!result.empty()) {
result += "\n\n";
}
result += beer_extract;
}
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());
}
cache_.emplace(cache_key, result);
return result;
}

View File

@@ -1,55 +0,0 @@
/**
* @file wikipedia/get_summary.cpp
* @brief WikipediaService::GetSummary() implementation.
*/
#include <spdlog/spdlog.h>
#include <string>
#include "wikipedia/wikipedia_service.h"
auto WikipediaService::GetSummary(std::string_view city,
std::string_view country) -> std::string {
const std::string key = std::string(city) + "|" + std::string(country);
const auto cacheIt = cache_.find(key);
if (cacheIt != cache_.end()) {
return cacheIt->second;
}
std::string result;
if (!client_) {
cache_.emplace(key, result);
return result;
}
std::string regionQuery(city);
if (!country.empty()) {
regionQuery += ", ";
regionQuery += country;
}
const std::string beerQuery = "beer in " + std::string(country);
try {
const std::string regionExtract = FetchExtract(regionQuery);
const std::string beerExtract = FetchExtract(beerQuery);
if (!regionExtract.empty()) {
result += regionExtract;
}
if (!beerExtract.empty()) {
if (!result.empty()) {
result += "\n\n";
}
result += beerExtract;
}
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
e.what());
}
cache_.emplace(key, result);
return result;
}