Code format updates

This commit is contained in:
Aaron Po
2026-04-11 23:51:08 -04:00
parent 823599a96f
commit 1cd30488eb
33 changed files with 985 additions and 993 deletions

View File

@@ -16,26 +16,26 @@
* @brief Interface for data generator implementations.
*/
class DataGenerator {
public:
virtual ~DataGenerator() = default;
public:
virtual ~DataGenerator() = default;
/**
* @brief Generates brewery data for a location.
*
* @param location Location data
* @param region_context Additional regional context text.
* @return Brewery generation result.
*/
virtual BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) = 0;
/**
* @brief Generates brewery data for a location.
*
* @param location Location data
* @param region_context Additional regional context text.
* @return Brewery generation result.
*/
virtual BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) = 0;
/**
* @brief Generates a user profile for a locale.
*
* @param locale Locale hint used by generator.
* @return User generation result.
*/
virtual UserResult GenerateUser(const std::string& locale) = 0;
/**
* @brief Generates a user profile for a locale.
*
* @param locale Locale hint used by generator.
* @return User generation result.
*/
virtual UserResult GenerateUser(const std::string& locale) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -34,8 +34,7 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
* @return Pair containing first and second parsed fields.
*/
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message);
const std::string& raw, const std::string& error_message);
/**
* @brief Applies model chat template to system and user prompts.
@@ -68,7 +67,8 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
* @return Validation error message if invalid, or std::nullopt on success.
*/
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out, std::string& description_out);
const std::string& raw, std::string& name_out,
std::string& description_out);
/**
* @brief Extracts the last balanced JSON object from text.

View File

@@ -16,109 +16,108 @@
* @brief Mock generator used for deterministic, model-free outputs.
*/
class MockGenerator final : public DataGenerator {
public:
/**
* @brief Generates deterministic brewery data for a location.
*
* @param location City and country names.
* @param region_context Unused for mock generation.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) override;
public:
/**
* @brief Generates deterministic brewery data for a location.
*
* @param location City and country names.
* @param region_context Unused for mock generation.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) override;
/**
* @brief Generates deterministic user data for a locale.
*
* @param locale Locale hint.
* @return Generated user result.
*/
UserResult GenerateUser(const std::string& locale) override;
/**
* @brief Generates deterministic user data for a locale.
*
* @param locale Locale hint.
* @return Generated user result.
*/
UserResult GenerateUser(const std::string& locale) override;
private:
/**
* @brief Combines two strings into a stable hash value.
*
* @param location City and country names.
* @return Deterministic hash value.
*/
static std::size_t DeterministicHash(const Location& location);
private:
/**
* @brief Combines two strings into a stable hash value.
*
* @param location City and country names.
* @return Deterministic hash value.
*/
static std::size_t DeterministicHash(const Location& location);
static constexpr std::array<std::string_view, 18> kBreweryAdjectives =
{"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18>
kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark "
"porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky "
"lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food "
"pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional "
"grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward "
"sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal "
"blends.",
"Modern taproom focused on sessionable lagers and classic pub "
"styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp "
"lagers.",
"Creative brewery offering rotating collaborations and limited "
"draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich "
"stouts."};
static constexpr std::array<std::string_view, 18> kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark "
"porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky "
"lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food "
"pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional "
"grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward "
"sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal "
"blends.",
"Modern taproom focused on sessionable lagers and classic pub "
"styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp "
"lagers.",
"Creative brewery offering rotating collaborations and limited "
"draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich "
"stouts."};
static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward "
"styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean "
"fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};
static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward "
"styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean "
"fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -13,30 +13,30 @@
* @brief Program options for the Biergarten pipeline application.
*/
struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path;
/// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false;
/// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random).
float top_p = 0.95F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random).
float top_p = 0.95F;
/// @brief LLM top-k sampling parameter.
uint32_t top_k = 64;
/// @brief LLM top-k sampling parameter.
uint32_t top_k = 64;
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 8192;
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 8192;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_

View File

@@ -12,11 +12,11 @@
* @brief Non-owning brewery location input.
*/
struct BreweryLocation {
/// @brief City name.
std::string_view city_name;
/// @brief City name.
std::string_view city_name;
/// @brief Country name.
std::string_view country_name;
/// @brief Country name.
std::string_view country_name;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated brewery payload.
*/
struct BreweryResult {
/// @brief Brewery display name.
std::string name{};
/// @brief Brewery display name.
std::string name{};
/// @brief Brewery description text.
std::string description{};
/// @brief Brewery description text.
std::string description{};
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_

View File

@@ -14,8 +14,8 @@
* @brief Enriched city data with Wikipedia context.
*/
struct EnrichedCity {
Location location;
std::string region_context{};
Location location;
std::string region_context{};
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_

View File

@@ -13,8 +13,8 @@
* @brief Helper struct to store generated brewery data.
*/
struct GeneratedBrewery {
Location location;
BreweryResult brewery;
Location location;
BreweryResult brewery;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_GENERATED_BREWERY_H_

View File

@@ -12,26 +12,26 @@
* @brief Canonical location record for city-level generation.
*/
struct Location {
/// @brief City name.
std::string city{};
/// @brief City name.
std::string city{};
/// @brief State or province name.
std::string state_province{};
/// @brief State or province name.
std::string state_province{};
/// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{};
/// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{};
/// @brief Country name.
std::string country{};
/// @brief Country name.
std::string country{};
/// @brief ISO 3166-1 country code.
std::string iso3166_1{};
/// @brief ISO 3166-1 country code.
std::string iso3166_1{};
/// @brief Latitude in decimal degrees.
double latitude{};
/// @brief Latitude in decimal degrees.
double latitude{};
/// @brief Longitude in decimal degrees.
double longitude{};
/// @brief Longitude in decimal degrees.
double longitude{};
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated user profile payload.
*/
struct UserResult {
/// @brief Username handle.
std::string username{};
/// @brief Username handle.
std::string username{};
/// @brief Short user biography.
std::string bio{};
/// @brief Short user biography.
std::string bio{};
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_

View File

@@ -13,10 +13,10 @@
/// @brief Loads curated world locations from a JSON file into memory.
class JsonLoader {
public:
/// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations(
const std::filesystem::path& filepath);
public:
/// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations(
const std::filesystem::path& filepath);
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -15,18 +15,18 @@
* it alive for application lifetime.
*/
class LlamaBackendState {
public:
/// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); }
public:
/// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); }
/// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); }
/// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); }
/// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_LLAMA_BACKEND_STATE_H_

View File

@@ -14,17 +14,17 @@
* @brief Interface for services that can enrich a location with context.
*/
class IEnrichmentService {
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~IEnrichmentService() = default;
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~IEnrichmentService() = default;
/**
* @brief Resolves contextual enrichment for a location.
*
* @param loc Location to enrich.
* @return Context text, or an empty string if unavailable.
*/
virtual std::string GetLocationContext(const Location& loc) = 0;
/**
* @brief Resolves contextual enrichment for a location.
*
* @param loc Location to enrich.
* @return Context text, or an empty string if unavailable.
*/
virtual std::string GetLocationContext(const Location& loc) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_SERVICE_H_

View File

@@ -15,40 +15,40 @@
* alive for application lifetime.
*/
class CurlGlobalState {
public:
/// @brief Initializes global libcurl state.
CurlGlobalState();
public:
/// @brief Initializes global libcurl state.
CurlGlobalState();
/// @brief Cleans up global libcurl state.
~CurlGlobalState();
/// @brief Cleans up global libcurl state.
~CurlGlobalState();
/// @brief Non-copyable type.
CurlGlobalState(const CurlGlobalState&) = delete;
/// @brief Non-copyable type.
CurlGlobalState(const CurlGlobalState&) = delete;
/// @brief Non-copyable type.
CurlGlobalState& operator=(const CurlGlobalState&) = delete;
/// @brief Non-copyable type.
CurlGlobalState& operator=(const CurlGlobalState&) = delete;
};
/**
* @brief WebClient implementation backed by libcurl.
*/
class CURLWebClient : public WebClient {
public:
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
std::string Get(const std::string& url) override;
public:
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
std::string Get(const std::string& url) override;
/**
* @brief URL-encodes a string value.
*
* @param value Raw value.
* @return URL-encoded string.
*/
std::string UrlEncode(const std::string& value) override;
/**
* @brief URL-encodes a string value.
*
* @param value Raw value.
* @return URL-encoded string.
*/
std::string UrlEncode(const std::string& value) override;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -12,25 +12,25 @@
* @brief Abstract web client interface.
*/
class WebClient {
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~WebClient() = default;
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~WebClient() = default;
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
virtual std::string Get(const std::string& url) = 0;
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
virtual std::string Get(const std::string& url) = 0;
/**
* @brief URL-encodes a string value.
*
* @param value Raw string value.
* @return Encoded value safe for URL usage.
*/
virtual std::string UrlEncode(const std::string& value) = 0;
/**
* @brief URL-encodes a string value.
*
* @param value Raw string value.
* @return Encoded value safe for URL usage.
*/
virtual std::string UrlEncode(const std::string& value) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -9,31 +9,31 @@
void BiergartenDataGenerator::GenerateBreweries(
std::span<const EnrichedCity> cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generated_breweries_.clear();
size_t skipped_count = 0;
generated_breweries_.clear();
size_t skipped_count = 0;
for (const auto& [location, region_context] : cities) {
try {
const BreweryResult brewery =
generator_->GenerateBrewery(location, region_context);
for (const auto& [location, region_context] : cities) {
try {
const BreweryResult brewery =
generator_->GenerateBrewery(location, region_context);
const GeneratedBrewery gen{.location = location, .brewery = brewery};
const GeneratedBrewery gen{.location = location, .brewery = brewery};
generated_breweries_.push_back(gen);
} catch (const std::exception& e) {
++skipped_count;
generated_breweries_.push_back(gen);
} catch (const std::exception& e) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}",
location.city, location.country, e.what());
}
}
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}",
location.city, location.country, e.what());
}
}
if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors",
skipped_count);
}
if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors",
skipped_count);
}
}

View File

@@ -8,16 +8,16 @@
#include "biergarten_data_generator.h"
void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ===");
size_t index = 1;
for (const auto& [location, brewery] : generated_breweries_) {
spdlog::info(
"{}. city=\"{}\" country=\"{}\" state=\"{}\" "
"iso3166_2={} lat={} lon={}",
index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name=\"{}\"", brewery.name);
spdlog::info(" brewery_description=\"{}\"", brewery.description);
++index;
}
spdlog::info("\n=== GENERATED DATA DUMP ===");
size_t index = 1;
for (const auto& [location, brewery] : generated_breweries_) {
spdlog::info(
"{}. city=\"{}\" country=\"{}\" state=\"{}\" "
"iso3166_2={} lat={} lon={}",
index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name=\"{}\"", brewery.name);
spdlog::info(" brewery_description=\"{}\"", brewery.description);
++index;
}
}

View File

@@ -16,25 +16,25 @@
static constexpr std::size_t kBreweryAmount = 4;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
const std::filesystem::path locations_path = "locations.json";
const std::filesystem::path locations_path = "locations.json";
auto all_locations = JsonLoader::LoadLocations(locations_path);
spdlog::info(" Locations available: {}", all_locations.size());
auto all_locations = JsonLoader::LoadLocations(locations_path);
spdlog::info(" Locations available: {}", all_locations.size());
const std::size_t sample_count =
std::min(kBreweryAmount, all_locations.size());
const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count);
std::vector<Location> sampled_locations;
sampled_locations.reserve(sample_count);
const std::size_t sample_count =
std::min(kBreweryAmount, all_locations.size());
const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count);
std::vector<Location> sampled_locations;
sampled_locations.reserve(sample_count);
std::random_device random_generator;
std::ranges::sample(all_locations, std::back_inserter(sampled_locations),
sample_count_signed, random_generator);
std::random_device random_generator;
std::ranges::sample(all_locations, std::back_inserter(sampled_locations),
sample_count_signed, random_generator);
spdlog::info(" Sampled locations: {}", sampled_locations.size());
return sampled_locations;
spdlog::info(" Sampled locations: {}", sampled_locations.size());
return sampled_locations;
}

View File

@@ -8,40 +8,40 @@
#include "biergarten_data_generator.h"
bool BiergartenDataGenerator::Run() {
try {
const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
try {
const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
size_t skipped_count = 0;
for (const auto& city : cities) {
try {
const std::string region_context =
context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
size_t skipped_count = 0;
for (const auto& city : cities) {
try {
const std::string region_context =
context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
enriched.push_back(EnrichedCity{.location = city,
.region_context = region_context});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what());
}
enriched.push_back(
EnrichedCity{.location = city, .region_context = region_context});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what());
}
}
if (skipped_count > 0) {
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count);
}
if (skipped_count > 0) {
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count);
}
this->GenerateBreweries(enriched);
this->LogResults();
return true;
} catch (const std::exception& e) {
spdlog::error("Pipeline execution failed with error: {}", e.what());
return false;
}
this->GenerateBreweries(enriched);
this->LogResults();
return true;
} catch (const std::exception& e) {
spdlog::error("Pipeline execution failed with error: {}", e.what());
return false;
}
}

View File

@@ -16,135 +16,134 @@
#include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
return std::string(trimmed);
return std::string(trimmed);
}
BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) {
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
*/
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md");
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
*/
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md");
/**
* User prompt: provides geographic context to guide generation towards
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
/**
* User prompt: provides geographic context to guide generation towards
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
std::format("Location: {}{}", location.city, country_suffix);
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
std::format("Location: {}{}", location.city, country_suffix);
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
constexpr int max_attempts = 3;
std::string raw;
std::string last_error;
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
constexpr int max_attempts = 3;
std::string raw;
std::string last_error;
// Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
// Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
// Validate output: parse JSON and check required fields
// Validate output: parse JSON and check required fields
std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
}
std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
}
// Validation failed: log error and prepare corrective feedback
// Validation failed: log error and prepare corrective feedback
last_error = *validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error);
last_error = *validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error);
// Update prompt with error details to guide LLM toward correct output.
prompt = std::format(
R"(Your previous response was invalid. Error: {}
// Update prompt with error details to guide LLM toward correct output.
prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
{})",
*validation_error, retry_location);
}
*validation_error, retry_location);
}
// All retry attempts exhausted: log failure and throw exception
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
// All retry attempts exhausted: log failure and throw exception
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}

View File

@@ -13,6 +13,6 @@
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."};
return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."};
}

View File

@@ -24,14 +24,14 @@
* String trimming: removes leading and trailing whitespace
*/
static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
}
/**
@@ -39,26 +39,26 @@ static std::string Trim(std::string_view value) {
* spaces
*/
static std::string CondenseWhitespace(std::string_view text) {
std::string out;
out.reserve(text.size());
std::string out;
out.reserve(text.size());
bool pending_space = false;
for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
}
continue;
bool pending_space = false;
for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
}
continue;
}
if (pending_space) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
}
if (pending_space) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
}
return out;
return out;
}
/**
@@ -67,286 +67,285 @@ static std::string CondenseWhitespace(std::string_view text) {
*/
static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) {
return normalized;
}
std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized += "...";
return normalized;
normalized += "...";
return normalized;
}
static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
}
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
}
const std::array<llama_chat_message, 2> messages = {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
const std::array<llama_chat_message, 2> messages = {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize =
[&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (result < 0) {
return result;
}
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (result < 0) {
return result;
};
}
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
return result;
};
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
int32_t template_result = apply_template_with_resize(messages.data(), 2);
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
template_result = apply_template_with_resize(fallback_msg.data(), 1);
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
return {buffer.data(), static_cast<std::size_t>(template_result)};
template_result = apply_template_with_resize(fallback_msg.data(), 1);
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes = llama_token_to_piece(vocab, token, buffer.data(),
buffer.size(), 0, true);
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()),
0, true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
bool found = false;
std::string candidate;
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
bool found = false;
std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
}
continue;
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
}
continue;
}
if (ch == '"') {
in_string = true;
continue;
if (ch == '"') {
in_string = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
if (ch == '}') {
if (depth == 0) {
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
candidate = text.substr(start, i - start + 1);
found = true;
}
--depth;
if (depth == 0 && start != std::string::npos) {
candidate = text.substr(start, i - start + 1);
found = true;
}
}
}
}
if (!found) {
return false;
}
if (!found) {
return false;
}
json_out = std::move(candidate);
return true;
json_out = std::move(candidate);
return true;
}
std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
return {};
return {};
}
static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
error_out.clear();
return true;
};
error_out.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
if (!validate_object(jv, validation_error)) {
return validation_error;
}
}
return std::nullopt;
return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
}
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars);
return PrepareRegionContext(region_context, max_chars);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt);
return ToChatPrompt(model, system_prompt, user_prompt);
}
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) {
AppendTokenPiece(vocab, token, output);
AppendTokenPiece(vocab, token, output);
}
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -14,32 +14,32 @@
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
const llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
const llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}

View File

@@ -21,40 +21,39 @@ namespace fs = std::filesystem;
* @return Prompt text loaded from disk.
*/
std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) {
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
const std::string& prompt_file_path) {
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
// Try the provided path only
const fs::path prompt_path(prompt_file_path);
std::ifstream prompt_file(prompt_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
// Try the provided path only
const fs::path prompt_path(prompt_file_path);
std::ifstream prompt_file(prompt_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (prompt.empty()) {
spdlog::error(
"LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_path.string());
}
if (prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_path.string());
}
spdlog::info(
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_path.string(), prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}

View File

@@ -9,8 +9,8 @@
#include "data_generation/mock_generator.h"
size_t MockGenerator::DeterministicHash(const Location& location) {
size_t seed = 0;
boost::hash_combine(seed, location.city);
boost::hash_combine(seed, location.country);
return seed;
size_t seed = 0;
boost::hash_combine(seed, location.city);
boost::hash_combine(seed, location.country);
return seed;
}

View File

@@ -12,31 +12,31 @@
BreweryResult MockGenerator::GenerateBrewery(
const Location& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location);
const std::size_t hash = DeterministicHash(location);
const std::string_view adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size());
const std::string_view base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string_view adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size());
const std::string_view base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string name =
std::format("{} {} {}", location.city, adjective, noun);
const std::string name =
std::format("{} {} {}", location.city, adjective, noun);
const std::string state_suffix =
location.state_province.empty()
? std::string{}
: std::format(", {}", location.state_province);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string description = std::format(
"{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix);
const std::string state_suffix =
location.state_province.empty()
? std::string{}
: std::format(", {}", location.state_province);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string description =
std::format("{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix);
return {
.name = name,
.description = description,
};
return {
.name = name,
.description = description,
};
}

View File

@@ -11,12 +11,12 @@
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()];
result.username = username;
result.bio = bio;
return result;
UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()];
result.username = username;
result.bio = bio;
return result;
}

View File

@@ -16,72 +16,72 @@
static std::string ReadRequiredString(const boost::json::object& object,
const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) {
throw std::runtime_error(
std::string("Missing or invalid string field: ") + key);
}
const std::string_view text = value->as_string();
return std::string(text);
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) {
throw std::runtime_error(std::string("Missing or invalid string field: ") +
key);
}
const std::string_view text = value->as_string();
return std::string(text);
}
static double ReadRequiredNumber(const boost::json::object& object,
const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) {
throw std::runtime_error(
std::string("Missing or invalid numeric field: ") + key);
}
return value->to_number<double>();
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) {
throw std::runtime_error(std::string("Missing or invalid numeric field: ") +
key);
}
return value->to_number<double>();
}
std::vector<Location> JsonLoader::LoadLocations(
const std::filesystem::path& filepath) {
std::ifstream input(filepath);
if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " +
filepath.string());
}
std::ifstream input(filepath);
if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " +
filepath.string());
}
std::stringstream buffer;
buffer << input.rdbuf();
const std::string content = buffer.str();
std::stringstream buffer;
buffer << input.rdbuf();
const std::string content = buffer.str();
boost::system::error_code error;
boost::json::value root = boost::json::parse(content, error);
if (error) {
throw std::runtime_error("Failed to parse locations JSON: " +
error.message());
}
boost::system::error_code error;
boost::json::value root = boost::json::parse(content, error);
if (error) {
throw std::runtime_error("Failed to parse locations JSON: " +
error.message());
}
if (!root.is_array()) {
if (!root.is_array()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
for (const auto& item : items) {
if (!item.is_object()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
"Invalid locations JSON: each entry must be an object");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
const auto& object = item.as_object();
locations.push_back(Location{
.city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
for (const auto& item : items) {
if (!item.is_object()) {
throw std::runtime_error(
"Invalid locations JSON: each entry must be an object");
}
const auto& object = item.as_object();
locations.push_back(Location{
.city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath.string());
return locations;
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath.string());
return locations;
}

View File

@@ -34,151 +34,150 @@ namespace di = boost::di;
* @return Parsed ApplicationOptions if parsing succeeded, std::nullopt
* otherwise.
*/
std::optional<ApplicationOptions> ParseArguments(const int argc,
char** argv) {
prog_opts::options_description desc("Pipeline Options");
std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
prog_opts::options_description desc("Pipeline Options");
auto opt = desc.add_options();
auto opt = desc.add_options();
opt("help,h", "Produce help message");
opt("help,h", "Produce help message");
opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data");
opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data");
opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)");
opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)");
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)");
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)");
opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
// Handle the "no arguments" or "help" case
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
return std::nullopt;
}
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt;
}
}
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt;
}
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt;
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt;
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt;
}
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return std::nullopt;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
return std::nullopt;
}
}
int main(const int argc, char** argv) {
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
const auto parsed_options = ParseArguments(argc, argv);
if (!parsed_options.has_value()) {
return 0;
}
const auto parsed_options = ParseArguments(argc, argv);
if (!parsed_options.has_value()) {
return 0;
}
const auto options = *parsed_options;
const auto options = *parsed_options;
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to([options](const auto& inj)
-> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to(
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>();
}
}
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>();
}));
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>();
}));
auto generator = injector.create<BiergartenDataGenerator>();
auto generator = injector.create<BiergartenDataGenerator>();
if (!generator.Run()) {
spdlog::error("Pipeline execution failed");
return 1;
}
spdlog::info("Pipeline executed successfully");
return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
if (!generator.Run()) {
spdlog::error("Pipeline execution failed");
return 1;
} catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
}
spdlog::info("Pipeline executed successfully");
return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1;
} catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
}

View File

@@ -12,51 +12,50 @@
#include "services/wikipedia_service.h"
std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string body = this->client_->Get(url);
const std::string body = this->client_->Get(url);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view =
page.at("extract").as_string();
std::string extract(extract_view);
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string();
std::string extract(extract_view);
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
return {};
return {};
}

View File

@@ -10,10 +10,10 @@
#include "web_client/curl_web_client.h"
CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally");
}
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally");
}
}
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }

View File

@@ -15,63 +15,61 @@
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url) {
constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, const size_t size,
const size_t nmemb,
void* userp) {
const size_t real_size = size * nmemb;
auto* str = static_cast<std::string*>(userp);
str->append(static_cast<char*>(contents), real_size);
return real_size;
const size_t nmemb, void* userp) {
const size_t real_size = size * nmemb;
auto* str = static_cast<std::string*>(userp);
str->append(static_cast<char*>(contents), real_size);
return real_size;
}
std::string CURLWebClient::Get(const std::string& url) {
const CurlHandle curl = create_handle();
const CurlHandle curl = create_handle();
std::string response_string;
std::string response_string;
set_common_get_options(curl.get(), url);
set_common_get_options(curl.get(), url);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get());
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
const auto error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
if (res != CURLE_OK) {
const auto error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
int64_t httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
int64_t httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " +
std::to_string(httpCode) +
" for URL " + url;
throw std::runtime_error(error);
}
if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " +
std::to_string(httpCode) + " for URL " + url;
throw std::runtime_error(error);
}
return response_string;
return response_string;
}

View File

@@ -11,14 +11,14 @@
#include "web_client/curl_web_client.h"
std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0);
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (!output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}
if (!output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}
std::string result(output);
curl_free(output);
return result;
std::string result(output);
curl_free(output);
return result;
}