Code format updates

This commit is contained in:
Aaron Po
2026-04-11 23:51:08 -04:00
parent 823599a96f
commit 1cd30488eb
33 changed files with 985 additions and 993 deletions

View File

@@ -16,26 +16,26 @@
* @brief Interface for data generator implementations. * @brief Interface for data generator implementations.
*/ */
class DataGenerator { class DataGenerator {
public: public:
virtual ~DataGenerator() = default; virtual ~DataGenerator() = default;
/** /**
* @brief Generates brewery data for a location. * @brief Generates brewery data for a location.
* *
* @param location Location data * @param location Location data
* @param region_context Additional regional context text. * @param region_context Additional regional context text.
* @return Brewery generation result. * @return Brewery generation result.
*/ */
virtual BreweryResult GenerateBrewery(const Location& location, virtual BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) = 0; const std::string& region_context) = 0;
/** /**
* @brief Generates a user profile for a locale. * @brief Generates a user profile for a locale.
* *
* @param locale Locale hint used by generator. * @param locale Locale hint used by generator.
* @return User generation result. * @return User generation result.
*/ */
virtual UserResult GenerateUser(const std::string& locale) = 0; virtual UserResult GenerateUser(const std::string& locale) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -34,8 +34,7 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
* @return Pair containing first and second parsed fields. * @return Pair containing first and second parsed fields.
*/ */
std::pair<std::string, std::string> ParseTwoLineResponsePublic( std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message); const std::string& raw, const std::string& error_message);
/** /**
* @brief Applies model chat template to system and user prompts. * @brief Applies model chat template to system and user prompts.
@@ -68,7 +67,8 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
* @return Validation error message if invalid, or std::nullopt on success. * @return Validation error message if invalid, or std::nullopt on success.
*/ */
std::optional<std::string> ValidateBreweryJsonPublic( std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out, std::string& description_out); const std::string& raw, std::string& name_out,
std::string& description_out);
/** /**
* @brief Extracts the last balanced JSON object from text. * @brief Extracts the last balanced JSON object from text.

View File

@@ -16,109 +16,108 @@
* @brief Mock generator used for deterministic, model-free outputs. * @brief Mock generator used for deterministic, model-free outputs.
*/ */
class MockGenerator final : public DataGenerator { class MockGenerator final : public DataGenerator {
public: public:
/** /**
* @brief Generates deterministic brewery data for a location. * @brief Generates deterministic brewery data for a location.
* *
* @param location City and country names. * @param location City and country names.
* @param region_context Unused for mock generation. * @param region_context Unused for mock generation.
* @return Generated brewery result. * @return Generated brewery result.
*/ */
BreweryResult GenerateBrewery(const Location& location, BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) override; const std::string& region_context) override;
/** /**
* @brief Generates deterministic user data for a locale. * @brief Generates deterministic user data for a locale.
* *
* @param locale Locale hint. * @param locale Locale hint.
* @return Generated user result. * @return Generated user result.
*/ */
UserResult GenerateUser(const std::string& locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
/** /**
* @brief Combines two strings into a stable hash value. * @brief Combines two strings into a stable hash value.
* *
* @param location City and country names. * @param location City and country names.
* @return Deterministic hash value. * @return Deterministic hash value.
*/ */
static std::size_t DeterministicHash(const Location& location); static std::size_t DeterministicHash(const Location& location);
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
{"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel", "Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"}; "Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
static constexpr std::array<std::string_view, 18> kBreweryNouns = { static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works", "Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective", "House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op", "Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"}; "Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18> static constexpr std::array<std::string_view, 18> kBreweryDescriptions = {
kBreweryDescriptions = { "Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Handcrafted pale ales and seasonal IPAs with local ingredients.", "Traditional lagers and experimental sours in small batches.",
"Traditional lagers and experimental sours in small batches.", "Award-winning stouts and wildly hoppy blonde ales.",
"Award-winning stouts and wildly hoppy blonde ales.", "Craft brewery specializing in Belgian-style triples and dark "
"Craft brewery specializing in Belgian-style triples and dark " "porters.",
"porters.", "Modern brewery blending tradition with bold experimental flavors.",
"Modern brewery blending tradition with bold experimental flavors.", "Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"Neighborhood-focused taproom pouring crisp pilsners and citrusy " "pale "
"pale " "ales.",
"ales.", "Small-batch brewery known for barrel-aged releases and smoky "
"Small-batch brewery known for barrel-aged releases and smoky " "lagers.",
"lagers.", "Independent brewhouse pairing farmhouse ales with rotating food "
"Independent brewhouse pairing farmhouse ales with rotating food " "pop-ups.",
"pop-ups.", "Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.", "Experimental nanobrewery exploring local yeast and regional "
"Experimental nanobrewery exploring local yeast and regional " "grains.",
"grains.", "Family-run brewery producing smooth amber ales and robust porters.",
"Family-run brewery producing smooth amber ales and robust porters.", "Urban brewery crafting clean lagers and bright, fruit-forward "
"Urban brewery crafting clean lagers and bright, fruit-forward " "sours.",
"sours.", "Riverfront brewhouse featuring oak-matured ales and seasonal "
"Riverfront brewhouse featuring oak-matured ales and seasonal " "blends.",
"blends.", "Modern taproom focused on sessionable lagers and classic pub "
"Modern taproom focused on sessionable lagers and classic pub " "styles.",
"styles.", "Brewery rooted in tradition with a lineup of malty reds and crisp "
"Brewery rooted in tradition with a lineup of malty reds and crisp " "lagers.",
"lagers.", "Creative brewery offering rotating collaborations and limited "
"Creative brewery offering rotating collaborations and limited " "draft-only "
"draft-only " "pours.",
"pours.", "Locally inspired brewery serving approachable ales with bold hop "
"Locally inspired brewery serving approachable ales with bold hop " "character.",
"character.", "Destination taproom known for balanced IPAs and cocoa-rich "
"Destination taproom known for balanced IPAs and cocoa-rich " "stouts."};
"stouts."};
static constexpr std::array<std::string_view, 18> kUsernames = { static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane", "hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist", "barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps", "brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal", "mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"}; "sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kBios = { static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.", "Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.", "Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer " "Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.", "gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.", "Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.", "Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward " "Hops-first drinker who still saves room for classic malt-forward "
"styles.", "styles.",
"Finding hidden tap lists and sharing the best seasonal releases.", "Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean " "Brewery road-tripper focused on local ingredients and clean "
"fermentation.", "fermentation.",
"Always comparing house lagers and ranking patio pint vibes.", "Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.", "Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.", "Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.", "Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.", "Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.", "Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.", "Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -13,30 +13,30 @@
* @brief Program options for the Biergarten pipeline application. * @brief Program options for the Biergarten pipeline application.
*/ */
struct ApplicationOptions { struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with /// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked. /// use_mocked.
std::string model_path; std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with /// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path. /// model_path.
bool use_mocked = false; bool use_mocked = false;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F; float temperature = 1.0F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random). /// random).
float top_p = 0.95F; float top_p = 0.95F;
/// @brief LLM top-k sampling parameter. /// @brief LLM top-k sampling parameter.
uint32_t top_k = 64; uint32_t top_k = 64;
/// @brief Context window size (tokens) for LLM inference. Higher values /// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory. /// support longer prompts but use more memory.
uint32_t n_ctx = 8192; uint32_t n_ctx = 8192;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_

View File

@@ -12,11 +12,11 @@
* @brief Non-owning brewery location input. * @brief Non-owning brewery location input.
*/ */
struct BreweryLocation { struct BreweryLocation {
/// @brief City name. /// @brief City name.
std::string_view city_name; std::string_view city_name;
/// @brief Country name. /// @brief Country name.
std::string_view country_name; std::string_view country_name;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_LOCATION_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated brewery payload. * @brief Generated brewery payload.
*/ */
struct BreweryResult { struct BreweryResult {
/// @brief Brewery display name. /// @brief Brewery display name.
std::string name{}; std::string name{};
/// @brief Brewery description text. /// @brief Brewery description text.
std::string description{}; std::string description{};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_

View File

@@ -14,8 +14,8 @@
* @brief Enriched city data with Wikipedia context. * @brief Enriched city data with Wikipedia context.
*/ */
struct EnrichedCity { struct EnrichedCity {
Location location; Location location;
std::string region_context{}; std::string region_context{};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_

View File

@@ -13,8 +13,8 @@
* @brief Helper struct to store generated brewery data. * @brief Helper struct to store generated brewery data.
*/ */
struct GeneratedBrewery { struct GeneratedBrewery {
Location location; Location location;
BreweryResult brewery; BreweryResult brewery;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_GENERATED_BREWERY_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_GENERATED_BREWERY_H_

View File

@@ -12,26 +12,26 @@
* @brief Canonical location record for city-level generation. * @brief Canonical location record for city-level generation.
*/ */
struct Location { struct Location {
/// @brief City name. /// @brief City name.
std::string city{}; std::string city{};
/// @brief State or province name. /// @brief State or province name.
std::string state_province{}; std::string state_province{};
/// @brief ISO 3166-2 subdivision code. /// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{}; std::string iso3166_2{};
/// @brief Country name. /// @brief Country name.
std::string country{}; std::string country{};
/// @brief ISO 3166-1 country code. /// @brief ISO 3166-1 country code.
std::string iso3166_1{}; std::string iso3166_1{};
/// @brief Latitude in decimal degrees. /// @brief Latitude in decimal degrees.
double latitude{}; double latitude{};
/// @brief Longitude in decimal degrees. /// @brief Longitude in decimal degrees.
double longitude{}; double longitude{};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated user profile payload. * @brief Generated user profile payload.
*/ */
struct UserResult { struct UserResult {
/// @brief Username handle. /// @brief Username handle.
std::string username{}; std::string username{};
/// @brief Short user biography. /// @brief Short user biography.
std::string bio{}; std::string bio{};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_

View File

@@ -13,10 +13,10 @@
/// @brief Loads curated world locations from a JSON file into memory. /// @brief Loads curated world locations from a JSON file into memory.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON array file and returns all location records. /// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations( static std::vector<Location> LoadLocations(
const std::filesystem::path& filepath); const std::filesystem::path& filepath);
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -15,18 +15,18 @@
* it alive for application lifetime. * it alive for application lifetime.
*/ */
class LlamaBackendState { class LlamaBackendState {
public: public:
/// @brief Initializes global llama backend state. /// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); } LlamaBackendState() { llama_backend_init(); }
/// @brief Cleans up global llama backend state. /// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); } ~LlamaBackendState() { llama_backend_free(); }
/// @brief Non-copyable type. /// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete; LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type. /// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete; LlamaBackendState& operator=(const LlamaBackendState&) = delete;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_LLAMA_BACKEND_STATE_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_LLAMA_BACKEND_STATE_H_

View File

@@ -14,17 +14,17 @@
* @brief Interface for services that can enrich a location with context. * @brief Interface for services that can enrich a location with context.
*/ */
class IEnrichmentService { class IEnrichmentService {
public: public:
/// @brief Virtual destructor for polymorphic cleanup. /// @brief Virtual destructor for polymorphic cleanup.
virtual ~IEnrichmentService() = default; virtual ~IEnrichmentService() = default;
/** /**
* @brief Resolves contextual enrichment for a location. * @brief Resolves contextual enrichment for a location.
* *
* @param loc Location to enrich. * @param loc Location to enrich.
* @return Context text, or an empty string if unavailable. * @return Context text, or an empty string if unavailable.
*/ */
virtual std::string GetLocationContext(const Location& loc) = 0; virtual std::string GetLocationContext(const Location& loc) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_SERVICE_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_SERVICE_H_

View File

@@ -15,40 +15,40 @@
* alive for application lifetime. * alive for application lifetime.
*/ */
class CurlGlobalState { class CurlGlobalState {
public: public:
/// @brief Initializes global libcurl state. /// @brief Initializes global libcurl state.
CurlGlobalState(); CurlGlobalState();
/// @brief Cleans up global libcurl state. /// @brief Cleans up global libcurl state.
~CurlGlobalState(); ~CurlGlobalState();
/// @brief Non-copyable type. /// @brief Non-copyable type.
CurlGlobalState(const CurlGlobalState&) = delete; CurlGlobalState(const CurlGlobalState&) = delete;
/// @brief Non-copyable type. /// @brief Non-copyable type.
CurlGlobalState& operator=(const CurlGlobalState&) = delete; CurlGlobalState& operator=(const CurlGlobalState&) = delete;
}; };
/** /**
* @brief WebClient implementation backed by libcurl. * @brief WebClient implementation backed by libcurl.
*/ */
class CURLWebClient : public WebClient { class CURLWebClient : public WebClient {
public: public:
/** /**
* @brief Executes an HTTP GET request. * @brief Executes an HTTP GET request.
* *
* @param url Request URL. * @param url Request URL.
* @return Response body. * @return Response body.
*/ */
std::string Get(const std::string& url) override; std::string Get(const std::string& url) override;
/** /**
* @brief URL-encodes a string value. * @brief URL-encodes a string value.
* *
* @param value Raw value. * @param value Raw value.
* @return URL-encoded string. * @return URL-encoded string.
*/ */
std::string UrlEncode(const std::string& value) override; std::string UrlEncode(const std::string& value) override;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_CURL_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -12,25 +12,25 @@
* @brief Abstract web client interface. * @brief Abstract web client interface.
*/ */
class WebClient { class WebClient {
public: public:
/// @brief Virtual destructor for polymorphic cleanup. /// @brief Virtual destructor for polymorphic cleanup.
virtual ~WebClient() = default; virtual ~WebClient() = default;
/** /**
* @brief Executes an HTTP GET request. * @brief Executes an HTTP GET request.
* *
* @param url Request URL. * @param url Request URL.
* @return Response body. * @return Response body.
*/ */
virtual std::string Get(const std::string& url) = 0; virtual std::string Get(const std::string& url) = 0;
/** /**
* @brief URL-encodes a string value. * @brief URL-encodes a string value.
* *
* @param value Raw string value. * @param value Raw string value.
* @return Encoded value safe for URL usage. * @return Encoded value safe for URL usage.
*/ */
virtual std::string UrlEncode(const std::string& value) = 0; virtual std::string UrlEncode(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -9,31 +9,31 @@
void BiergartenDataGenerator::GenerateBreweries( void BiergartenDataGenerator::GenerateBreweries(
std::span<const EnrichedCity> cities) { std::span<const EnrichedCity> cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generated_breweries_.clear(); generated_breweries_.clear();
size_t skipped_count = 0; size_t skipped_count = 0;
for (const auto& [location, region_context] : cities) { for (const auto& [location, region_context] : cities) {
try { try {
const BreweryResult brewery = const BreweryResult brewery =
generator_->GenerateBrewery(location, region_context); generator_->GenerateBrewery(location, region_context);
const GeneratedBrewery gen{.location = location, .brewery = brewery}; const GeneratedBrewery gen{.location = location, .brewery = brewery};
generated_breweries_.push_back(gen); generated_breweries_.push_back(gen);
} catch (const std::exception& e) { } catch (const std::exception& e) {
++skipped_count; ++skipped_count;
spdlog::warn( spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: " "[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}", "{}",
location.city, location.country, e.what()); location.city, location.country, e.what());
} }
} }
if (skipped_count > 0) { if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors", spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors",
skipped_count); skipped_count);
} }
} }

View File

@@ -8,16 +8,16 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
void BiergartenDataGenerator::LogResults() const { void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ==="); spdlog::info("\n=== GENERATED DATA DUMP ===");
size_t index = 1; size_t index = 1;
for (const auto& [location, brewery] : generated_breweries_) { for (const auto& [location, brewery] : generated_breweries_) {
spdlog::info( spdlog::info(
"{}. city=\"{}\" country=\"{}\" state=\"{}\" " "{}. city=\"{}\" country=\"{}\" state=\"{}\" "
"iso3166_2={} lat={} lon={}", "iso3166_2={} lat={} lon={}",
index, location.city, location.country, location.state_province, index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude); location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name=\"{}\"", brewery.name); spdlog::info(" brewery_name=\"{}\"", brewery.name);
spdlog::info(" brewery_description=\"{}\"", brewery.description); spdlog::info(" brewery_description=\"{}\"", brewery.description);
++index; ++index;
} }
} }

View File

@@ -16,25 +16,25 @@
static constexpr std::size_t kBreweryAmount = 4; static constexpr std::size_t kBreweryAmount = 4;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() { std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
const std::filesystem::path locations_path = "locations.json"; const std::filesystem::path locations_path = "locations.json";
auto all_locations = JsonLoader::LoadLocations(locations_path); auto all_locations = JsonLoader::LoadLocations(locations_path);
spdlog::info(" Locations available: {}", all_locations.size()); spdlog::info(" Locations available: {}", all_locations.size());
const std::size_t sample_count = const std::size_t sample_count =
std::min(kBreweryAmount, all_locations.size()); std::min(kBreweryAmount, all_locations.size());
const auto sample_count_signed = const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>( static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count); sample_count);
std::vector<Location> sampled_locations; std::vector<Location> sampled_locations;
sampled_locations.reserve(sample_count); sampled_locations.reserve(sample_count);
std::random_device random_generator; std::random_device random_generator;
std::ranges::sample(all_locations, std::back_inserter(sampled_locations), std::ranges::sample(all_locations, std::back_inserter(sampled_locations),
sample_count_signed, random_generator); sample_count_signed, random_generator);
spdlog::info(" Sampled locations: {}", sampled_locations.size()); spdlog::info(" Sampled locations: {}", sampled_locations.size());
return sampled_locations; return sampled_locations;
} }

View File

@@ -8,40 +8,40 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
bool BiergartenDataGenerator::Run() { bool BiergartenDataGenerator::Run() {
try { try {
const std::vector<Location> cities = QueryCitiesWithCountries(); const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched; std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size()); enriched.reserve(cities.size());
size_t skipped_count = 0; size_t skipped_count = 0;
for (const auto& city : cities) { for (const auto& city : cities) {
try { try {
const std::string region_context = const std::string region_context =
context_service_->GetLocationContext(city); context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}", spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context); city.city, city.country, region_context);
enriched.push_back(EnrichedCity{.location = city, enriched.push_back(
.region_context = region_context}); EnrichedCity{.location = city, .region_context = region_context});
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
++skipped_count; ++skipped_count;
spdlog::warn( spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}", "[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what()); city.city, city.country, exception.what());
}
} }
}
if (skipped_count > 0) { if (skipped_count > 0) {
spdlog::warn( spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors", "[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count); skipped_count);
} }
this->GenerateBreweries(enriched); this->GenerateBreweries(enriched);
this->LogResults(); this->LogResults();
return true; return true;
} catch (const std::exception& e) { } catch (const std::exception& e) {
spdlog::error("Pipeline execution failed with error: {}", e.what()); spdlog::error("Pipeline execution failed with error: {}", e.what());
return false; return false;
} }
} }

View File

@@ -16,135 +16,134 @@
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) { static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view { auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r"); const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) { if (first == std::string_view::npos) {
return {}; return {};
} }
const std::size_t last = text.find_last_not_of(" \t\n\r"); const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1); return text.substr(first, last - first + 1);
}; };
static constexpr std::array<std::string_view, 6> separator_tokens = { static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>", "<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"}; "<turn|>", "<channel|>", "<|channel|>"};
std::size_t separator_pos = std::string::npos; std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0; std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) { for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token); const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos && if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || (separator_pos == std::string::npos || candidate_pos > separator_pos)) {
candidate_pos > separator_pos)) { separator_pos = candidate_pos;
separator_pos = candidate_pos; separator_length = token.size();
separator_length = token.size(); }
} }
}
if (separator_pos != std::string::npos) { if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length); raw_response.erase(0, separator_pos + separator_length);
} }
const std::string_view trimmed = trim(raw_response); const std::string_view trimmed = trim(raw_response);
const std::string json_candidate = const std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed)); ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) { if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed)); return ExtractLastJsonObjectPublic(std::string(trimmed));
} }
return std::string(trimmed); return std::string(trimmed);
} }
BreweryResult LlamaGenerator::GenerateBrewery( BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) { const Location& location, const std::string& region_context) {
/** /**
* Preprocess and truncate region context to manageable size * Preprocess and truncate region context to manageable size
*/ */
const std::string safe_region_context = const std::string safe_region_context =
PrepareRegionContextPublic(region_context); PrepareRegionContextPublic(region_context);
const std::string country_suffix = const std::string country_suffix =
location.country.empty() ? std::string{} location.country.empty() ? std::string{}
: std::format(", {}", location.country); : std::format(", {}", location.country);
const std::string region_suffix = const std::string region_suffix =
safe_region_context.empty() safe_region_context.empty()
? "." ? "."
: std::format(". Regional context: {}", safe_region_context); : std::format(". Regional context: {}", safe_region_context);
/** /**
* Load brewery system prompt from file * Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found * Falls back to minimal inline prompt if file not found
*/ */
const std::string system_prompt = const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md"); LoadBrewerySystemPrompt("prompts/system.md");
/** /**
* User prompt: provides geographic context to guide generation towards * User prompt: provides geographic context to guide generation towards
* culturally relevant and locally-inspired brewery attributes * culturally relevant and locally-inspired brewery attributes
*/ */
std::string prompt = std::format( std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft " "Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}", "brewery in {}{}{}",
location.city, country_suffix, region_suffix); location.city, country_suffix, region_suffix);
/** /**
* Store location context for retry prompts (without repeating full context) * Store location context for retry prompts (without repeating full context)
*/ */
const std::string retry_location = const std::string retry_location =
std::format("Location: {}{}", location.city, country_suffix); std::format("Location: {}{}", location.city, country_suffix);
/** /**
* RETRY LOOP with validation and error correction * RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based * Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement * refinement
*/ */
constexpr int max_attempts = 3; constexpr int max_attempts = 3;
std::string raw; std::string raw;
std::string last_error; std::string last_error;
// Limit output length to keep it concise and focused // Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) { for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052; constexpr int max_tokens = 1052;
// Generate brewery data from LLM // Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens); raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw); raw);
// Validate output: parse JSON and check required fields // Validate output: parse JSON and check required fields
std::string name; std::string name;
std::string description; std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw); const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error = const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description); ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) { if (!validation_error.has_value()) {
// Success: return parsed brewery data // Success: return parsed brewery data
return BreweryResult{.name = std::move(name), return BreweryResult{.name = std::move(name),
.description = std::move(description)}; .description = std::move(description)};
} }
// Validation failed: log error and prepare corrective feedback // Validation failed: log error and prepare corrective feedback
last_error = *validation_error; last_error = *validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error); attempt + 1, *validation_error);
// Update prompt with error details to guide LLM toward correct output. // Update prompt with error details to guide LLM toward correct output.
prompt = std::format( prompt = std::format(
R"(Your previous response was invalid. Error: {} R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}. Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values. Do not include markdown, comments, extra keys, or literal placeholder values.
{})", {})",
*validation_error, retry_location); *validation_error, retry_location);
} }
// All retry attempts exhausted: log failure and throw exception // All retry attempts exhausted: log failure and throw exception
spdlog::error( spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: " "LlamaGenerator: malformed brewery response after {} attempts: "
"{}", "{}",
max_attempts, last_error.empty() ? raw : last_error); max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response"); throw std::runtime_error("LlamaGenerator: malformed brewery response");
} }

View File

@@ -13,6 +13,6 @@
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) { UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user", return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."}; .bio = "This is a test user profile from " + locale + "."};
} }

View File

@@ -24,14 +24,14 @@
* String trimming: removes leading and trailing whitespace * String trimming: removes leading and trailing whitespace
*/ */
static std::string Trim(std::string_view value) { static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v"; constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace); const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) { if (first_index == std::string_view::npos) {
return {}; return {};
} }
const std::size_t last_index = value.find_last_not_of(whitespace); const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1)); return std::string(value.substr(first_index, last_index - first_index + 1));
} }
/** /**
@@ -39,26 +39,26 @@ static std::string Trim(std::string_view value) {
* spaces * spaces
*/ */
static std::string CondenseWhitespace(std::string_view text) { static std::string CondenseWhitespace(std::string_view text) {
std::string out; std::string out;
out.reserve(text.size()); out.reserve(text.size());
bool pending_space = false; bool pending_space = false;
for (const unsigned char chr : text) { for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) { if (std::isspace(chr) != 0) {
if (!out.empty()) { if (!out.empty()) {
pending_space = true; pending_space = true;
}
continue;
} }
continue;
}
if (pending_space) { if (pending_space) {
out.push_back(' '); out.push_back(' ');
pending_space = false; pending_space = false;
} }
out.push_back(static_cast<char>(chr)); out.push_back(static_cast<char>(chr));
} }
return out; return out;
} }
/** /**
@@ -67,286 +67,285 @@ static std::string CondenseWhitespace(std::string_view text) {
*/ */
static std::string PrepareRegionContext(std::string_view region_context, static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) { const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context); std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) { if (normalized.size() <= max_chars) {
return normalized; return normalized;
} }
normalized.resize(max_chars); normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' '); const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) { if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space); normalized.resize(last_space);
} }
normalized += "..."; normalized += "...";
return normalized; return normalized;
} }
static std::string ToChatPrompt(const llama_model* model, static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
std::string combined_prompt; std::string combined_prompt;
combined_prompt.append(system_prompt); combined_prompt.append(system_prompt);
combined_prompt.append("\n\n"); combined_prompt.append("\n\n");
combined_prompt.append(user_prompt); combined_prompt.append(user_prompt);
const char* tmpl = llama_model_chat_template(model, nullptr); const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) { if (tmpl == nullptr) {
// No template found, fallback to raw text // No template found, fallback to raw text
spdlog::warn( spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback"); "LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt; return combined_prompt;
} }
const std::array<llama_chat_message, 2> messages = { const std::array<llama_chat_message, 2> messages = {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}}; {{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
std::vector<char> buffer(std::max<std::size_t>( std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4)); 1024, (system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize = auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
[&](const llama_chat_message* chat_messages, int32_t message_count) -> int32_t {
int32_t message_count) -> int32_t { int32_t result = llama_chat_apply_template(
int32_t result = llama_chat_apply_template( tmpl, chat_messages, message_count, true, buffer.data(),
tmpl, chat_messages, message_count, true, buffer.data(), static_cast<int32_t>(buffer.size()));
static_cast<int32_t>(buffer.size()));
if (result < 0) {
return result;
}
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
if (result < 0) {
return result; return result;
}; }
int32_t template_result = apply_template_with_resize(messages.data(), 2); if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
if (template_result >= 0) { return result;
return {buffer.data(), static_cast<std::size_t>(template_result)}; };
}
spdlog::warn( int32_t template_result = apply_template_with_resize(messages.data(), 2);
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role), if (template_result >= 0) {
// combine the system and user prompts into a single "user" message. return {buffer.data(), static_cast<std::size_t>(template_result)};
const std::array<llama_chat_message, 1> fallback_msg = { }
{{"user", combined_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1); spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// Ultimate fallback: if GGUF template parsing still fails, use raw text. // FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
if (template_result < 0) { // combine the system and user prompts into a single "user" message.
spdlog::warn( const std::array<llama_chat_message, 1> fallback_msg = {
"LlamaGenerator: chat template fallback failed (result {}); using " {{"user", combined_prompt.c_str()}}};
"raw prompt text",
template_result);
return combined_prompt;
}
return {buffer.data(), static_cast<std::size_t>(template_result)}; template_result = apply_template_with_resize(fallback_msg.data(), 1);
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
return {buffer.data(), static_cast<std::size_t>(template_result)};
} }
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token, static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
std::array<char, 256> buffer{}; std::array<char, 256> buffer{};
int32_t bytes = llama_token_to_piece(vocab, token, buffer.data(), int32_t bytes =
buffer.size(), 0, true); llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
if (bytes < 0) { if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes)); std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), static_cast<int32_t>(dynamic_buffer.size()), 0,
0, true); true);
if (bytes < 0) { if (bytes < 0) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece"); "LlamaGenerator: failed to decode sampled token piece");
} }
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes)); output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return; return;
} }
output.append(buffer.data(), static_cast<std::size_t>(bytes)); output.append(buffer.data(), static_cast<std::size_t>(bytes));
} }
static bool ExtractLastJsonObject(const std::string& text, static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) { std::string& json_out) {
std::size_t start = std::string::npos; std::size_t start = std::string::npos;
int depth = 0; int depth = 0;
bool in_string = false; bool in_string = false;
bool escaped = false; bool escaped = false;
bool found = false; bool found = false;
std::string candidate; std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) { for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i]; const char ch = text[i];
if (in_string) { if (in_string) {
if (escaped) { if (escaped) {
escaped = false; escaped = false;
} else if (ch == '\\') { } else if (ch == '\\') {
escaped = true; escaped = true;
} else if (ch == '"') { } else if (ch == '"') {
in_string = false; in_string = false;
}
continue;
} }
continue;
}
if (ch == '"') { if (ch == '"') {
in_string = true; in_string = true;
continue; continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
} }
++depth;
continue;
}
if (ch == '{') { if (ch == '}') {
if (depth == 0) { if (depth == 0) {
start = i; continue;
}
++depth;
continue;
} }
--depth;
if (ch == '}') { if (depth == 0 && start != std::string::npos) {
if (depth == 0) { candidate = text.substr(start, i - start + 1);
continue; found = true;
}
--depth;
if (depth == 0 && start != std::string::npos) {
candidate = text.substr(start, i - start + 1);
found = true;
}
} }
} }
}
if (!found) { if (!found) {
return false; return false;
} }
json_out = std::move(candidate); json_out = std::move(candidate);
return true; return true;
} }
std::string ExtractLastJsonObjectPublic(const std::string& text) { std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted; std::string extracted;
if (ExtractLastJsonObject(text, extracted)) { if (ExtractLastJsonObject(text, extracted)) {
return extracted; return extracted;
} }
return {}; return {};
} }
static std::optional<std::string> ValidateBreweryJson( static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out, const std::string& raw, std::string& name_out,
std::string& description_out) { std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv, auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool { std::string& error_out) -> bool {
if (!jv.is_object()) { if (!jv.is_object()) {
error_out = "JSON root must be an object"; error_out = "JSON root must be an object";
return false; return false;
} }
const auto& obj = jv.get_object(); const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) { if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string"; error_out = "JSON field 'name' is missing or not a string";
return false; return false;
} }
if (!obj.contains("description") || !obj.at("description").is_string()) { if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string"; error_out = "JSON field 'description' is missing or not a string";
return false; return false;
} }
const auto& name_value = obj.at("name").as_string(); const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string(); const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size())); name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim( description_out = Trim(
std::string_view(description_value.data(), description_value.size())); std::string_view(description_value.data(), description_value.size()));
if (name_out.empty()) { if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty"; error_out = "JSON field 'name' must not be empty";
return false; return false;
} }
if (description_out.empty()) { if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty"; error_out = "JSON field 'description' must not be empty";
return false; return false;
} }
std::string name_lower = name_out; std::string name_lower = name_out;
std::string description_lower = description_out; std::string description_lower = description_out;
std::transform( std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(), name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); }); [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(), std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) { description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (name_lower == "string" || description_lower == "string") { if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content"; error_out = "JSON appears to be a schema placeholder, not content";
return false; return false;
} }
error_out.clear(); error_out.clear();
return true; return true;
}; };
boost::system::error_code ec; boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec); boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error; std::string validation_error;
if (ec) { if (ec) {
std::string extracted; std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) { if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message(); return "JSON parse error: " + ec.message();
} }
ec.clear(); ec.clear();
jv = boost::json::parse(extracted, ec); jv = boost::json::parse(extracted, ec);
if (ec) { if (ec) {
return "JSON parse error: " + ec.message(); return "JSON parse error: " + ec.message();
} }
if (!validate_object(jv, validation_error)) { if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
return validation_error; return validation_error;
} }
return std::nullopt; return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
} }
// Forward declarations for helper functions exposed to other translation units // Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context, std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) { std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars); return PrepareRegionContext(region_context, max_chars);
} }
std::string ToChatPromptPublic(const llama_model* model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt); return ToChatPrompt(model, system_prompt, user_prompt);
} }
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
AppendTokenPiece(vocab, token, output); AppendTokenPiece(vocab, token, output);
} }
std::optional<std::string> ValidateBreweryJsonPublic( std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out, const std::string& raw, std::string& name_out,
std::string& description_out) { std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out); return ValidateBreweryJson(raw, name_out, description_out);
} }

View File

@@ -14,32 +14,32 @@
#include "llama.h" #include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) { void LlamaGenerator::Load(const std::string& model_path) {
if (context_ != nullptr) { if (context_ != nullptr) {
llama_free(context_); llama_free(context_);
context_ = nullptr; context_ = nullptr;
} }
if (model_ != nullptr) { if (model_ != nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
} }
const llama_model_params model_params = llama_model_default_params(); const llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params); model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) { if (model_ == nullptr) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path); "LlamaGenerator: failed to load model from path: " + model_path);
} }
llama_context_params context_params = llama_context_default_params(); llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_; context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000)); context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
context_ = llama_init_from_model(model_, context_params); context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) { if (context_ == nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context"); throw std::runtime_error("LlamaGenerator: failed to create context");
} }
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
} }

View File

@@ -21,40 +21,39 @@ namespace fs = std::filesystem;
* @return Prompt text loaded from disk. * @return Prompt text loaded from disk.
*/ */
std::string LlamaGenerator::LoadBrewerySystemPrompt( std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) { const std::string& prompt_file_path) {
// Return cached version if already loaded // Return cached version if already loaded
if (!brewery_system_prompt_.empty()) { if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_; return brewery_system_prompt_;
} }
// Try the provided path only // Try the provided path only
const fs::path prompt_path(prompt_file_path); const fs::path prompt_path(prompt_file_path);
std::ifstream prompt_file(prompt_path); std::ifstream prompt_file(prompt_path);
if (!prompt_file.is_open()) { if (!prompt_file.is_open()) {
spdlog::error( spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'", "LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_path.string()); prompt_path.string());
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " + "LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string()); prompt_path.string());
} }
const std::string prompt((std::istreambuf_iterator(prompt_file)), const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
prompt_file.close(); prompt_file.close();
if (prompt.empty()) { if (prompt.empty()) {
spdlog::error( spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
"LlamaGenerator: Brewery system prompt file '{}' is empty", prompt_path.string());
prompt_path.string()); throw std::runtime_error(
throw std::runtime_error( "LlamaGenerator: empty brewery system prompt file: " +
"LlamaGenerator: empty brewery system prompt file: " + prompt_path.string());
prompt_path.string()); }
}
spdlog::info( spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)", "LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_path.string(), prompt.length()); prompt_path.string(), prompt.length());
brewery_system_prompt_ = prompt; brewery_system_prompt_ = prompt;
return brewery_system_prompt_; return brewery_system_prompt_;
} }

View File

@@ -9,8 +9,8 @@
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
size_t MockGenerator::DeterministicHash(const Location& location) { size_t MockGenerator::DeterministicHash(const Location& location) {
size_t seed = 0; size_t seed = 0;
boost::hash_combine(seed, location.city); boost::hash_combine(seed, location.city);
boost::hash_combine(seed, location.country); boost::hash_combine(seed, location.country);
return seed; return seed;
} }

View File

@@ -12,31 +12,31 @@
BreweryResult MockGenerator::GenerateBrewery( BreweryResult MockGenerator::GenerateBrewery(
const Location& location, const std::string& /*region_context*/) { const Location& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location); const std::size_t hash = DeterministicHash(location);
const std::string_view adjective = const std::string_view adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size()); kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun = const std::string_view noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size()); kBreweryNouns.at(hash / 7 % kBreweryNouns.size());
const std::string_view base_description = const std::string_view base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size()); kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string name = const std::string name =
std::format("{} {} {}", location.city, adjective, noun); std::format("{} {} {}", location.city, adjective, noun);
const std::string state_suffix = const std::string state_suffix =
location.state_province.empty() location.state_province.empty()
? std::string{} ? std::string{}
: std::format(", {}", location.state_province); : std::format(", {}", location.state_province);
const std::string country_suffix = const std::string country_suffix =
location.country.empty() ? std::string{} location.country.empty() ? std::string{}
: std::format(", {}", location.country); : std::format(", {}", location.country);
const std::string description = std::format( const std::string description =
"{} Located in {}{}{}.", base_description, location.city, std::format("{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix); state_suffix, country_suffix);
return { return {
.name = name, .name = name,
.description = description, .description = description,
}; };
} }

View File

@@ -11,12 +11,12 @@
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) { UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale); const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result; UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()]; const std::string_view username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()]; const std::string_view bio = kBios[hash / 11 % kBios.size()];
result.username = username; result.username = username;
result.bio = bio; result.bio = bio;
return result; return result;
} }

View File

@@ -16,72 +16,72 @@
static std::string ReadRequiredString(const boost::json::object& object, static std::string ReadRequiredString(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) { if (value == nullptr || !value->is_string()) {
throw std::runtime_error( throw std::runtime_error(std::string("Missing or invalid string field: ") +
std::string("Missing or invalid string field: ") + key); key);
} }
const std::string_view text = value->as_string(); const std::string_view text = value->as_string();
return std::string(text); return std::string(text);
} }
static double ReadRequiredNumber(const boost::json::object& object, static double ReadRequiredNumber(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) { if (value == nullptr || !value->is_number()) {
throw std::runtime_error( throw std::runtime_error(std::string("Missing or invalid numeric field: ") +
std::string("Missing or invalid numeric field: ") + key); key);
} }
return value->to_number<double>(); return value->to_number<double>();
} }
std::vector<Location> JsonLoader::LoadLocations( std::vector<Location> JsonLoader::LoadLocations(
const std::filesystem::path& filepath) { const std::filesystem::path& filepath) {
std::ifstream input(filepath); std::ifstream input(filepath);
if (!input.is_open()) { if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " + throw std::runtime_error("Failed to open locations file: " +
filepath.string()); filepath.string());
} }
std::stringstream buffer; std::stringstream buffer;
buffer << input.rdbuf(); buffer << input.rdbuf();
const std::string content = buffer.str(); const std::string content = buffer.str();
boost::system::error_code error; boost::system::error_code error;
boost::json::value root = boost::json::parse(content, error); boost::json::value root = boost::json::parse(content, error);
if (error) { if (error) {
throw std::runtime_error("Failed to parse locations JSON: " + throw std::runtime_error("Failed to parse locations JSON: " +
error.message()); error.message());
} }
if (!root.is_array()) { if (!root.is_array()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
for (const auto& item : items) {
if (!item.is_object()) {
throw std::runtime_error( throw std::runtime_error(
"Invalid locations JSON: root element must be an array"); "Invalid locations JSON: each entry must be an object");
} }
std::vector<Location> locations; const auto& object = item.as_object();
const auto& items = root.as_array(); locations.push_back(Location{
locations.reserve(items.size()); .city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
for (const auto& item : items) { spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
if (!item.is_object()) { filepath.string());
throw std::runtime_error( return locations;
"Invalid locations JSON: each entry must be an object");
}
const auto& object = item.as_object();
locations.push_back(Location{
.city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath.string());
return locations;
} }

View File

@@ -34,151 +34,150 @@ namespace di = boost::di;
* @return Parsed ApplicationOptions if parsing succeeded, std::nullopt * @return Parsed ApplicationOptions if parsing succeeded, std::nullopt
* otherwise. * otherwise.
*/ */
std::optional<ApplicationOptions> ParseArguments(const int argc, std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
char** argv) { prog_opts::options_description desc("Pipeline Options");
prog_opts::options_description desc("Pipeline Options");
auto opt = desc.add_options(); auto opt = desc.add_options();
opt("help,h", "Produce help message"); opt("help,h", "Produce help message");
opt("mocked", prog_opts::bool_switch(), opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data"); "Use mocked generator for brewery/user data");
opt("model,m", prog_opts::value<std::string>()->default_value(""), opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)"); "Path to LLM model (gguf)");
opt("temperature", prog_opts::value<float>()->default_value(1.0F), opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)"); "Sampling temperature (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F), opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)"); "Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64), opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)"); "Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192), opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)"); "Context window size in tokens (1-32768)");
opt("seed", prog_opts::value<int>()->default_value(-1), opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer"); "Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case // Handle the "no arguments" or "help" case
if (argc == 1) { if (argc == 1) {
spdlog::info("Biergarten Pipeline"); spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream; std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc; usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str()); spdlog::info(usage_stream.str());
return std::nullopt;
}
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt; return std::nullopt;
} }
try { const auto use_mocked = variables_map["mocked"].as<bool>();
prog_opts::variables_map variables_map; const auto model_path = variables_map["model"].as<std::string>();
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
if (variables_map.contains("help")) { if (use_mocked && !model_path.empty()) {
std::stringstream help_stream; spdlog::error(
help_stream << "\n" << desc; "Invalid arguments: --mocked and --model are mutually exclusive");
spdlog::info(help_stream.str());
return std::nullopt;
}
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt;
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt;
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return std::nullopt; return std::nullopt;
} catch (...) { }
spdlog::error("Failed to parse command-line arguments: unknown error");
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt; return std::nullopt;
} }
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return std::nullopt;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
return std::nullopt;
}
} }
int main(const int argc, char** argv) { int main(const int argc, char** argv) {
try { try {
const CurlGlobalState curl_state; const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state; const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v"); spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
const auto parsed_options = ParseArguments(argc, argv); const auto parsed_options = ParseArguments(argc, argv);
if (!parsed_options.has_value()) { if (!parsed_options.has_value()) {
return 0; return 0;
} }
const auto options = *parsed_options; const auto options = *parsed_options;
const auto injector = di::make_injector( const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(), di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options), di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(), di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path), di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to([options](const auto& inj) di::bind<DataGenerator>().to(
-> std::unique_ptr<DataGenerator> { [options](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.use_mocked) { if (options.use_mocked) {
spdlog::info( spdlog::info(
"[Generator] Using MockGenerator (no model path provided)"); "[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>(); return std::make_unique<MockGenerator>();
} }
spdlog::info( spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, " "[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})", "top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p, options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed); options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>(); return inj.template create<std::unique_ptr<LlamaGenerator>>();
})); }));
auto generator = injector.create<BiergartenDataGenerator>(); auto generator = injector.create<BiergartenDataGenerator>();
if (!generator.Run()) { if (!generator.Run()) {
spdlog::error("Pipeline execution failed"); spdlog::error("Pipeline execution failed");
return 1;
}
spdlog::info("Pipeline executed successfully");
return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1; return 1;
} catch (...) { }
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1; spdlog::info("Pipeline executed successfully");
} return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1;
} catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
} }

View File

@@ -12,51 +12,50 @@
#include "services/wikipedia_service.h" #include "services/wikipedia_service.h"
std::string WikipediaService::FetchExtract(std::string_view query) { std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string cache_key(query); const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key); const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) { if (cache_it != this->extract_cache_.end()) {
return cache_it->second; return cache_it->second;
} }
const std::string encoded = this->client_->UrlEncode(cache_key); const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url = const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json"; "&prop=extracts&explaintext=1&format=json";
const std::string body = this->client_->Get(url); const std::string body = this->client_->Get(url);
boost::system::error_code parse_error; boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error); boost::json::value doc = boost::json::parse(body, parse_error);
if (!parse_error && doc.is_object()) { if (!parse_error && doc.is_object()) {
try { try {
auto& pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto& page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = const std::string_view extract_view = page.at("extract").as_string();
page.at("extract").as_string(); std::string extract(extract_view);
std::string extract(extract_view);
spdlog::debug("WikipediaService fetched {} chars for '{}'", spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query); extract.size(), query);
this->extract_cache_.emplace(cache_key, extract); this->extract_cache_.emplace(cache_key, extract);
return extract; return extract;
} }
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
} }
} else if (parse_error) { this->extract_cache_.emplace(cache_key, std::string{});
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query, } catch (const std::exception& e) {
parse_error.message()); spdlog::warn(
} "WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
return {}; return {};
} }

View File

@@ -10,10 +10,10 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
CurlGlobalState::CurlGlobalState() { CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally"); "[CURLWebClient] Failed to initialize libcurl globally");
} }
} }
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); } CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }

View File

@@ -15,63 +15,61 @@
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>; using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() { static CurlHandle create_handle() {
CURL* handle = curl_easy_init(); CURL* handle = curl_easy_init();
if (handle == nullptr) { if (handle == nullptr) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle"); "[CURLWebClient] Failed to initialize libcurl handle");
} }
return CurlHandle(handle, &curl_easy_cleanup); return CurlHandle(handle, &curl_easy_cleanup);
} }
static void set_common_get_options(CURL* curl, const std::string& url) { static void set_common_get_options(CURL* curl, const std::string& url) {
constexpr uint64_t connection_timeout = 10; constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30; constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout); curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout); curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
} }
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, const size_t size, static size_t WriteCallbackString(void* contents, const size_t size,
const size_t nmemb, const size_t nmemb, void* userp) {
void* userp) { const size_t real_size = size * nmemb;
const size_t real_size = size * nmemb; auto* str = static_cast<std::string*>(userp);
auto* str = static_cast<std::string*>(userp); str->append(static_cast<char*>(contents), real_size);
str->append(static_cast<char*>(contents), real_size); return real_size;
return real_size;
} }
std::string CURLWebClient::Get(const std::string& url) { std::string CURLWebClient::Get(const std::string& url) {
const CurlHandle curl = create_handle(); const CurlHandle curl = create_handle();
std::string response_string; std::string response_string;
set_common_get_options(curl.get(), url); set_common_get_options(curl.get(), url);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) { if (res != CURLE_OK) {
const auto error = const auto error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error); throw std::runtime_error(error);
} }
int64_t httpCode = 0; int64_t httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) { if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " + const std::string error = "[CURLWebClient] HTTP error " +
std::to_string(httpCode) + std::to_string(httpCode) + " for URL " + url;
" for URL " + url; throw std::runtime_error(error);
throw std::runtime_error(error); }
}
return response_string; return response_string;
} }

View File

@@ -11,14 +11,14 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
std::string CURLWebClient::UrlEncode(const std::string& value) { std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (!output) { if (!output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
} }
std::string result(output); std::string result(output);
curl_free(output); curl_free(output);
return result; return result;
} }