Refactor data generator constructor and update web client handling; enhance README with detailed pipeline overview and class diagram

This commit is contained in:
Aaron Po
2026-04-09 18:19:12 -04:00
parent 028786b8b5
commit 5d93d76e99
10 changed files with 249 additions and 164 deletions

View File

@@ -8,5 +8,6 @@
#include "biergarten_data_generator.h"
BiergartenDataGenerator::BiergartenDataGenerator(
ApplicationOptions options, std::unique_ptr<WebClient> web_client)
: options_(std::move(options)), webClient_(std::move(web_client)) {}
ApplicationOptions const& options, std::shared_ptr<WebClient> web_client)
: options_(options), webClient_(std::move(web_client)) {
}

View File

@@ -12,11 +12,9 @@
#include "biergarten_data_generator.h"
#include "wikipedia/wikipedia_service.h"
namespace {
auto TryGetRegionContext(const std::shared_ptr<WebClient>& web_client,
const Location* city_ptr,
std::atomic<size_t>* skipped_enrichment_count) noexcept
static auto TryGetRegionContext(
const std::shared_ptr<WebClient>& web_client, const Location* city_ptr,
std::atomic<size_t>* skipped_enrichment_count) noexcept
-> std::optional<std::string> {
try {
WikipediaService wikipedia_service(web_client);
@@ -27,8 +25,6 @@ auto TryGetRegionContext(const std::shared_ptr<WebClient>& web_client,
}
}
} // namespace
auto BiergartenDataGenerator::EnrichWithWikipedia(
const std::vector<Location>& cities) -> std::vector<EnrichedCity> {
std::vector<EnrichedCity> enriched;

View File

@@ -16,12 +16,10 @@
#include "data_generation/llama_generator.h"
#include "llama.h"
namespace {
/**
* String trimming: removes leading and trailing whitespace
*/
std::string Trim(std::string value) {
static std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
@@ -36,7 +34,7 @@ std::string Trim(std::string value) {
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces
*/
std::string CondenseWhitespace(std::string text) {
static std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
@@ -61,8 +59,8 @@ std::string CondenseWhitespace(std::string text) {
* Truncate region context to fit within max length while preserving word
* boundaries
*/
std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
static std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) {
return normalized;
@@ -81,7 +79,7 @@ std::string PrepareRegionContext(std::string_view region_context,
/**
* Remove common bullet points, numbers, and field labels added by LLM in output
*/
std::string StripCommonPrefix(std::string line) {
static std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
@@ -126,7 +124,7 @@ std::string StripCommonPrefix(std::string line) {
* Parse two-line response from LLM: normalize line endings, strip formatting,
* filter spurious output, and combine remaining lines if needed
*/
std::pair<std::string, std::string> ParseTwoLineResponse(
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
@@ -177,8 +175,8 @@ std::pair<std::string, std::string> ParseTwoLineResponse(
/**
* Apply model's chat template to user-only prompt, formatting it for the model
*/
std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) {
static std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return user_prompt;
@@ -214,9 +212,9 @@ std::string ToChatPrompt(const llama_model* model,
* Apply model's chat template to system+user prompt pair, formatting for the
* model
*/
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + user_prompt;
@@ -249,8 +247,8 @@ std::string ToChatPrompt(const llama_model* model,
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
@@ -273,7 +271,8 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) {
static bool ExtractFirstJsonObject(const std::string& text,
std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
@@ -321,8 +320,9 @@ bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) {
return false;
}
std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
std::string& description_out) {
static std::string ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
@@ -403,8 +403,6 @@ std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
return {};
}
} // namespace
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {

View File

@@ -13,10 +13,8 @@
#include <sstream>
#include <stdexcept>
namespace {
auto ReadRequiredString(const boost::json::object& object, const char* key)
-> std::string {
static auto ReadRequiredString(const boost::json::object& object,
const char* key) -> std::string {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) {
throw std::runtime_error(
@@ -25,8 +23,8 @@ auto ReadRequiredString(const boost::json::object& object, const char* key)
return std::string(value->as_string().c_str());
}
auto ReadRequiredNumber(const boost::json::object& object, const char* key)
-> double {
static auto ReadRequiredNumber(const boost::json::object& object,
const char* key) -> double {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) {
throw std::runtime_error(
@@ -35,8 +33,6 @@ auto ReadRequiredNumber(const boost::json::object& object, const char* key)
return value->to_number<double>();
}
} // namespace
auto JsonLoader::LoadLocations(const std::string& filepath)
-> std::vector<Location> {
std::ifstream input(filepath);

View File

@@ -27,25 +27,18 @@ namespace prog_opts = boost::program_options;
auto ParseArguments(const int argc, char** argv,
ApplicationOptions& options) noexcept -> bool {
prog_opts::options_description desc("Pipeline Options");
desc.add_options()
("help,h", "Produce help message")
("mocked",
prog_opts::bool_switch(),
"Use mocked generator for brewery/user data")
("model,m",
prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")
("temperature",
prog_opts::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")
("top-p",
prog_opts::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")
("n-ctx",
prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)")
("seed",
prog_opts::value<int>()->default_value(-1),
desc.add_options()("help,h", "Produce help message")(
"mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"temperature", prog_opts::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")(
"top-p", prog_opts::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)")(
"seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case
@@ -74,13 +67,13 @@ auto ParseArguments(const int argc, char** argv,
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
"Invalid arguments: --mocked and --model are mutually exclusive");
return false;
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
"Invalid arguments: Either --mocked or --model must be specified");
return false;
}
@@ -90,8 +83,8 @@ auto ParseArguments(const int argc, char** argv,
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --seed) are"
" ignored when using --mocked");
"Sampling parameters (--temperature, --top-p, --seed) are"
" ignored when using --mocked");
}
options.use_mocked = use_mocked;
@@ -122,7 +115,7 @@ auto main(const int argc, char** argv) noexcept -> int {
return 0;
}
auto webClient = std::make_unique<CURLWebClient>();
auto webClient = std::make_shared<CURLWebClient>();
BiergartenDataGenerator generator(options, std::move(webClient));
if (!generator.Run()) {
@@ -139,4 +132,4 @@ auto main(const int argc, char** argv) noexcept -> int {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
}
}

View File

@@ -13,11 +13,10 @@
#include "web_client/curl_web_client.h"
namespace {
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() {
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
@@ -26,8 +25,8 @@ CurlHandle create_handle() {
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
static void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
@@ -38,14 +37,13 @@ void set_common_get_options(CURL* curl, const std::string& url,
}
// curl write callback that writes to a file stream
size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void* userp) {
static size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void* userp) {
size_t realsize = size * nmemb;
auto* outFile = static_cast<std::ofstream*>(userp);
outFile->write(static_cast<char*>(contents), realsize);
return realsize;
}
} // namespace
void CURLWebClient::DownloadToFile(const std::string& url,
const std::string& file_path) {

View File

@@ -12,11 +12,10 @@
#include "web_client/curl_web_client.h"
namespace {
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() {
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
@@ -25,8 +24,8 @@ CurlHandle create_handle() {
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
static void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
@@ -37,14 +36,13 @@ void set_common_get_options(CURL* curl, const std::string& url,
}
// curl write callback that appends response data into a std::string
size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void* userp) {
static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void* userp) {
size_t realsize = size * nmemb;
auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char*>(contents), realsize);
return realsize;
}
} // namespace
std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle();