8 Commits

Author SHA1 Message Date
Aaron Po
299a767d39 remove unused code 2026-04-11 14:42:32 -04:00
Aaron Po
f07d48f810 Add missing includes, update readme 2026-04-11 14:31:24 -04:00
Aaron Po
bcfde856fe Split data models into dedicated headers 2026-04-11 13:21:50 -04:00
Aaron Po
5946356083 Style audit: update code to strictly follow Google Style Guide 2026-04-11 11:56:45 -04:00
Aaron Po
ae67fa8566 refactor: consolidate and rename data generation and service files 2026-04-11 00:06:23 -04:00
Aaron Po
8c572a2d07 fix: stabilize Gemma 4 brewery generation
remove misleading turn-token output guidance from the brewery prompt
extract the last balanced JSON object before validation
keep README model setup and run instructions aligned
preserve Gemma 4 sampling defaults and local model usage
2026-04-10 22:25:26 -04:00
Aaron Po
902bda6eb9 eat: make Gemma 4 the default model, enable thinking mode 2026-04-10 21:43:18 -04:00
Aaron Po
61d5077a95 update readme 2026-04-10 00:03:45 -04:00
50 changed files with 1744 additions and 1528 deletions

View File

@@ -1,5 +1,5 @@
---
BasedOnStyle: Google
ColumnLimit: 80
IndentWidth: 2
IndentWidth: 3
...

View File

@@ -9,23 +9,21 @@ Checks: >
-google-runtime-references
CheckOptions:
# Enforce Google Naming Conventions with valid clang-tidy strings
- key: readability-identifier-naming.ClassCase
value: CamelCase
# Enforce Google Naming Conventions
- key: readability-identifier-naming.ClassMemberCase
value: lower_case
value: snake_case
- key: readability-identifier-naming.ClassMemberSuffix
value: _
- key: readability-identifier-naming.ClassCase
value: PascalCase
- key: readability-identifier-naming.FunctionCase
value: CamelCase
value: PascalCase
- key: readability-identifier-naming.StructCase
value: CamelCase
value: PascalCase
- key: readability-identifier-naming.VariableCase
value: lower_case
value: snake_case
- key: readability-identifier-naming.GlobalConstantCase
value: CamelCase
- key: readability-identifier-naming.GlobalConstantPrefix
value: k
value: kPascalCase
# Ensure C++20 Modernization
- key: modernize-make-unique.MakeSmartPtrFunction

1
pipeline/.gitignore vendored
View File

@@ -1,6 +1,5 @@
dist
build
cmake-build-*
data
models
*.gguf

View File

@@ -1,10 +1,12 @@
cmake_minimum_required(VERSION 3.24)
project(biergarten-pipeline)
# Boost.DI still declares a very old minimum CMake version, which newer CMake
# releases reject unless a policy version floor is provided.
set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "" FORCE)
# =============================================================================
# 1. Platform & GPU Detection
# 1. Platform & GPU Detection (Windows explicitly NOT supported)
# =============================================================================
if(WIN32)
message(FATAL_ERROR "[biergarten] Windows is currently not supported. Please use Linux (Fedora 43) or macOS (M1 Pro).")
@@ -38,15 +40,18 @@ endif()
# 2. Project-wide Settings (Standard & Optimization)
# =============================================================================
# Downgrade to C++20 as per Google Style Guide
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# GCC/Clang specific settings (warnings as errors)
add_compile_options(-Wall -Wextra -Werror -Wpedantic)
# Release Build Optimization: Aggressive (-O3), Arch-specific, and LTO
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native -flto")
# Debug Build Optimization: Fast and debuggable (-Og)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og -g")
# =============================================================================
@@ -101,6 +106,7 @@ set(SOURCES
src/services/wikipedia/fetch_extract.cpp
src/web_client/curl_global_state.cpp
src/web_client/curl_web_client_get.cpp
src/web_client/curl_web_client_utils.cpp
src/web_client/curl_web_client_url_encode.cpp
src/data_generation/llama/llama_generator.cpp
src/data_generation/llama/generate_brewery.cpp
@@ -109,6 +115,7 @@ set(SOURCES
src/data_generation/llama/infer.cpp
src/data_generation/llama/load.cpp
src/data_generation/llama/load_brewery_prompt.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
@@ -141,9 +148,3 @@ configure_file(
${CMAKE_BINARY_DIR}/locations.json
COPYONLY
)
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_SOURCE_DIR}/prompts
${CMAKE_BINARY_DIR}/prompts
)

View File

@@ -76,7 +76,7 @@ curl -L \
## Run
Run the executable from the build directory so the copied `locations.json` and `prompts/` directory are available.
Run the executable from the build directory so the copied `locations.json` is available.
```bash
./biergarten-pipeline --mocked

View File

@@ -1,7 +1,7 @@
@startuml BiergartenPipeline
title Biergarten Pipeline - Class and Composition Diagram
top to bottom direction
left to right direction
skinparam shadowing false
skinparam classAttributeIconSize 0
skinparam packageStyle rectangle
@@ -16,17 +16,11 @@ package "Composition root" {
+~CurlGlobalState()
}
class LlamaBackendState {
+LlamaBackendState()
+~LlamaBackendState()
}
note right of Main
Binds with Boost.DI:
- WebClient -> CURLWebClient
- IEnrichmentService -> WikipediaService
- DataGenerator -> MockGenerator or LlamaGenerator
- std::string -> model_path
- LlamaGenerator receives ApplicationOptions and model_path directly
end note
}
@@ -38,8 +32,8 @@ package "Core orchestration" {
-generated_breweries_: std::vector<GeneratedBrewery>
+BiergartenDataGenerator(context_service: std::shared_ptr<IEnrichmentService>, generator: std::unique_ptr<DataGenerator>)
+Run(): bool
{static} -QueryCitiesWithCountries(): std::vector<Location>
-GenerateBreweries(cities: const std::vector<EnrichedCity>&): void
-QueryCitiesWithCountries(): std::vector<Location>
-GenerateBreweries(cities: std::vector<EnrichedCity>): void
-LogResults(): void
}
}
@@ -55,6 +49,11 @@ package "Data models" {
+seed: int
}
class BreweryLocation <<struct>> {
+city_name: std::string_view
+country_name: std::string_view
}
class Location <<struct>> {
+city: std::string
+state_province: std::string
@@ -88,64 +87,67 @@ package "Data models" {
package "Generation" {
interface DataGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult
+GenerateUser(locale: const std::string&): UserResult
+GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: std::string): UserResult
}
class MockGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult
+GenerateUser(locale: const std::string&): UserResult
+GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: std::string): UserResult
}
class LlamaGenerator {
+LlamaGenerator(options: const ApplicationOptions&, model_path: const std::string&)
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult
+GenerateUser(locale: const std::string&): UserResult
+LlamaGenerator(options: ApplicationOptions, model_path: std::string)
+GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: std::string): UserResult
}
}
package "HTTP" {
interface WebClient {
+Get(url: const std::string&): std::string
+UrlEncode(value: const std::string&): std::string
+DownloadToFile(url: std::string, file_path: std::string): void
+Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
}
class CURLWebClient {
+Get(url: const std::string&): std::string
+UrlEncode(value: const std::string&): std::string
+CURLWebClient()
+~CURLWebClient()
+DownloadToFile(url: std::string, file_path: std::string): void
+Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
}
}
package "JSON handling" {
class JsonLoader {
{static} +LoadLocations(filepath: const std::string&): std::vector<Location>
{static} +LoadLocations(filepath: std::string): std::vector<Location>
}
}
package "Wikipedia" {
interface IEnrichmentService {
+GetLocationContext(loc: const Location&): std::string
+GetLocationContext(loc: Location): std::string
}
class WikipediaService {
+WikipediaService(client: std::unique_ptr<WebClient>)
+GetLocationContext(loc: const Location&): std::string
+WikipediaService(client: std::shared_ptr<WebClient>)
+GetLocationContext(loc: Location): std::string
}
}
Main --> CurlGlobalState
Main --> LlamaBackendState
Main --> ApplicationOptions
Main --> BiergartenDataGenerator
Main ..> IEnrichmentService : DI binding
Main ..> DataGenerator : DI factory
Main ..> CURLWebClient : DI binding
BiergartenDataGenerator *-- EnrichedCity
BiergartenDataGenerator *-- GeneratedBrewery
BiergartenDataGenerator ..> JsonLoader : LoadLocations()
BiergartenDataGenerator --> IEnrichmentService : context lookup
BiergartenDataGenerator --> DataGenerator : brewery generation
BiergartenDataGenerator ..> EnrichedCity
BiergartenDataGenerator ..> Location
BiergartenDataGenerator ..> BreweryResult
@@ -154,7 +156,7 @@ DataGenerator <|.. LlamaGenerator
WebClient <|.. CURLWebClient
IEnrichmentService <|.. WikipediaService
WikipediaService *-- WebClient : unique_ptr
WikipediaService --> WebClient : shared_ptr
note right of BiergartenDataGenerator
Current behavior:

View File

@@ -7,7 +7,6 @@
*/
#include <memory>
#include <span>
#include <vector>
#include "data_generation/data_generator.h"
@@ -23,55 +22,55 @@
* It handles location loading, city enrichment, and brewery generation.
*/
class BiergartenDataGenerator {
public:
/**
* @brief Construct a BiergartenDataGenerator with injected dependencies.
*
* @param context_service Context provider for sampled locations.
* @param generator Brewery and user data generator.
*/
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator);
public:
/**
* @brief Construct a BiergartenDataGenerator with injected dependencies.
*
* @param context_service Context provider for sampled locations.
* @param generator Brewery and user data generator.
*/
BiergartenDataGenerator(std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator);
/**
* @brief Run the data generation pipeline.
*
* Performs the following steps:
* 1. Load curated locations from JSON
* 2. Resolve context for each city using the injected context service
* 3. Generate brewery data for sampled cities
*
* @return true if successful, false if not
*/
bool Run();
/**
* @brief Run the data generation pipeline.
*
* Performs the following steps:
* 1. Load curated locations from JSON
* 2. Resolve context for each city using the injected context service
* 3. Generate brewery data for sampled cities
*
* @return true if successful, false if not
*/
bool Run();
private:
/// @brief Owning context provider dependency.
std::unique_ptr<IEnrichmentService> context_service_;
private:
/// @brief Shared context provider dependency.
std::shared_ptr<IEnrichmentService> context_service_;
/// @brief Generator dependency selected in the composition root.
std::unique_ptr<DataGenerator> generator_;
/// @brief Generator dependency selected in the composition root.
std::unique_ptr<DataGenerator> generator_;
/**
* @brief Load locations from JSON and sample cities.
*
* @return Vector of sampled locations capped at 4 entries.
*/
static std::vector<Location> QueryCitiesWithCountries();
/**
* @brief Load locations from JSON and sample cities.
*
* @return Vector of sampled locations capped at 4 entries.
*/
static std::vector<Location> QueryCitiesWithCountries();
/**
* @brief Generate breweries for enriched cities.
*
* @param cities Span of enriched city data.
*/
void GenerateBreweries(std::span<const EnrichedCity> cities);
/**
* @brief Generate breweries for enriched cities.
*
* @param cities Vector of enriched city data.
*/
void GenerateBreweries(const std::vector<EnrichedCity>& cities);
/**
* @brief Log the generated brewery results.
*/
void LogResults() const;
/**
* @brief Log the generated brewery results.
*/
void LogResults() const;
/// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generated_breweries_;
/// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generated_breweries_;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_BIERGARTEN_DATA_GENERATOR_H_

View File

@@ -8,34 +8,35 @@
#include <string>
#include "data_model/brewery_location.h"
#include "data_model/brewery_result.h"
#include "data_model/location.h"
#include "data_model/user_result.h"
/**
* @brief Interface for data generator implementations.
*/
class DataGenerator {
public:
virtual ~DataGenerator() = default;
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~DataGenerator() = default;
/**
* @brief Generates brewery data for a location.
*
* @param location Location data
* @param region_context Additional regional context text.
* @return Brewery generation result.
*/
virtual BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) = 0;
/**
* @brief Generates brewery data for a location.
*
* @param location City and country names.
* @param region_context Additional regional context text.
* @return Brewery generation result.
*/
virtual BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) = 0;
/**
* @brief Generates a user profile for a locale.
*
* @param locale Locale hint used by generator.
* @return User generation result.
*/
virtual UserResult GenerateUser(const std::string& locale) = 0;
/**
* @brief Generates a user profile for a locale.
*
* @param locale Locale hint used by generator.
* @return User generation result.
*/
virtual UserResult GenerateUser(const std::string& locale) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -16,116 +16,107 @@
struct llama_model;
struct llama_context;
struct LlamaSampler;
/**
* @brief Data generator implementation backed by llama.cpp.
*/
class LlamaGenerator final : public DataGenerator {
public:
/**
* @brief Constructs a generator using parsed application options and loads
* the configured model immediately.
*
* @param options Parsed application options.
* @param model_path Filesystem path to GGUF model assets.
*/
LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path);
public:
/**
* @brief Constructs a generator using parsed application options and loads
* the configured model immediately.
*
* @param options Parsed application options.
* @param model_path Filesystem path to GGUF model assets.
*/
LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path);
/// @brief Releases model/context resources.
~LlamaGenerator() override;
/// @brief Releases model/context resources.
~LlamaGenerator() override;
LlamaGenerator(const LlamaGenerator&) = delete;
LlamaGenerator& operator=(const LlamaGenerator&) = delete;
LlamaGenerator(LlamaGenerator&&) = delete;
LlamaGenerator& operator=(LlamaGenerator&&) = delete;
/**
* @brief Generates brewery data for a specific location.
*
* @param location City and country names.
* @param region_context Additional regional context.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override;
/**
* @brief Generates brewery data for a specific location.
*
* @param location Location object.
* @param region_context Additional regional context.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) override;
/**
* @brief Generates a user profile for the provided locale.
*
* @param locale Locale hint.
* @return Generated user profile.
*/
UserResult GenerateUser(const std::string& locale) override;
/**
* @brief Generates a user profile for the provided locale.
*
* @param locale Locale hint.
* @return Generated user profile.
*/
UserResult GenerateUser(const std::string& locale) override;
private:
/**
* @brief Loads model and prepares inference context.
*
* @param model_path Filesystem path to GGUF model.
*/
void Load(const std::string& model_path);
private:
static constexpr int kDefaultMaxTokens = 10000;
static constexpr float kDefaultSamplingTopP = 0.95F;
static constexpr uint32_t kDefaultSamplingTopK = 64;
static constexpr uint32_t kDefaultContextSize = 8192;
/**
* @brief Infers text from a user prompt.
*
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& prompt, int max_tokens = 10000);
struct SamplerState {
SamplerState() = default;
~SamplerState();
/**
* @brief Infers text from separate system and user prompts.
*
* This helps chat-capable models preserve system-role behavior instead of
* concatenating system text into user input.
*
* @param system_prompt System role prompt.
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens = 10000);
SamplerState(const SamplerState&) = delete;
SamplerState& operator=(const SamplerState&) = delete;
SamplerState(SamplerState&&) = delete;
SamplerState& operator=(SamplerState&&) = delete;
/**
* @brief Runs inference on an already-formatted prompt.
*
* @param formatted_prompt Prompt preformatted for model chat template.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = 10000);
LlamaSampler* chain = nullptr;
};
/**
* @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text or fallback prompt.
*/
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
/**
* @brief Loads model and prepares inference context.
*
* @param model_path Filesystem path to GGUF model.
*/
void Load(const std::string& model_path);
/**
* @brief Returns a built-in fallback system prompt.
*
* @return Fallback prompt text.
*/
std::string GetFallbackBreweryPrompt();
/**
* @brief Infers text from separate system and user prompts.
*
* This helps chat-capable models preserve system-role behavior instead of
* concatenating system text into user input.
*
* @param system_prompt System role prompt.
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt, const std::string& prompt,
int max_tokens = kDefaultMaxTokens);
/**
* @brief Runs inference on an already-formatted prompt.
*
* @param formatted_prompt Prompt preformatted for model chat template.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = kDefaultMaxTokens);
/**
* @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text.
*/
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
llama_model* model_ = nullptr;
llama_context* context_ = nullptr;
/// @brief Persistent sampler chain reused across inference calls.
std::unique_ptr<SamplerState> sampler_;
float sampling_temperature_ = 1.0F;
float sampling_top_p_ = kDefaultSamplingTopP;
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize;
std::string brewery_system_prompt_;
llama_model* model_ = nullptr;
llama_context* context_ = nullptr;
float sampling_temperature_ = 1.0F;
float sampling_top_p_ = 0.95F;
uint32_t sampling_top_k_ = 64;
std::mt19937 rng_;
uint32_t n_ctx_ = 8192;
std::string brewery_system_prompt_;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -7,7 +7,6 @@
*/
#include <cstddef>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
@@ -36,6 +35,16 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message);
/**
* @brief Applies model chat template to a user-only prompt.
*
* @param model Loaded llama model.
* @param user_prompt User prompt text.
* @return Model-formatted prompt.
*/
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
/**
* @brief Applies model chat template to system and user prompts.
*
@@ -64,11 +73,11 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
* @param raw Raw model output.
* @param name_out Parsed brewery name.
* @param description_out Parsed brewery description.
* @return Validation error message if invalid, or std::nullopt on success.
* @return Empty string on success, or validation error message.
*/
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string& description_out);
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out);
/**
* @brief Extracts the last balanced JSON object from text.
@@ -78,4 +87,4 @@ std::optional<std::string> ValidateBreweryJsonPublic(
*/
std::string ExtractLastJsonObjectPublic(const std::string& text);
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_

View File

@@ -6,9 +6,9 @@
* @brief Deterministic mock implementation of DataGenerator.
*/
#include <array>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/data_generator.h"
@@ -16,108 +16,39 @@
* @brief Mock generator used for deterministic, model-free outputs.
*/
class MockGenerator final : public DataGenerator {
public:
/**
* @brief Generates deterministic brewery data for a location.
*
* @param location City and country names.
* @param region_context Unused for mock generation.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) override;
public:
/**
* @brief Generates deterministic brewery data for a location.
*
* @param location City and country names.
* @param region_context Unused for mock generation.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override;
/**
* @brief Generates deterministic user data for a locale.
*
* @param locale Locale hint.
* @return Generated user result.
*/
UserResult GenerateUser(const std::string& locale) override;
/**
* @brief Generates deterministic user data for a locale.
*
* @param locale Locale hint.
* @return Generated user result.
*/
UserResult GenerateUser(const std::string& locale) override;
private:
/**
* @brief Combines two strings into a stable hash value.
*
* @param location City and country names.
* @return Deterministic hash value.
*/
static std::size_t DeterministicHash(const Location& location);
private:
/**
* @brief Combines two strings into a stable hash value.
*
* @param location City and country names.
* @return Deterministic hash value.
*/
static std::size_t DeterministicHash(const BreweryLocation& location);
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18> kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark "
"porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky "
"lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food "
"pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional "
"grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward "
"sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal "
"blends.",
"Modern taproom focused on sessionable lagers and classic pub "
"styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp "
"lagers.",
"Creative brewery offering rotating collaborations and limited "
"draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich "
"stouts."};
static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward "
"styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean "
"fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};
static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns;
static const std::vector<std::string> kBreweryDescriptions;
static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -13,30 +13,30 @@
* @brief Program options for the Biergarten pipeline application.
*/
struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path;
/// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false;
/// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random).
float top_p = 0.95F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random).
float top_p = 0.95F;
/// @brief LLM top-k sampling parameter.
uint32_t top_k = 64;
/// @brief LLM top-k sampling parameter.
uint32_t top_k = 64;
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 8192;
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 2048;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_

View File

@@ -12,11 +12,11 @@
* @brief Non-owning brewery location input.
*/
struct BreweryLocation {
/// @brief City name.
std::string_view city_name;
/// @brief City name.
std::string_view city_name;
/// @brief Country name.
std::string_view country_name;
/// @brief Country name.
std::string_view country_name;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated brewery payload.
*/
struct BreweryResult {
/// @brief Brewery display name.
std::string name{};
/// @brief Brewery display name.
std::string name;
/// @brief Brewery description text.
std::string description{};
/// @brief Brewery description text.
std::string description;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_

View File

@@ -14,8 +14,8 @@
* @brief Enriched city data with Wikipedia context.
*/
struct EnrichedCity {
Location location;
std::string region_context{};
Location location;
std::string region_context;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_

View File

@@ -13,8 +13,8 @@
* @brief Helper struct to store generated brewery data.
*/
struct GeneratedBrewery {
Location location;
BreweryResult brewery;
Location location;
BreweryResult brewery;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_GENERATED_BREWERY_H_

View File

@@ -12,26 +12,26 @@
* @brief Canonical location record for city-level generation.
*/
struct Location {
/// @brief City name.
std::string city{};
/// @brief City name.
std::string city;
/// @brief State or province name.
std::string state_province{};
/// @brief State or province name.
std::string state_province;
/// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{};
/// @brief ISO 3166-2 subdivision code.
std::string iso3166_2;
/// @brief Country name.
std::string country{};
/// @brief Country name.
std::string country;
/// @brief ISO 3166-1 country code.
std::string iso3166_1{};
/// @brief ISO 3166-1 country code.
std::string iso3166_1;
/// @brief Latitude in decimal degrees.
double latitude{};
/// @brief Latitude in decimal degrees.
double latitude;
/// @brief Longitude in decimal degrees.
double longitude{};
/// @brief Longitude in decimal degrees.
double longitude;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated user profile payload.
*/
struct UserResult {
/// @brief Username handle.
std::string username{};
/// @brief Username handle.
std::string username;
/// @brief Short user biography.
std::string bio{};
/// @brief Short user biography.
std::string bio;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_

View File

@@ -6,17 +6,16 @@
* @brief Loader API for curated location data.
*/
#include <filesystem>
#include <string>
#include <vector>
#include "data_model/location.h"
/// @brief Loads curated world locations from a JSON file into memory.
class JsonLoader {
public:
/// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations(
const std::filesystem::path& filepath);
public:
/// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations(const std::string& filepath);
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -15,18 +15,18 @@
* it alive for application lifetime.
*/
class LlamaBackendState {
public:
/// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); }
public:
/// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); }
/// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); }
/// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); }
/// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_LLAMA_BACKEND_STATE_H_

View File

@@ -14,17 +14,17 @@
* @brief Interface for services that can enrich a location with context.
*/
class IEnrichmentService {
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~IEnrichmentService() = default;
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~IEnrichmentService() = default;
/**
* @brief Resolves contextual enrichment for a location.
*
* @param loc Location to enrich.
* @return Context text, or an empty string if unavailable.
*/
virtual std::string GetLocationContext(const Location& loc) = 0;
/**
* @brief Resolves contextual enrichment for a location.
*
* @param loc Location to enrich.
* @return Context text, or an empty string if unavailable.
*/
virtual std::string GetLocationContext(const Location& loc) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_SERVICE_H_

View File

@@ -14,20 +14,20 @@
#include "services/enrichment_service.h"
#include "web_client/web_client.h"
/// @brief Provides Wikipedia summary lookups backed by cached raw extracts.
/// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService final : public IEnrichmentService {
public:
/// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::unique_ptr<WebClient> client);
public:
/// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia-derived context for a location.
[[nodiscard]] std::string GetLocationContext(const Location& loc) override;
/// @brief Returns the Wikipedia-derived context for a location.
[[nodiscard]] std::string GetLocationContext(const Location& loc) override;
private:
std::string FetchExtract(std::string_view query);
std::unique_ptr<WebClient> client_;
/// @brief Canonical cache for raw Wikipedia query extracts.
std::unordered_map<std::string, std::string> extract_cache_;
private:
std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_;
std::unordered_map<std::string, std::string> extract_cache_;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_WIKIPEDIA_SERVICE_H_

View File

@@ -15,40 +15,40 @@
* alive for application lifetime.
*/
class CurlGlobalState {
public:
/// @brief Initializes global libcurl state.
CurlGlobalState();
public:
/// @brief Initializes global libcurl state.
CurlGlobalState();
/// @brief Cleans up global libcurl state.
~CurlGlobalState();
/// @brief Cleans up global libcurl state.
~CurlGlobalState();
/// @brief Non-copyable type.
CurlGlobalState(const CurlGlobalState&) = delete;
/// @brief Non-copyable type.
CurlGlobalState(const CurlGlobalState&) = delete;
/// @brief Non-copyable type.
CurlGlobalState& operator=(const CurlGlobalState&) = delete;
/// @brief Non-copyable type.
CurlGlobalState& operator=(const CurlGlobalState&) = delete;
};
/**
* @brief WebClient implementation backed by libcurl.
*/
class CURLWebClient : public WebClient {
public:
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
std::string Get(const std::string& url) override;
public:
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
std::string Get(const std::string& url) override;
/**
* @brief URL-encodes a string value.
*
* @param value Raw value.
* @return URL-encoded string.
*/
std::string UrlEncode(const std::string& value) override;
/**
* @brief URL-encodes a string value.
*
* @param value Raw value.
* @return URL-encoded string.
*/
std::string UrlEncode(const std::string& value) override;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -12,25 +12,25 @@
* @brief Abstract web client interface.
*/
class WebClient {
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~WebClient() = default;
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~WebClient() = default;
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
virtual std::string Get(const std::string& url) = 0;
/**
* @brief Executes an HTTP GET request.
*
* @param url Request URL.
* @return Response body.
*/
virtual std::string Get(const std::string& url) = 0;
/**
* @brief URL-encodes a string value.
*
* @param value Raw string value.
* @return Encoded value safe for URL usage.
*/
virtual std::string UrlEncode(const std::string& value) = 0;
/**
* @brief URL-encodes a string value.
*
* @param value Raw string value.
* @return Encoded value safe for URL usage.
*/
virtual std::string UrlEncode(const std::string& value) = 0;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -8,7 +8,7 @@
#include <utility>
BiergartenDataGenerator::BiergartenDataGenerator(
std::unique_ptr<IEnrichmentService> context_service,
std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator)
: context_service_(std::move(context_service)),
generator_(std::move(generator)) {}

View File

@@ -8,32 +8,34 @@
#include "biergarten_data_generator.h"
void BiergartenDataGenerator::GenerateBreweries(
std::span<const EnrichedCity> cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generated_breweries_.clear();
generated_breweries_.clear();
size_t skipped_count = 0;
size_t skipped_count = 0;
for (const auto& [location, region_context] : cities) {
try {
const BreweryResult brewery =
generator_->GenerateBrewery(location, region_context);
const GeneratedBrewery gen{.location = location, .brewery = brewery};
generated_breweries_.push_back(gen);
} catch (const std::exception& e) {
++skipped_count;
for (const auto& enriched_city : cities) {
try {
auto brewery = generator_->GenerateBrewery(
BreweryLocation{enriched_city.location.city,
enriched_city.location.country},
enriched_city.region_context);
generated_breweries_.push_back(GeneratedBrewery{
.location = enriched_city.location, .brewery = brewery});
} catch (const std::exception& e) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}",
enriched_city.location.city, enriched_city.location.country,
e.what());
}
}
if (skipped_count > 0) {
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}",
location.city, location.country, e.what());
}
}
if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors",
skipped_count);
}
"[Pipeline] Skipped {} city/cities due to generation "
"errors",
skipped_count);
}
}

View File

@@ -8,16 +8,16 @@
#include "biergarten_data_generator.h"
void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ===");
size_t index = 1;
for (const auto& [location, brewery] : generated_breweries_) {
spdlog::info(
"{}. city=\"{}\" country=\"{}\" state=\"{}\" "
"iso3166_2={} lat={} lon={}",
index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name=\"{}\"", brewery.name);
spdlog::info(" brewery_description=\"{}\"", brewery.description);
++index;
}
spdlog::info("\n=== GENERATED DATA DUMP ===");
size_t index = 1;
for (const auto& [location, brewery] : generated_breweries_) {
spdlog::info(
"{}. city=\"{}\" country=\"{}\" state=\"{}\" "
"iso3166_2={} lat={} lon={}",
index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name=\"{}\"", brewery.name);
spdlog::info(" brewery_description=\"{}\"", brewery.description);
++index;
}
}

View File

@@ -13,28 +13,28 @@
#include "biergarten_data_generator.h"
#include "json_handling/json_loader.h"
static constexpr std::size_t kBreweryAmount = 4;
static constexpr unsigned int brewery_amount = 4;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
const std::filesystem::path locations_path = "locations.json";
const std::filesystem::path locations_path = "locations.json";
auto all_locations = JsonLoader::LoadLocations(locations_path);
spdlog::info(" Locations available: {}", all_locations.size());
auto all_locations = JsonLoader::LoadLocations(locations_path.string());
spdlog::info(" Locations available: {}", all_locations.size());
const std::size_t sample_count =
std::min(kBreweryAmount, all_locations.size());
const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count);
std::vector<Location> sampled_locations;
sampled_locations.reserve(sample_count);
const size_t sample_count =
std::min<size_t>(brewery_amount, all_locations.size());
const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count);
std::vector<Location> sampled_locations;
sampled_locations.reserve(sample_count);
std::random_device random_generator;
std::ranges::sample(all_locations, std::back_inserter(sampled_locations),
sample_count_signed, random_generator);
std::random_device random_generator;
std::ranges::sample(all_locations, std::back_inserter(sampled_locations),
sample_count_signed, random_generator);
spdlog::info(" Sampled locations: {}", sampled_locations.size());
return sampled_locations;
spdlog::info(" Sampled locations: {}", sampled_locations.size());
return sampled_locations;
}

View File

@@ -8,40 +8,40 @@
#include "biergarten_data_generator.h"
bool BiergartenDataGenerator::Run() {
try {
const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
try {
const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size());
size_t skipped_count = 0;
for (const auto& city : cities) {
try {
const std::string region_context =
context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
size_t skipped_count = 0;
for (const auto& city : cities) {
try {
const std::string region_context =
context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
enriched.push_back(
EnrichedCity{.location = city, .region_context = region_context});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what());
enriched.push_back(EnrichedCity{.location = city,
.region_context = region_context});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what());
}
}
}
if (skipped_count > 0) {
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count);
}
if (skipped_count > 0) {
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count);
}
this->GenerateBreweries(enriched);
this->LogResults();
return true;
} catch (const std::exception& e) {
spdlog::error("Pipeline execution failed with error: {}", e.what());
return false;
}
this->GenerateBreweries(enriched);
this->LogResults();
return true;
} catch (const std::exception& e) {
spdlog::error("Pipeline execution failed with error: {}", e.what());
return false;
}
}

View File

@@ -7,143 +7,154 @@
#include <spdlog/spdlog.h>
#include <array>
#include <format>
#include <optional>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
namespace {
const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
const std::size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
static const std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
std::size_t separator_pos = std::string::npos;
std::size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
const std::string_view trimmed = trim(raw_response);
std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
return std::string(trimmed);
return std::string(trimmed);
}
} // namespace
BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) {
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const BreweryLocation& location, const std::string& region_context) {
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
*/
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md");
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
*/
const std::string system_prompt =
LoadBrewerySystemPrompt("prompts/system.md");
/**
* User prompt: provides geographic context to guide generation towards
* culturally appropriate and locally-inspired brewery attributes
*/
std::string prompt =
"Write a brewery name and place-specific long description for a craft "
"brewery in ";
prompt.append(location.city_name);
if (!location.country_name.empty()) {
prompt.append(", ");
prompt.append(location.country_name);
}
if (safe_region_context.empty()) {
prompt.append(".");
} else {
prompt.append(". Regional context: ");
prompt.append(safe_region_context);
}
/**
* User prompt: provides geographic context to guide generation towards
* culturally relevant and locally-inspired brewery attributes
*/
std::string prompt = std::format(
"Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
/**
* Store location context for retry prompts (without repeating full context)
*/
std::string retry_location = "Location: ";
retry_location.append(location.city_name);
if (!location.country_name.empty()) {
retry_location.append(", ");
retry_location.append(location.country_name);
}
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
std::format("Location: {}{}", location.city, country_suffix);
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
const int max_attempts = 3;
std::string raw;
std::string last_error;
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
constexpr int max_attempts = 3;
std::string raw;
std::string last_error;
// Limit output length to keep it concise and focused
constexpr int max_tokens = 1052;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
// Generate brewery data from LLM
raw = Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
// Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
// Validate output: parse JSON and check required fields
// Validate output: parse JSON and check required fields
std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::string validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (validation_error.empty()) {
// Success: return parsed brewery data
return {std::move(name), std::move(description)};
}
std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
}
// Validation failed: log error and prepare corrective feedback
// Validation failed: log error and prepare corrective feedback
last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
last_error = *validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error);
// Update prompt with error details to guide LLM toward correct output.
// For retries, use a compact prompt format to avoid exceeding token
// limits.
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with exactly these keys: "
"{\"name\": \"<brewery name>\", "
"\"description\": \"<single-paragraph description>\"}."
"\nDo not include markdown, comments, extra keys, or literal "
"placeholder values.";
prompt += "\n\n";
prompt += retry_location;
}
// Update prompt with error details to guide LLM toward correct output.
prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
{})",
*validation_error, retry_location);
}
// All retry attempts exhausted: log failure and throw exception
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
// All retry attempts exhausted: log failure and throw exception
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
@@ -13,6 +14,87 @@
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."};
/**
* System prompt: specifies exact output format to minimize parsing errors
* Constraints: 2-line output, username format, bio length bounds
*/
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
/**
* User prompt: locale parameter guides cultural appropriateness of generated
* profiles
*/
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
/**
* RETRY LOOP with format validation
* Attempts up to 3 times to generate valid user profile with correct format
*/
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
/**
* Generate user profile (max 128 tokens - should fit 2 lines easily)
*/
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
/**
* Parse two-line response: first line = username, second line = bio
*/
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
/**
* Remove any whitespace from username (usernames shouldn't have
* spaces)
*/
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
/**
* Validate both fields are non-empty after processing
*/
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
/**
* Truncate bio if exceeds reasonable length for bio field
*/
if (bio.size() > 200) bio = bio.substr(0, 200);
/**
* Success: return parsed user profile
*/
return {username, bio};
} catch (const std::exception& e) {
/**
* Parsing failed: log and continue to next attempt
*/
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
/**
* All retry attempts exhausted: log failure and throw exception
*/
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -4,17 +4,13 @@
* parsing, token decoding, and JSON validation helpers for Llama modules.
*/
#include <spdlog/spdlog.h>
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/llama_generator.h"
@@ -23,42 +19,40 @@
/**
* String trimming: removes leading and trailing whitespace
*/
static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
static std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
}
/**
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces
*/
static std::string CondenseWhitespace(std::string_view text) {
std::string out;
out.reserve(text.size());
static std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool pending_space = false;
for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
}
continue;
}
continue;
}
if (pending_space) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
}
in_whitespace = false;
out.push_back(static_cast<char>(ch));
}
return out;
return Trim(std::move(out));
}
/**
@@ -66,286 +60,386 @@ static std::string CondenseWhitespace(std::string_view text) {
* boundaries
*/
static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context);
if (normalized.size() <= max_chars) {
return normalized;
}
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized += "...";
return normalized;
normalized += "...";
return normalized;
}
static std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
/**
* Remove common bullet points, numbers, and field labels added by LLM in output
*/
static std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
}
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
const std::array<llama_chat_message, 2> messages = {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
return Trim(std::move(line));
}
if (result < 0) {
return result;
}
/**
* Parse two-line response from LLM: normalize line endings, strip formatting,
* filter spurious output, and combine remaining lines if needed
*/
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
return result;
};
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (filtered.size() < 2) throw std::runtime_error(error_message);
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
return system_prompt + "\n\n" + user_prompt;
}
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
const std::array<llama_chat_message, 2> messages = {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1);
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
int32_t required =
llama_chat_apply_template(tmpl, messages.data(), 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
return {buffer.data(), static_cast<std::size_t>(template_result)};
// FALLBACK: If the template fails (e.g., Gemma rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
if (required < 0) {
std::string combined_prompt = system_prompt + "\n\n" + user_prompt;
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
required = llama_chat_apply_template(tmpl, fallback_msg.data(), 1, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// THE FIX: Ultimate fallback. If the GGUF's internal template is
// completely unparseable (which happens with complex Jinja macros),
// degrade gracefully to raw text instead of throwing a runtime_error.
if (required < 0) {
return combined_prompt;
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(
tmpl, fallback_msg.data(), 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
return combined_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
// Standard buffer resize if the original "system" + "user" array succeeded
// but needed more space
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages.data(), 2, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// Final safety net on resize
if (required < 0) {
return system_prompt + "\n\n" + user_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()),
0, true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
bool found = false;
std::string candidate;
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
bool found = false;
std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
}
continue;
}
continue;
}
if (ch == '"') {
in_string = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
if (ch == '"') {
in_string = true;
continue;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
candidate = text.substr(start, i - start + 1);
found = true;
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
candidate = text.substr(start, i - start + 1);
found = true;
}
}
}
}
}
if (!found) {
return false;
}
if (!found) {
return false;
}
json_out = std::move(candidate);
return true;
json_out = std::move(candidate);
return true;
}
std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
std::string extracted;
if (ExtractLastJsonObject(text, extracted)) {
return extracted;
}
return {};
return {};
}
static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
static std::string ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
error_out.clear();
return true;
};
error_out.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validate_object(jv, validation_error)) {
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
}
return std::nullopt;
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
return {};
}
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars);
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt, "");
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt);
return ToChatPrompt(model, system_prompt, user_prompt);
}
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) {
AppendTokenPiece(vocab, token, output);
AppendTokenPiece(vocab, token, output);
}
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -2,7 +2,7 @@
* Text Generation / Inference Module
* Core module that performs LLM inference: converts text prompts into tokens,
* runs the neural network forward pass, samples the next token, and converts
* output tokens back to text for system+user chat prompts.
* output tokens back to text. Supports both simple and system+user prompts.
*/
#include <spdlog/spdlog.h>
@@ -17,156 +17,182 @@
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
static constexpr std::size_t kPromptTokenSlack = 8;
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt,
const int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
const int max_tokens) {
/**
* Validate that model and context are loaded
*/
if (model_ == nullptr || context_ == nullptr) {
throw std::runtime_error("LlamaGenerator: model not loaded");
}
int max_tokens) {
/**
* Validate that model and context are loaded
*/
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
/**
* Clear KV cache to ensure clean inference state (no residual context)
*/
llama_memory_clear(llama_get_memory(context_), true);
/**
* Clear KV cache to ensure clean inference state (no residual context)
*/
llama_memory_clear(llama_get_memory(context_), true);
/**
* TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands
*/
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
kPromptTokenSlack);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
/**
* TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands
*/
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
/**
* If buffer too small, negative return indicates required size
*/
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
/**
* If buffer too small, negative return indicates required size
*/
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0) {
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
/**
* CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window
* constraints
*/
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
/**
* CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window
* constraints
*/
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
/**
* Clamp generation limit to available context window, reserve space for
* output
*/
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
/**
* Prompt can use remaining context after reserving space for generation
*/
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
/**
* Clamp generation limit to available context window, reserve space for
* output
*/
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
/**
* Prompt can use remaining context after reserving space for generation
*/
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
/**
* Truncate prompt if necessary to fit within constraints
*/
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
/**
* Truncate prompt if necessary to fit within constraints
*/
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
/**
* PROMPT PROCESSING PHASE
* Create a batch containing all prompt tokens and feed through the model
* This computes internal representations and fills the KV cache
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) {
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
/**
* PROMPT PROCESSING PHASE
* Create a batch containing all prompt tokens and feed through the model
* This computes internal representations and fills the KV cache
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
/**
* TOKEN GENERATION LOOP
* Iteratively generate tokens one at a time until max_tokens or
* end-of-sequence
*/
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
/**
* SAMPLER CONFIGURATION PHASE
* Set up the probabilistic token selection pipeline (sampler chain)
* Samplers are applied in sequence: temperature -> top-k -> top-p ->
* distribution
*/
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
if (sampler_ == nullptr || sampler_->chain == nullptr) {
throw std::runtime_error("LlamaGenerator: sampler not initialized");
}
/**
* Temperature: scales logits before softmax (controls randomness)
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
/**
* Top-K: limits sampling to the most likely tokens before nucleus
* sampling
*/
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
/**
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
* probability
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
/**
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
for (int i = 0; i < effective_max_tokens; ++i) {
/**
* Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler_->chain, context_, -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) {
break;
}
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token decode_token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
if (llama_decode(context_, one_token_batch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
/**
* TOKEN GENERATION LOOP
* Iteratively generate tokens one at a time until max_tokens or
* end-of-sequence
*/
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
/**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens) {
AppendTokenPiecePublic(vocab, token, output);
}
for (int i = 0; i < effective_max_tokens; ++i) {
/**
* Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler.get(), context_, -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
return output;
/**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
return output;
}

View File

@@ -5,7 +5,6 @@
#include "data_generation/llama_generator.h"
#include <memory>
#include <random>
#include <stdexcept>
#include <string>
@@ -13,113 +12,65 @@
#include "data_model/application_options.h"
#include "llama.h"
static constexpr uint32_t kMaxContextSize = 32768U;
struct SamplerConfig {
float temperature;
float top_p;
uint32_t top_k;
};
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
std::mt19937& rng) {
const llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(config.temperature));
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(config.top_p, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng()));
return sampler;
}
LlamaGenerator::SamplerState::~SamplerState() {
if (chain != nullptr) {
llama_sampler_free(chain);
chain = nullptr;
}
}
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path)
: rng_(std::random_device{}()) {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (options.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.temperature < 0.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (options.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
}
if (options.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
}
if (options.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k;
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k;
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
this->Load(model_path);
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
sampling_top_k_};
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
sampler_.reset(new SamplerState());
sampler_->chain = sampler_chain.release();
this->Load(model_path);
}
LlamaGenerator::~LlamaGenerator() {
sampler_.reset();
/**
* Free the inference context (contains KV cache and computation state)
*/
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
/**
* Free the inference context (contains KV cache and computation state)
*/
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
/**
* Free the loaded model (contains weights and vocabulary)
*/
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
/**
* Free the loaded model (contains weights and vocabulary)
*/
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
}

View File

@@ -14,32 +14,32 @@
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
const llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}

View File

@@ -1,14 +1,13 @@
/**
* @file data_generation/llama/load_brewery_prompt.cpp
* @brief Resolves brewery system prompt content from cache or a configured
* filesystem path and provides a robust inline fallback prompt when absent.
* @brief Resolves brewery system prompt content from cache or filesystem
* search paths and provides a robust inline fallback prompt when absent.
*/
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h"
@@ -18,42 +17,81 @@ namespace fs = std::filesystem;
* @brief Loads brewery system prompt from disk or cache.
*
* @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk.
* @return Prompt text loaded from disk or fallback content.
*/
std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) {
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
// Return cached version if already loaded
if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_;
}
// Try the provided path only
const fs::path prompt_path(prompt_file_path);
std::ifstream prompt_file(prompt_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
// Try multiple path locations
std::vector<std::string> paths_to_try = {
prompt_file_path, // As provided
"../" + prompt_file_path, // One level up
"../../" + prompt_file_path, // Two levels up
};
const std::string prompt((std::istreambuf_iterator(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
for (const auto& path : paths_to_try) {
std::ifstream prompt_file(path);
if (prompt_file.is_open()) {
std::string prompt((std::istreambuf_iterator<char>(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_path.string());
}
if (!prompt.empty()) {
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} "
"chars)",
path, prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}
}
}
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_path.string(), prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}
spdlog::warn(
"LlamaGenerator: Could not open brewery system prompt file at any of "
"the "
"expected locations. Using fallback inline prompt.");
return GetFallbackBreweryPrompt();
}
/**
* @brief Provides an inline fallback brewery system prompt.
*
* @return Default fallback prompt text.
*/
std::string LlamaGenerator::GetFallbackBreweryPrompt() {
return "You are an experienced brewmaster and owner of a local craft "
"brewery. "
"Create a distinctive, authentic name and detailed description that "
"genuinely reflects your specific location, brewing philosophy, "
"local "
"culture, and community connection. The brewery must feel real and "
"grounded—not generic or interchangeable.\n\n"
"AVOID REPETITIVE PHRASES - Never use:\n"
"Love letter to, tribute to, rolling hills, picturesque, every sip "
"tells a story, Come for X stay for Y, rich history, passion, woven "
"into, ancient roots, timeless, where tradition meets innovation\n\n"
"OPENING APPROACHES - Choose ONE:\n"
"1. Start with specific beer style and its regional origins\n"
"2. Begin with specific brewing challenge (water, altitude, "
"climate)\n"
"3. Open with founding story or personal motivation\n"
"4. Lead with specific local ingredient or resource\n"
"5. Start with unexpected angle or contradiction\n"
"6. Open with local event, tradition, or cultural moment\n"
"7. Begin with tangible architectural or geographic detail\n\n"
"BE SPECIFIC - Include:\n"
"- At least ONE concrete proper noun (landmark, river, "
"neighborhood)\n"
"- Specific beer styles relevant to the REGION'S culture\n"
"- Concrete brewing challenges or advantages\n"
"- Sensory details SPECIFIC to place—not generic adjectives\n\n"
"LENGTH: 150-250 words. TONE: Can be soulful, irreverent, "
"matter-of-fact, unpretentious, or minimalist.\n\n"
"Output ONLY a raw JSON object with keys name and description. "
"No markdown, backticks, preamble, or trailing text.";
}

View File

@@ -0,0 +1,71 @@
/**
* @file data_generation/mock/data.cpp
* @brief Defines static lookup tables used by MockGenerator for deterministic
* brewery names, descriptions, usernames, and bios.
*/
#include <string>
#include <vector>
#include "data_generation/mock_generator.h"
const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
const std::vector<std::string> MockGenerator::kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
const std::vector<std::string> MockGenerator::kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal blends.",
"Modern taproom focused on sessionable lagers and classic pub styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp lagers.",
"Creative brewery offering rotating collaborations and limited draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich stouts."};
const std::vector<std::string> MockGenerator::kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
const std::vector<std::string> MockGenerator::kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};

View File

@@ -5,12 +5,13 @@
*/
#include <boost/container_hash/hash.hpp>
#include <string>
#include "data_generation/mock_generator.h"
size_t MockGenerator::DeterministicHash(const Location& location) {
size_t seed = 0;
boost::hash_combine(seed, location.city);
boost::hash_combine(seed, location.country);
return seed;
std::size_t MockGenerator::DeterministicHash(const BreweryLocation& location) {
std::size_t seed = 0;
boost::hash_combine(seed, location.city_name);
boost::hash_combine(seed, location.country_name);
return seed;
}

View File

@@ -4,39 +4,35 @@
* and country into fixed mock phrase catalogs.
*/
#include <format>
#include <string>
#include <string_view>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const Location& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location);
const BreweryLocation& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location);
const std::string_view adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size());
const std::string_view base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string& adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string& noun =
kBreweryNouns.at((hash / 7) % kBreweryNouns.size());
const std::string& base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string name =
std::format("{} {} {}", location.city, adjective, noun);
std::string name(location.city_name);
name.append(" ");
name.append(adjective);
name.append(" ");
name.append(noun);
const std::string state_suffix =
location.state_province.empty()
? std::string{}
: std::format(", {}", location.state_province);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string description =
std::format("{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix);
std::string description = base_description;
description.append(" Based in ");
description.append(location.city_name);
if (!location.country_name.empty()) {
description.append(", ");
description.append(location.country_name);
}
description.append(".");
return {
.name = name,
.description = description,
};
return {name, description};
}

View File

@@ -6,17 +6,14 @@
#include <functional>
#include <string>
#include <string_view>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()];
result.username = username;
result.bio = bio;
return result;
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -12,76 +12,72 @@
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <string_view>
static std::string ReadRequiredString(const boost::json::object& object,
const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) {
throw std::runtime_error(std::string("Missing or invalid string field: ") +
key);
}
const std::string_view text = value->as_string();
return std::string(text);
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) {
throw std::runtime_error(
std::string("Missing or invalid string field: ") + key);
}
return std::string(value->as_string().c_str());
}
static double ReadRequiredNumber(const boost::json::object& object,
const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) {
throw std::runtime_error(std::string("Missing or invalid numeric field: ") +
key);
}
return value->to_number<double>();
}
std::vector<Location> JsonLoader::LoadLocations(
const std::filesystem::path& filepath) {
std::ifstream input(filepath);
if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " +
filepath.string());
}
std::stringstream buffer;
buffer << input.rdbuf();
const std::string content = buffer.str();
boost::system::error_code error;
boost::json::value root = boost::json::parse(content, error);
if (error) {
throw std::runtime_error("Failed to parse locations JSON: " +
error.message());
}
if (!root.is_array()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
for (const auto& item : items) {
if (!item.is_object()) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) {
throw std::runtime_error(
"Invalid locations JSON: each entry must be an object");
}
const auto& object = item.as_object();
locations.push_back(Location{
.city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath.string());
return locations;
std::string("Missing or invalid numeric field: ") + key);
}
return value->to_number<double>();
}
std::vector<Location> JsonLoader::LoadLocations(const std::string& filepath) {
std::ifstream input(filepath);
if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " + filepath);
}
std::stringstream buffer;
buffer << input.rdbuf();
const std::string content = buffer.str();
boost::system::error_code error;
boost::json::value root = boost::json::parse(content, error);
if (error) {
throw std::runtime_error("Failed to parse locations JSON: " +
error.message());
}
if (!root.is_array()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
for (const auto& item : items) {
if (!item.is_object()) {
throw std::runtime_error(
"Invalid locations JSON: each entry must be an object");
}
const auto& object = item.as_object();
locations.push_back(Location{
.city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath);
return locations;
}

View File

@@ -10,7 +10,6 @@
#include <boost/program_options.hpp>
#include <exception>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
@@ -31,153 +30,141 @@ namespace di = boost::di;
*
* @param argc Command-line argument count.
* @param argv Command-line arguments.
* @return Parsed ApplicationOptions if parsing succeeded, std::nullopt
* otherwise.
* @param options Output ApplicationOptions struct.
* @return true if parsing succeeded and should proceed, false otherwise.
*/
std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
prog_opts::options_description desc("Pipeline Options");
bool ParseArguments(const int argc, char** argv,
ApplicationOptions& options) noexcept {
prog_opts::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)")(
"top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)")(
"n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)")(
"seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
auto opt = desc.add_options();
// Handle the "no arguments" or "help" case
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
return false;
}
opt("help,h", "Produce help message");
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data");
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return false;
}
opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)");
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)");
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return false;
}
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)");
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return false;
}
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)");
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted();
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)");
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
// Handle the "no arguments" or "help" case
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
return std::nullopt;
}
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt;
}
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt;
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt;
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return std::nullopt;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
return std::nullopt;
}
return true;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return false;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
return false;
}
}
int main(const int argc, char** argv) {
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
int main(const int argc, char** argv) noexcept {
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
const auto parsed_options = ParseArguments(argc, argv);
if (!parsed_options.has_value()) {
return 0;
}
ApplicationOptions options;
if (!ParseArguments(argc, argv, options)) {
return 0;
}
const auto options = *parsed_options;
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to(
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to([options](const auto& injector)
-> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>();
}
}
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>();
}));
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed);
return injector.template create<std::unique_ptr<LlamaGenerator>>();
}));
auto generator = injector.create<BiergartenDataGenerator>();
auto generator = injector.create<BiergartenDataGenerator>();
if (!generator.Run()) {
spdlog::error("Pipeline execution failed");
if (!generator.Run()) {
spdlog::error("Pipeline execution failed");
return 1;
}
spdlog::info("Pipeline executed successfully");
return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1;
}
spdlog::info("Pipeline executed successfully");
return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1;
} catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
} catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
}

View File

@@ -12,50 +12,47 @@
#include "services/wikipedia_service.h"
std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string body = this->client_->Get(url);
const std::string body = this->client_->Get(url);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string();
std::string extract(extract_view);
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
return {};
return {};
}

View File

@@ -10,38 +10,47 @@
#include "services/wikipedia_service.h"
std::string WikipediaService::GetLocationContext(const Location& loc) {
if (!client_) {
return {};
}
const std::string cache_key = loc.city + "|" + loc.country;
const auto cache_it = cache_.find(cache_key);
if (cache_it != cache_.end()) {
return cache_it->second;
}
std::string result;
std::string result;
std::string region_query(loc.city);
if (!loc.country.empty()) {
region_query += ", ";
region_query += loc.country;
}
if (!client_) {
cache_.emplace(cache_key, result);
return result;
}
const std::string beer_query = "beer in " + loc.country;
const std::string city_beer_query = "beer in " + loc.city;
std::string region_query(loc.city);
if (!loc.country.empty()) {
region_query += ", ";
region_query += loc.country;
}
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
return;
}
if (!result.empty()) {
result += "\n\n";
}
result += extract;
};
const std::string beer_query = "beer in " + loc.country;
const std::string city_beer_query = "beer in " + loc.city;
try {
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());
}
return result;
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
return;
}
if (!result.empty()) {
result += "\n\n";
}
result += extract;
};
try {
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());
}
cache_.emplace(cache_key, result);
return result;
}

View File

@@ -7,5 +7,5 @@
#include <utility>
WikipediaService::WikipediaService(std::unique_ptr<WebClient> client)
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {}

View File

@@ -10,10 +10,10 @@
#include "web_client/curl_web_client.h"
CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally");
}
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally");
}
}
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }

View File

@@ -5,71 +5,46 @@
#include <curl/curl.h>
#include <cstdint>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h"
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url) {
constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, const size_t size,
const size_t nmemb, void* userp) {
const size_t real_size = size * nmemb;
auto* str = static_cast<std::string*>(userp);
str->append(static_cast<char*>(contents), real_size);
return real_size;
static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void* userp) {
size_t realsize = size * nmemb;
auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char*>(contents), realsize);
return realsize;
}
std::string CURLWebClient::Get(const std::string& url) {
const CurlHandle curl = create_handle();
auto curl = create_handle();
std::string response_string;
std::string response_string;
set_common_get_options(curl.get(), url, {10L, 20L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
set_common_get_options(curl.get(), url);
CURLcode res = curl_easy_perform(curl.get());
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
if (res != CURLE_OK) {
std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
CURLcode res = curl_easy_perform(curl.get());
long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (res != CURLE_OK) {
const auto error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
if (httpCode != 200) {
std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str());
}
int64_t httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " +
std::to_string(httpCode) + " for URL " + url;
throw std::runtime_error(error);
}
return response_string;
}
return response_string;
}

View File

@@ -11,14 +11,13 @@
#include "web_client/curl_web_client.h"
std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0);
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (!output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}
std::string result(output);
curl_free(output);
return result;
}
if (output) {
std::string result(output);
curl_free(output);
return result;
}
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}

View File

@@ -0,0 +1,28 @@
/**
* @file web_client/curl_web_client_utils.cpp
* @brief Shared CURLWebClient helper implementations.
*/
#include "curl_web_client_utils.h"
#include <stdexcept>
CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}

View File

@@ -0,0 +1,26 @@
#ifndef BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
#define BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
/**
* @file web_client/curl_web_client_utils.h
* @brief Shared helpers for CURLWebClient request setup.
*/
#include <curl/curl.h>
#include <memory>
#include <string>
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
struct CurlTimeouts {
long connect_timeout;
long total_timeout;
};
CurlHandle create_handle();
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts);
#endif // BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_