8 Commits

Author SHA1 Message Date
Aaron Po
299a767d39 remove unused code 2026-04-11 14:42:32 -04:00
Aaron Po
f07d48f810 Add missing includes, update readme 2026-04-11 14:31:24 -04:00
Aaron Po
bcfde856fe Split data models into dedicated headers 2026-04-11 13:21:50 -04:00
Aaron Po
5946356083 Style audit: update code to strictly follow Google Style Guide 2026-04-11 11:56:45 -04:00
Aaron Po
ae67fa8566 refactor: consolidate and rename data generation and service files 2026-04-11 00:06:23 -04:00
Aaron Po
8c572a2d07 fix: stabilize Gemma 4 brewery generation
remove misleading turn-token output guidance from the brewery prompt
extract the last balanced JSON object before validation
keep README model setup and run instructions aligned
preserve Gemma 4 sampling defaults and local model usage
2026-04-10 22:25:26 -04:00
Aaron Po
902bda6eb9 eat: make Gemma 4 the default model, enable thinking mode 2026-04-10 21:43:18 -04:00
Aaron Po
61d5077a95 update readme 2026-04-10 00:03:45 -04:00
50 changed files with 1744 additions and 1528 deletions

View File

@@ -1,5 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
ColumnLimit: 80 ColumnLimit: 80
IndentWidth: 2 IndentWidth: 3
... ...

View File

@@ -9,23 +9,21 @@ Checks: >
-google-runtime-references -google-runtime-references
CheckOptions: CheckOptions:
# Enforce Google Naming Conventions with valid clang-tidy strings # Enforce Google Naming Conventions
- key: readability-identifier-naming.ClassCase
value: CamelCase
- key: readability-identifier-naming.ClassMemberCase - key: readability-identifier-naming.ClassMemberCase
value: lower_case value: snake_case
- key: readability-identifier-naming.ClassMemberSuffix - key: readability-identifier-naming.ClassMemberSuffix
value: _ value: _
- key: readability-identifier-naming.ClassCase
value: PascalCase
- key: readability-identifier-naming.FunctionCase - key: readability-identifier-naming.FunctionCase
value: CamelCase value: PascalCase
- key: readability-identifier-naming.StructCase - key: readability-identifier-naming.StructCase
value: CamelCase value: PascalCase
- key: readability-identifier-naming.VariableCase - key: readability-identifier-naming.VariableCase
value: lower_case value: snake_case
- key: readability-identifier-naming.GlobalConstantCase - key: readability-identifier-naming.GlobalConstantCase
value: CamelCase value: kPascalCase
- key: readability-identifier-naming.GlobalConstantPrefix
value: k
# Ensure C++20 Modernization # Ensure C++20 Modernization
- key: modernize-make-unique.MakeSmartPtrFunction - key: modernize-make-unique.MakeSmartPtrFunction

1
pipeline/.gitignore vendored
View File

@@ -1,6 +1,5 @@
dist dist
build build
cmake-build-*
data data
models models
*.gguf *.gguf

View File

@@ -1,10 +1,12 @@
cmake_minimum_required(VERSION 3.24) cmake_minimum_required(VERSION 3.24)
project(biergarten-pipeline) project(biergarten-pipeline)
# Boost.DI still declares a very old minimum CMake version, which newer CMake
# releases reject unless a policy version floor is provided.
set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "" FORCE) set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "" FORCE)
# ============================================================================= # =============================================================================
# 1. Platform & GPU Detection # 1. Platform & GPU Detection (Windows explicitly NOT supported)
# ============================================================================= # =============================================================================
if(WIN32) if(WIN32)
message(FATAL_ERROR "[biergarten] Windows is currently not supported. Please use Linux (Fedora 43) or macOS (M1 Pro).") message(FATAL_ERROR "[biergarten] Windows is currently not supported. Please use Linux (Fedora 43) or macOS (M1 Pro).")
@@ -38,15 +40,18 @@ endif()
# 2. Project-wide Settings (Standard & Optimization) # 2. Project-wide Settings (Standard & Optimization)
# ============================================================================= # =============================================================================
# Downgrade to C++20 as per Google Style Guide
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# GCC/Clang specific settings (warnings as errors)
add_compile_options(-Wall -Wextra -Werror -Wpedantic) add_compile_options(-Wall -Wextra -Werror -Wpedantic)
# Release Build Optimization: Aggressive (-O3), Arch-specific, and LTO # Release Build Optimization: Aggressive (-O3), Arch-specific, and LTO
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native -flto") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native -flto")
# Debug Build Optimization: Fast and debuggable (-Og)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og -g") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og -g")
# ============================================================================= # =============================================================================
@@ -101,6 +106,7 @@ set(SOURCES
src/services/wikipedia/fetch_extract.cpp src/services/wikipedia/fetch_extract.cpp
src/web_client/curl_global_state.cpp src/web_client/curl_global_state.cpp
src/web_client/curl_web_client_get.cpp src/web_client/curl_web_client_get.cpp
src/web_client/curl_web_client_utils.cpp
src/web_client/curl_web_client_url_encode.cpp src/web_client/curl_web_client_url_encode.cpp
src/data_generation/llama/llama_generator.cpp src/data_generation/llama/llama_generator.cpp
src/data_generation/llama/generate_brewery.cpp src/data_generation/llama/generate_brewery.cpp
@@ -109,6 +115,7 @@ set(SOURCES
src/data_generation/llama/infer.cpp src/data_generation/llama/infer.cpp
src/data_generation/llama/load.cpp src/data_generation/llama/load.cpp
src/data_generation/llama/load_brewery_prompt.cpp src/data_generation/llama/load_brewery_prompt.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/generate_brewery.cpp src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp src/data_generation/mock/generate_user.cpp
@@ -141,9 +148,3 @@ configure_file(
${CMAKE_BINARY_DIR}/locations.json ${CMAKE_BINARY_DIR}/locations.json
COPYONLY COPYONLY
) )
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_SOURCE_DIR}/prompts
${CMAKE_BINARY_DIR}/prompts
)

View File

@@ -76,7 +76,7 @@ curl -L \
## Run ## Run
Run the executable from the build directory so the copied `locations.json` and `prompts/` directory are available. Run the executable from the build directory so the copied `locations.json` is available.
```bash ```bash
./biergarten-pipeline --mocked ./biergarten-pipeline --mocked

View File

@@ -1,7 +1,7 @@
@startuml BiergartenPipeline @startuml BiergartenPipeline
title Biergarten Pipeline - Class and Composition Diagram title Biergarten Pipeline - Class and Composition Diagram
top to bottom direction left to right direction
skinparam shadowing false skinparam shadowing false
skinparam classAttributeIconSize 0 skinparam classAttributeIconSize 0
skinparam packageStyle rectangle skinparam packageStyle rectangle
@@ -16,17 +16,11 @@ package "Composition root" {
+~CurlGlobalState() +~CurlGlobalState()
} }
class LlamaBackendState {
+LlamaBackendState()
+~LlamaBackendState()
}
note right of Main note right of Main
Binds with Boost.DI: Binds with Boost.DI:
- WebClient -> CURLWebClient - WebClient -> CURLWebClient
- IEnrichmentService -> WikipediaService - IEnrichmentService -> WikipediaService
- DataGenerator -> MockGenerator or LlamaGenerator - DataGenerator -> MockGenerator or LlamaGenerator
- std::string -> model_path
- LlamaGenerator receives ApplicationOptions and model_path directly - LlamaGenerator receives ApplicationOptions and model_path directly
end note end note
} }
@@ -38,8 +32,8 @@ package "Core orchestration" {
-generated_breweries_: std::vector<GeneratedBrewery> -generated_breweries_: std::vector<GeneratedBrewery>
+BiergartenDataGenerator(context_service: std::shared_ptr<IEnrichmentService>, generator: std::unique_ptr<DataGenerator>) +BiergartenDataGenerator(context_service: std::shared_ptr<IEnrichmentService>, generator: std::unique_ptr<DataGenerator>)
+Run(): bool +Run(): bool
{static} -QueryCitiesWithCountries(): std::vector<Location> -QueryCitiesWithCountries(): std::vector<Location>
-GenerateBreweries(cities: const std::vector<EnrichedCity>&): void -GenerateBreweries(cities: std::vector<EnrichedCity>): void
-LogResults(): void -LogResults(): void
} }
} }
@@ -55,6 +49,11 @@ package "Data models" {
+seed: int +seed: int
} }
class BreweryLocation <<struct>> {
+city_name: std::string_view
+country_name: std::string_view
}
class Location <<struct>> { class Location <<struct>> {
+city: std::string +city: std::string
+state_province: std::string +state_province: std::string
@@ -88,64 +87,67 @@ package "Data models" {
package "Generation" { package "Generation" {
interface DataGenerator { interface DataGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult +GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: const std::string&): UserResult +GenerateUser(locale: std::string): UserResult
} }
class MockGenerator { class MockGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult +GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: const std::string&): UserResult +GenerateUser(locale: std::string): UserResult
} }
class LlamaGenerator { class LlamaGenerator {
+LlamaGenerator(options: const ApplicationOptions&, model_path: const std::string&) +LlamaGenerator(options: ApplicationOptions, model_path: std::string)
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult +GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: const std::string&): UserResult +GenerateUser(locale: std::string): UserResult
} }
} }
package "HTTP" { package "HTTP" {
interface WebClient { interface WebClient {
+Get(url: const std::string&): std::string +DownloadToFile(url: std::string, file_path: std::string): void
+UrlEncode(value: const std::string&): std::string +Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
} }
class CURLWebClient { class CURLWebClient {
+Get(url: const std::string&): std::string +CURLWebClient()
+UrlEncode(value: const std::string&): std::string +~CURLWebClient()
+DownloadToFile(url: std::string, file_path: std::string): void
+Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
} }
} }
package "JSON handling" { package "JSON handling" {
class JsonLoader { class JsonLoader {
{static} +LoadLocations(filepath: const std::string&): std::vector<Location> {static} +LoadLocations(filepath: std::string): std::vector<Location>
} }
} }
package "Wikipedia" { package "Wikipedia" {
interface IEnrichmentService { interface IEnrichmentService {
+GetLocationContext(loc: const Location&): std::string +GetLocationContext(loc: Location): std::string
} }
class WikipediaService { class WikipediaService {
+WikipediaService(client: std::unique_ptr<WebClient>) +WikipediaService(client: std::shared_ptr<WebClient>)
+GetLocationContext(loc: const Location&): std::string +GetLocationContext(loc: Location): std::string
} }
} }
Main --> CurlGlobalState Main --> CurlGlobalState
Main --> LlamaBackendState
Main --> ApplicationOptions Main --> ApplicationOptions
Main --> BiergartenDataGenerator Main --> BiergartenDataGenerator
Main ..> IEnrichmentService : DI binding Main ..> IEnrichmentService : DI binding
Main ..> DataGenerator : DI factory Main ..> DataGenerator : DI factory
Main ..> CURLWebClient : DI binding Main ..> CURLWebClient : DI binding
BiergartenDataGenerator *-- EnrichedCity
BiergartenDataGenerator *-- GeneratedBrewery BiergartenDataGenerator *-- GeneratedBrewery
BiergartenDataGenerator ..> JsonLoader : LoadLocations() BiergartenDataGenerator ..> JsonLoader : LoadLocations()
BiergartenDataGenerator --> IEnrichmentService : context lookup BiergartenDataGenerator --> IEnrichmentService : context lookup
BiergartenDataGenerator --> DataGenerator : brewery generation BiergartenDataGenerator --> DataGenerator : brewery generation
BiergartenDataGenerator ..> EnrichedCity
BiergartenDataGenerator ..> Location BiergartenDataGenerator ..> Location
BiergartenDataGenerator ..> BreweryResult BiergartenDataGenerator ..> BreweryResult
@@ -154,7 +156,7 @@ DataGenerator <|.. LlamaGenerator
WebClient <|.. CURLWebClient WebClient <|.. CURLWebClient
IEnrichmentService <|.. WikipediaService IEnrichmentService <|.. WikipediaService
WikipediaService *-- WebClient : unique_ptr WikipediaService --> WebClient : shared_ptr
note right of BiergartenDataGenerator note right of BiergartenDataGenerator
Current behavior: Current behavior:

View File

@@ -7,7 +7,6 @@
*/ */
#include <memory> #include <memory>
#include <span>
#include <vector> #include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
@@ -30,7 +29,7 @@ class BiergartenDataGenerator {
* @param context_service Context provider for sampled locations. * @param context_service Context provider for sampled locations.
* @param generator Brewery and user data generator. * @param generator Brewery and user data generator.
*/ */
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service, BiergartenDataGenerator(std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator); std::unique_ptr<DataGenerator> generator);
/** /**
@@ -46,8 +45,8 @@ class BiergartenDataGenerator {
bool Run(); bool Run();
private: private:
/// @brief Owning context provider dependency. /// @brief Shared context provider dependency.
std::unique_ptr<IEnrichmentService> context_service_; std::shared_ptr<IEnrichmentService> context_service_;
/// @brief Generator dependency selected in the composition root. /// @brief Generator dependency selected in the composition root.
std::unique_ptr<DataGenerator> generator_; std::unique_ptr<DataGenerator> generator_;
@@ -62,9 +61,9 @@ class BiergartenDataGenerator {
/** /**
* @brief Generate breweries for enriched cities. * @brief Generate breweries for enriched cities.
* *
* @param cities Span of enriched city data. * @param cities Vector of enriched city data.
*/ */
void GenerateBreweries(std::span<const EnrichedCity> cities); void GenerateBreweries(const std::vector<EnrichedCity>& cities);
/** /**
* @brief Log the generated brewery results. * @brief Log the generated brewery results.

View File

@@ -8,8 +8,8 @@
#include <string> #include <string>
#include "data_model/brewery_location.h"
#include "data_model/brewery_result.h" #include "data_model/brewery_result.h"
#include "data_model/location.h"
#include "data_model/user_result.h" #include "data_model/user_result.h"
/** /**
@@ -17,16 +17,17 @@
*/ */
class DataGenerator { class DataGenerator {
public: public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~DataGenerator() = default; virtual ~DataGenerator() = default;
/** /**
* @brief Generates brewery data for a location. * @brief Generates brewery data for a location.
* *
* @param location Location data * @param location City and country names.
* @param region_context Additional regional context text. * @param region_context Additional regional context text.
* @return Brewery generation result. * @return Brewery generation result.
*/ */
virtual BreweryResult GenerateBrewery(const Location& location, virtual BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) = 0; const std::string& region_context) = 0;
/** /**

View File

@@ -16,7 +16,6 @@
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct LlamaSampler;
/** /**
* @brief Data generator implementation backed by llama.cpp. * @brief Data generator implementation backed by llama.cpp.
@@ -36,19 +35,14 @@ class LlamaGenerator final : public DataGenerator {
/// @brief Releases model/context resources. /// @brief Releases model/context resources.
~LlamaGenerator() override; ~LlamaGenerator() override;
LlamaGenerator(const LlamaGenerator&) = delete;
LlamaGenerator& operator=(const LlamaGenerator&) = delete;
LlamaGenerator(LlamaGenerator&&) = delete;
LlamaGenerator& operator=(LlamaGenerator&&) = delete;
/** /**
* @brief Generates brewery data for a specific location. * @brief Generates brewery data for a specific location.
* *
* @param location Location object. * @param location City and country names.
* @param region_context Additional regional context. * @param region_context Additional regional context.
* @return Generated brewery result. * @return Generated brewery result.
*/ */
BreweryResult GenerateBrewery(const Location& location, BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override; const std::string& region_context) override;
/** /**
@@ -60,23 +54,6 @@ class LlamaGenerator final : public DataGenerator {
UserResult GenerateUser(const std::string& locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
static constexpr int kDefaultMaxTokens = 10000;
static constexpr float kDefaultSamplingTopP = 0.95F;
static constexpr uint32_t kDefaultSamplingTopK = 64;
static constexpr uint32_t kDefaultContextSize = 8192;
struct SamplerState {
SamplerState() = default;
~SamplerState();
SamplerState(const SamplerState&) = delete;
SamplerState& operator=(const SamplerState&) = delete;
SamplerState(SamplerState&&) = delete;
SamplerState& operator=(SamplerState&&) = delete;
LlamaSampler* chain = nullptr;
};
/** /**
* @brief Loads model and prepares inference context. * @brief Loads model and prepares inference context.
* *
@@ -84,6 +61,15 @@ class LlamaGenerator final : public DataGenerator {
*/ */
void Load(const std::string& model_path); void Load(const std::string& model_path);
/**
* @brief Infers text from a user prompt.
*
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& prompt, int max_tokens = 10000);
/** /**
* @brief Infers text from separate system and user prompts. * @brief Infers text from separate system and user prompts.
* *
@@ -95,8 +81,8 @@ class LlamaGenerator final : public DataGenerator {
* @param max_tokens Maximum tokens to generate. * @param max_tokens Maximum tokens to generate.
* @return Generated text. * @return Generated text.
*/ */
std::string Infer(const std::string& system_prompt, const std::string& prompt, std::string Infer(const std::string& system_prompt,
int max_tokens = kDefaultMaxTokens); const std::string& prompt, int max_tokens = 10000);
/** /**
* @brief Runs inference on an already-formatted prompt. * @brief Runs inference on an already-formatted prompt.
@@ -106,25 +92,30 @@ class LlamaGenerator final : public DataGenerator {
* @return Generated text. * @return Generated text.
*/ */
std::string InferFormatted(const std::string& formatted_prompt, std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = kDefaultMaxTokens); int max_tokens = 10000);
/** /**
* @brief Loads the brewery system prompt from disk. * @brief Loads the brewery system prompt from disk.
* *
* @param prompt_file_path Prompt file path to try first. * @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text. * @return Loaded prompt text or fallback prompt.
*/ */
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path); std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
/**
* @brief Returns a built-in fallback system prompt.
*
* @return Fallback prompt text.
*/
std::string GetFallbackBreweryPrompt();
llama_model* model_ = nullptr; llama_model* model_ = nullptr;
llama_context* context_ = nullptr; llama_context* context_ = nullptr;
/// @brief Persistent sampler chain reused across inference calls.
std::unique_ptr<SamplerState> sampler_;
float sampling_temperature_ = 1.0F; float sampling_temperature_ = 1.0F;
float sampling_top_p_ = kDefaultSamplingTopP; float sampling_top_p_ = 0.95F;
uint32_t sampling_top_k_ = kDefaultSamplingTopK; uint32_t sampling_top_k_ = 64;
std::mt19937 rng_; std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize; uint32_t n_ctx_ = 8192;
std::string brewery_system_prompt_; std::string brewery_system_prompt_;
}; };

View File

@@ -7,7 +7,6 @@
*/ */
#include <cstddef> #include <cstddef>
#include <optional>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <utility> #include <utility>
@@ -36,6 +35,16 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
std::pair<std::string, std::string> ParseTwoLineResponsePublic( std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message); const std::string& raw, const std::string& error_message);
/**
* @brief Applies model chat template to a user-only prompt.
*
* @param model Loaded llama model.
* @param user_prompt User prompt text.
* @return Model-formatted prompt.
*/
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
/** /**
* @brief Applies model chat template to system and user prompts. * @brief Applies model chat template to system and user prompts.
* *
@@ -64,10 +73,10 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
* @param raw Raw model output. * @param raw Raw model output.
* @param name_out Parsed brewery name. * @param name_out Parsed brewery name.
* @param description_out Parsed brewery description. * @param description_out Parsed brewery description.
* @return Validation error message if invalid, or std::nullopt on success. * @return Empty string on success, or validation error message.
*/ */
std::optional<std::string> ValidateBreweryJsonPublic( std::string ValidateBreweryJsonPublic(const std::string& raw,
const std::string& raw, std::string& name_out, std::string& name_out,
std::string& description_out); std::string& description_out);
/** /**

View File

@@ -6,9 +6,9 @@
* @brief Deterministic mock implementation of DataGenerator. * @brief Deterministic mock implementation of DataGenerator.
*/ */
#include <array>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
@@ -24,7 +24,7 @@ class MockGenerator final : public DataGenerator {
* @param region_context Unused for mock generation. * @param region_context Unused for mock generation.
* @return Generated brewery result. * @return Generated brewery result.
*/ */
BreweryResult GenerateBrewery(const Location& location, BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override; const std::string& region_context) override;
/** /**
@@ -42,82 +42,13 @@ class MockGenerator final : public DataGenerator {
* @param location City and country names. * @param location City and country names.
* @return Deterministic hash value. * @return Deterministic hash value.
*/ */
static std::size_t DeterministicHash(const Location& location); static std::size_t DeterministicHash(const BreweryLocation& location);
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = { static const std::vector<std::string> kBreweryAdjectives;
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", static const std::vector<std::string> kBreweryNouns;
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel", static const std::vector<std::string> kBreweryDescriptions;
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"}; static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios;
static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18> kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark "
"porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky "
"lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food "
"pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional "
"grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward "
"sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal "
"blends.",
"Modern taproom focused on sessionable lagers and classic pub "
"styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp "
"lagers.",
"Creative brewery offering rotating collaborations and limited "
"draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich "
"stouts."};
static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward "
"styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean "
"fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -33,7 +33,7 @@ struct ApplicationOptions {
/// @brief Context window size (tokens) for LLM inference. Higher values /// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory. /// support longer prompts but use more memory.
uint32_t n_ctx = 8192; uint32_t n_ctx = 2048;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;

View File

@@ -13,10 +13,10 @@
*/ */
struct BreweryResult { struct BreweryResult {
/// @brief Brewery display name. /// @brief Brewery display name.
std::string name{}; std::string name;
/// @brief Brewery description text. /// @brief Brewery description text.
std::string description{}; std::string description;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_

View File

@@ -15,7 +15,7 @@
*/ */
struct EnrichedCity { struct EnrichedCity {
Location location; Location location;
std::string region_context{}; std::string region_context;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_

View File

@@ -13,25 +13,25 @@
*/ */
struct Location { struct Location {
/// @brief City name. /// @brief City name.
std::string city{}; std::string city;
/// @brief State or province name. /// @brief State or province name.
std::string state_province{}; std::string state_province;
/// @brief ISO 3166-2 subdivision code. /// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{}; std::string iso3166_2;
/// @brief Country name. /// @brief Country name.
std::string country{}; std::string country;
/// @brief ISO 3166-1 country code. /// @brief ISO 3166-1 country code.
std::string iso3166_1{}; std::string iso3166_1;
/// @brief Latitude in decimal degrees. /// @brief Latitude in decimal degrees.
double latitude{}; double latitude;
/// @brief Longitude in decimal degrees. /// @brief Longitude in decimal degrees.
double longitude{}; double longitude;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_

View File

@@ -13,10 +13,10 @@
*/ */
struct UserResult { struct UserResult {
/// @brief Username handle. /// @brief Username handle.
std::string username{}; std::string username;
/// @brief Short user biography. /// @brief Short user biography.
std::string bio{}; std::string bio;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_

View File

@@ -6,7 +6,7 @@
* @brief Loader API for curated location data. * @brief Loader API for curated location data.
*/ */
#include <filesystem> #include <string>
#include <vector> #include <vector>
#include "data_model/location.h" #include "data_model/location.h"
@@ -15,8 +15,7 @@
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON array file and returns all location records. /// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations( static std::vector<Location> LoadLocations(const std::string& filepath);
const std::filesystem::path& filepath);
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -14,19 +14,19 @@
#include "services/enrichment_service.h" #include "services/enrichment_service.h"
#include "web_client/web_client.h" #include "web_client/web_client.h"
/// @brief Provides Wikipedia summary lookups backed by cached raw extracts. /// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService final : public IEnrichmentService { class WikipediaService final : public IEnrichmentService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::unique_ptr<WebClient> client); explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia-derived context for a location. /// @brief Returns the Wikipedia-derived context for a location.
[[nodiscard]] std::string GetLocationContext(const Location& loc) override; [[nodiscard]] std::string GetLocationContext(const Location& loc) override;
private: private:
std::string FetchExtract(std::string_view query); std::string FetchExtract(std::string_view query);
std::unique_ptr<WebClient> client_; std::shared_ptr<WebClient> client_;
/// @brief Canonical cache for raw Wikipedia query extracts. std::unordered_map<std::string, std::string> cache_;
std::unordered_map<std::string, std::string> extract_cache_; std::unordered_map<std::string, std::string> extract_cache_;
}; };

View File

@@ -8,7 +8,7 @@
#include <utility> #include <utility>
BiergartenDataGenerator::BiergartenDataGenerator( BiergartenDataGenerator::BiergartenDataGenerator(
std::unique_ptr<IEnrichmentService> context_service, std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator) std::unique_ptr<DataGenerator> generator)
: context_service_(std::move(context_service)), : context_service_(std::move(context_service)),
generator_(std::move(generator)) {} generator_(std::move(generator)) {}

View File

@@ -8,32 +8,34 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
void BiergartenDataGenerator::GenerateBreweries( void BiergartenDataGenerator::GenerateBreweries(
std::span<const EnrichedCity> cities) { const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generated_breweries_.clear(); generated_breweries_.clear();
size_t skipped_count = 0; size_t skipped_count = 0;
for (const auto& [location, region_context] : cities) { for (const auto& enriched_city : cities) {
try { try {
const BreweryResult brewery = auto brewery = generator_->GenerateBrewery(
generator_->GenerateBrewery(location, region_context); BreweryLocation{enriched_city.location.city,
enriched_city.location.country},
const GeneratedBrewery gen{.location = location, .brewery = brewery}; enriched_city.region_context);
generated_breweries_.push_back(GeneratedBrewery{
generated_breweries_.push_back(gen); .location = enriched_city.location, .brewery = brewery});
} catch (const std::exception& e) { } catch (const std::exception& e) {
++skipped_count; ++skipped_count;
spdlog::warn( spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: " "[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}", "{}",
location.city, location.country, e.what()); enriched_city.location.city, enriched_city.location.country,
e.what());
} }
} }
if (skipped_count > 0) { if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors", spdlog::warn(
"[Pipeline] Skipped {} city/cities due to generation "
"errors",
skipped_count); skipped_count);
} }
} }

View File

@@ -13,18 +13,18 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include "json_handling/json_loader.h" #include "json_handling/json_loader.h"
static constexpr std::size_t kBreweryAmount = 4; static constexpr unsigned int brewery_amount = 4;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() { std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
const std::filesystem::path locations_path = "locations.json"; const std::filesystem::path locations_path = "locations.json";
auto all_locations = JsonLoader::LoadLocations(locations_path); auto all_locations = JsonLoader::LoadLocations(locations_path.string());
spdlog::info(" Locations available: {}", all_locations.size()); spdlog::info(" Locations available: {}", all_locations.size());
const std::size_t sample_count = const size_t sample_count =
std::min(kBreweryAmount, all_locations.size()); std::min<size_t>(brewery_amount, all_locations.size());
const auto sample_count_signed = const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>( static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count); sample_count);

View File

@@ -21,8 +21,8 @@ bool BiergartenDataGenerator::Run() {
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}", spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context); city.city, city.country, region_context);
enriched.push_back( enriched.push_back(EnrichedCity{.location = city,
EnrichedCity{.location = city, .region_context = region_context}); .region_context = region_context});
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
++skipped_count; ++skipped_count;
spdlog::warn( spdlog::warn(

View File

@@ -7,16 +7,16 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <array> #include <array>
#include <format>
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) { namespace {
auto trim = [](const std::string_view text) -> std::string_view {
std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r"); const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) { if (first == std::string_view::npos) {
return {}; return {};
@@ -26,7 +26,7 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
return text.substr(first, last - first + 1); return text.substr(first, last - first + 1);
}; };
static constexpr std::array<std::string_view, 6> separator_tokens = { static const std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>", "<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"}; "<turn|>", "<channel|>", "<|channel|>"};
@@ -35,7 +35,8 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
for (const std::string_view token : separator_tokens) { for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token); const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos && if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || candidate_pos > separator_pos)) { (separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos; separator_pos = candidate_pos;
separator_length = token.size(); separator_length = token.size();
} }
@@ -46,9 +47,8 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
} }
const std::string_view trimmed = trim(raw_response); const std::string_view trimmed = trim(raw_response);
const std::string json_candidate = std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed)); ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) { if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed)); return ExtractLastJsonObjectPublic(std::string(trimmed));
} }
@@ -56,22 +56,16 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
return std::string(trimmed); return std::string(trimmed);
} }
} // namespace
BreweryResult LlamaGenerator::GenerateBrewery( BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) { const BreweryLocation& location, const std::string& region_context) {
/** /**
* Preprocess and truncate region context to manageable size * Preprocess and truncate region context to manageable size
*/ */
const std::string safe_region_context = const std::string safe_region_context =
PrepareRegionContextPublic(region_context); PrepareRegionContextPublic(region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/** /**
* Load brewery system prompt from file * Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found * Falls back to minimal inline prompt if file not found
@@ -81,33 +75,47 @@ BreweryResult LlamaGenerator::GenerateBrewery(
/** /**
* User prompt: provides geographic context to guide generation towards * User prompt: provides geographic context to guide generation towards
* culturally relevant and locally-inspired brewery attributes * culturally appropriate and locally-inspired brewery attributes
*/ */
std::string prompt = std::format( std::string prompt =
"Write a brewery name and place-specific long description for a craft " "Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}", "brewery in ";
location.city, country_suffix, region_suffix); prompt.append(location.city_name);
if (!location.country_name.empty()) {
prompt.append(", ");
prompt.append(location.country_name);
}
if (safe_region_context.empty()) {
prompt.append(".");
} else {
prompt.append(". Regional context: ");
prompt.append(safe_region_context);
}
/** /**
* Store location context for retry prompts (without repeating full context) * Store location context for retry prompts (without repeating full context)
*/ */
const std::string retry_location = std::string retry_location = "Location: ";
std::format("Location: {}{}", location.city, country_suffix); retry_location.append(location.city_name);
if (!location.country_name.empty()) {
retry_location.append(", ");
retry_location.append(location.country_name);
}
/** /**
* RETRY LOOP with validation and error correction * RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based * Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement * refinement
*/ */
constexpr int max_attempts = 3; const int max_attempts = 3;
std::string raw; std::string raw;
std::string last_error; std::string last_error;
// Limit output length to keep it concise and focused // Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052; constexpr int max_tokens = 1052;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
// Generate brewery data from LLM // Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens); raw = Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw); raw);
@@ -116,28 +124,31 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string name; std::string name;
std::string description; std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw); const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error = const std::string validation_error =
ValidateBreweryJsonPublic(json_only, name, description); ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) { if (validation_error.empty()) {
// Success: return parsed brewery data // Success: return parsed brewery data
return BreweryResult{.name = std::move(name), return {std::move(name), std::move(description)};
.description = std::move(description)};
} }
// Validation failed: log error and prepare corrective feedback // Validation failed: log error and prepare corrective feedback
last_error = *validation_error; last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error); attempt + 1, validation_error);
// Update prompt with error details to guide LLM toward correct output. // Update prompt with error details to guide LLM toward correct output.
prompt = std::format( // For retries, use a compact prompt format to avoid exceeding token
R"(Your previous response was invalid. Error: {} // limits.
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}. prompt =
Do not include markdown, comments, extra keys, or literal placeholder values. "Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with exactly these keys: "
{})", "{\"name\": \"<brewery name>\", "
*validation_error, retry_location); "\"description\": \"<single-paragraph description>\"}."
"\nDo not include markdown, comments, extra keys, or literal "
"placeholder values.";
prompt += "\n\n";
prompt += retry_location;
} }
// All retry attempts exhausted: log failure and throw exception // All retry attempts exhausted: log failure and throw exception

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@@ -13,6 +14,87 @@
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) { UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user", /**
.bio = "This is a test user profile from " + locale + "."}; * System prompt: specifies exact output format to minimize parsing errors
* Constraints: 2-line output, username format, bio length bounds
*/
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
/**
* User prompt: locale parameter guides cultural appropriateness of generated
* profiles
*/
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
/**
* RETRY LOOP with format validation
* Attempts up to 3 times to generate valid user profile with correct format
*/
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
/**
* Generate user profile (max 128 tokens - should fit 2 lines easily)
*/
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
/**
* Parse two-line response: first line = username, second line = bio
*/
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
/**
* Remove any whitespace from username (usernames shouldn't have
* spaces)
*/
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
/**
* Validate both fields are non-empty after processing
*/
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
/**
* Truncate bio if exceeds reasonable length for bio field
*/
if (bio.size() > 200) bio = bio.substr(0, 200);
/**
* Success: return parsed user profile
*/
return {username, bio};
} catch (const std::exception& e) {
/**
* Parsing failed: log and continue to next attempt
*/
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
/**
* All retry attempts exhausted: log failure and throw exception
*/
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
} }

View File

@@ -4,17 +4,13 @@
* parsing, token decoding, and JSON validation helpers for Llama modules. * parsing, token decoding, and JSON validation helpers for Llama modules.
*/ */
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <cctype> #include <cctype>
#include <optional>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <string_view>
#include <vector> #include <vector>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
@@ -23,42 +19,40 @@
/** /**
* String trimming: removes leading and trailing whitespace * String trimming: removes leading and trailing whitespace
*/ */
static std::string Trim(std::string_view value) { static std::string Trim(std::string value) {
constexpr std::string_view whitespace = " \t\n\r\f\v"; auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
const std::size_t last_index = value.find_last_not_of(whitespace); value.erase(value.begin(),
return std::string(value.substr(first_index, last_index - first_index + 1)); std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
} }
/** /**
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single * Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces * spaces
*/ */
static std::string CondenseWhitespace(std::string_view text) { static std::string CondenseWhitespace(std::string text) {
std::string out; std::string out;
out.reserve(text.size()); out.reserve(text.size());
bool pending_space = false; bool in_whitespace = false;
for (const unsigned char chr : text) { for (unsigned char ch : text) {
if (std::isspace(chr) != 0) { if (std::isspace(ch)) {
if (!out.empty()) { if (!in_whitespace) {
pending_space = true; out.push_back(' ');
in_whitespace = true;
} }
continue; continue;
} }
if (pending_space) { in_whitespace = false;
out.push_back(' '); out.push_back(static_cast<char>(ch));
pending_space = false;
}
out.push_back(static_cast<char>(chr));
} }
return out; return Trim(std::move(out));
} }
/** /**
@@ -66,14 +60,14 @@ static std::string CondenseWhitespace(std::string_view text) {
* boundaries * boundaries
*/ */
static std::string PrepareRegionContext(std::string_view region_context, static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) { std::size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context); std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) { if (normalized.size() <= max_chars) {
return normalized; return normalized;
} }
normalized.resize(max_chars); normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' '); const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) { if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space); normalized.resize(last_space);
} }
@@ -82,20 +76,108 @@ static std::string PrepareRegionContext(std::string_view region_context,
return normalized; return normalized;
} }
static std::string ToChatPrompt(const llama_model* model, /**
* Remove common bullet points, numbers, and field labels added by LLM in output
*/
static std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
/**
* Parse two-line response from LLM: normalize line endings, strip formatting,
* filter spurious output, and combine remaining lines if needed
*/
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
const char* tmpl = llama_model_chat_template(model, nullptr); const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) { if (tmpl == nullptr) {
// No template found, fallback to raw text // No template found, fallback to raw text
spdlog::warn( return system_prompt + "\n\n" + user_prompt;
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
} }
const std::array<llama_chat_message, 2> messages = { const std::array<llama_chat_message, 2> messages = {
@@ -104,67 +186,71 @@ static std::string ToChatPrompt(const llama_model* model,
std::vector<char> buffer(std::max<std::size_t>( std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4)); 1024, (system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages, int32_t required =
int32_t message_count) -> int32_t { llama_chat_apply_template(tmpl, messages.data(), 2, true, buffer.data(),
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
if (result < 0) { // FALLBACK: If the template fails (e.g., Gemma rejecting the "system" role),
return result;
}
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
return result;
};
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// combine the system and user prompts into a single "user" message. // combine the system and user prompts into a single "user" message.
if (required < 0) {
std::string combined_prompt = system_prompt + "\n\n" + user_prompt;
const std::array<llama_chat_message, 1> fallback_msg = { const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}}; {{"user", combined_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1); required = llama_chat_apply_template(tmpl, fallback_msg.data(), 1, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// Ultimate fallback: if GGUF template parsing still fails, use raw text. // THE FIX: Ultimate fallback. If the GGUF's internal template is
if (template_result < 0) { // completely unparseable (which happens with complex Jinja macros),
spdlog::warn( // degrade gracefully to raw text instead of throwing a runtime_error.
"LlamaGenerator: chat template fallback failed (result {}); using " if (required < 0) {
"raw prompt text",
template_result);
return combined_prompt; return combined_prompt;
} }
return {buffer.data(), static_cast<std::size_t>(template_result)}; if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(
tmpl, fallback_msg.data(), 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
return combined_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
// Standard buffer resize if the original "system" + "user" array succeeded
// but needed more space
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages.data(), 2, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// Final safety net on resize
if (required < 0) {
return system_prompt + "\n\n" + user_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
} }
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token, static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
std::array<char, 256> buffer{}; std::array<char, 256> buffer{};
int32_t bytes = int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true); llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) { if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes)); std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0, static_cast<int32_t>(dynamic_buffer.size()),
true); 0, true);
if (bytes < 0) { if (bytes < 0) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece"); "LlamaGenerator: failed to decode sampled token piece");
@@ -242,8 +328,8 @@ std::string ExtractLastJsonObjectPublic(const std::string& text) {
return {}; return {};
} }
static std::optional<std::string> ValidateBreweryJson( static std::string ValidateBreweryJson(const std::string& raw,
const std::string& raw, std::string& name_out, std::string& name_out,
std::string& description_out) { std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv, auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool { std::string& error_out) -> bool {
@@ -263,11 +349,9 @@ static std::optional<std::string> ValidateBreweryJson(
return false; return false;
} }
const auto& name_value = obj.at("name").as_string(); name_out = Trim(std::string(obj.at("name").as_string().c_str()));
const auto& description_value = obj.at("description").as_string(); description_out =
name_out = Trim(std::string_view(name_value.data(), name_value.size())); Trim(std::string(obj.at("description").as_string().c_str()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
if (name_out.empty()) { if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty"; error_out = "JSON field 'name' must not be empty";
@@ -317,14 +401,14 @@ static std::optional<std::string> ValidateBreweryJson(
return validation_error; return validation_error;
} }
return std::nullopt; return {};
} }
if (!validate_object(jv, validation_error)) { if (!validate_object(jv, validation_error)) {
return validation_error; return validation_error;
} }
return std::nullopt; return {};
} }
// Forward declarations for helper functions exposed to other translation units // Forward declarations for helper functions exposed to other translation units
@@ -333,6 +417,16 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
return PrepareRegionContext(region_context, max_chars); return PrepareRegionContext(region_context, max_chars);
} }
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt, "");
}
std::string ToChatPromptPublic(const llama_model* model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
@@ -344,8 +438,8 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
AppendTokenPiece(vocab, token, output); AppendTokenPiece(vocab, token, output);
} }
std::optional<std::string> ValidateBreweryJsonPublic( std::string ValidateBreweryJsonPublic(const std::string& raw,
const std::string& raw, std::string& name_out, std::string& name_out,
std::string& description_out) { std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out); return ValidateBreweryJson(raw, name_out, description_out);
} }

View File

@@ -2,7 +2,7 @@
* Text Generation / Inference Module * Text Generation / Inference Module
* Core module that performs LLM inference: converts text prompts into tokens, * Core module that performs LLM inference: converts text prompts into tokens,
* runs the neural network forward pass, samples the next token, and converts * runs the neural network forward pass, samples the next token, and converts
* output tokens back to text for system+user chat prompts. * output tokens back to text. Supports both simple and system+user prompts.
*/ */
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
@@ -17,31 +17,30 @@
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
#include "llama.h" #include "llama.h"
static constexpr std::size_t kPromptTokenSlack = 8; std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt, std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, const std::string& prompt, int max_tokens) {
const int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt), return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens); max_tokens);
} }
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
const int max_tokens) { int max_tokens) {
/** /**
* Validate that model and context are loaded * Validate that model and context are loaded
*/ */
if (model_ == nullptr || context_ == nullptr) { if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded"); throw std::runtime_error("LlamaGenerator: model not loaded");
}
/** /**
* Get vocabulary for tokenization and token-to-text conversion * Get vocabulary for tokenization and token-to-text conversion
*/ */
const llama_vocab* vocab = llama_model_get_vocab(model_); const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) { if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable"); throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
/** /**
* Clear KV cache to ensure clean inference state (no residual context) * Clear KV cache to ensure clean inference state (no residual context)
@@ -52,8 +51,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* TOKENIZATION PHASE * TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands * Convert text prompt into token IDs (integers) that the model understands
*/ */
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
kPromptTokenSlack);
int32_t token_count = llama_tokenize( int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
@@ -70,20 +68,18 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
static_cast<int32_t>(prompt_tokens.size()), true, true); static_cast<int32_t>(prompt_tokens.size()), true, true);
} }
if (token_count < 0) { if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
/** /**
* CONTEXT SIZE VALIDATION * CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window * Validate and compute effective token budgets based on context window
* constraints * constraints
*/ */
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_)); const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_)); const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) { if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
/** /**
* Clamp generation limit to available context window, reserve space for * Clamp generation limit to available context window, reserve space for
@@ -117,9 +113,47 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
*/ */
const llama_batch prompt_batch = llama_batch_get_one( const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size())); prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) { if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed"); throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
/**
* SAMPLER CONFIGURATION PHASE
* Set up the probabilistic token selection pipeline (sampler chain)
* Samplers are applied in sequence: temperature -> top-k -> top-p ->
* distribution
*/
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
/**
* Temperature: scales logits before softmax (controls randomness)
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
/**
* Top-K: limits sampling to the most likely tokens before nucleus
* sampling
*/
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
/**
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
* probability
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
/**
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
/** /**
* TOKEN GENERATION LOOP * TOKEN GENERATION LOOP
@@ -129,44 +163,36 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
std::vector<llama_token> generated_tokens; std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
if (sampler_ == nullptr || sampler_->chain == nullptr) {
throw std::runtime_error("LlamaGenerator: sampler not initialized");
}
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
/** /**
* Sample next token using configured sampler chain and model logits * Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch * Index -1 means use the last output position from previous batch
*/ */
const llama_token next = const llama_token next =
llama_sampler_sample(sampler_->chain, context_, -1); llama_sampler_sample(sampler.get(), context_, -1);
/** /**
* Stop if model predicts end-of-generation token (EOS/EOT) * Stop if model predicts end-of-generation token (EOS/EOT)
*/ */
if (llama_vocab_is_eog(vocab, next)) { if (llama_vocab_is_eog(vocab, next)) break;
break;
}
generated_tokens.push_back(next); generated_tokens.push_back(next);
/** /**
* Feed the sampled token back into model for next iteration * Feed the sampled token back into model for next iteration
* (autoregressive) * (autoregressive)
*/ */
llama_token decode_token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0) { if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: decode failed during generation"); "LlamaGenerator: decode failed during generation");
} }
}
/** /**
* DETOKENIZATION PHASE * DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary * Convert generated token IDs back to text using vocabulary
*/ */
std::string output; std::string output;
for (const llama_token token : generated_tokens) { for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output); AppendTokenPiecePublic(vocab, token, output);
}
return output; return output;
} }

View File

@@ -5,7 +5,6 @@
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include <memory>
#include <random> #include <random>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@@ -13,47 +12,6 @@
#include "data_model/application_options.h" #include "data_model/application_options.h"
#include "llama.h" #include "llama.h"
static constexpr uint32_t kMaxContextSize = 32768U;
struct SamplerConfig {
float temperature;
float top_p;
uint32_t top_k;
};
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
std::mt19937& rng) {
const llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(config.temperature));
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(config.top_p, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng()));
return sampler;
}
LlamaGenerator::SamplerState::~SamplerState() {
if (chain != nullptr) {
llama_sampler_free(chain);
chain = nullptr;
}
}
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) const std::string& model_path)
: rng_(std::random_device{}()) { : rng_(std::random_device{}()) {
@@ -80,7 +38,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
"LlamaGenerator: seed must be >= 0, or -1 for random"); "LlamaGenerator: seed must be >= 0, or -1 for random");
} }
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) { if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]"); "LlamaGenerator: context size must be in range [1, 32768]");
} }
@@ -97,16 +55,9 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
n_ctx_ = options.n_ctx; n_ctx_ = options.n_ctx;
this->Load(model_path); this->Load(model_path);
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
sampling_top_k_};
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
sampler_.reset(new SamplerState());
sampler_->chain = sampler_chain.release();
} }
LlamaGenerator::~LlamaGenerator() { LlamaGenerator::~LlamaGenerator() {
sampler_.reset();
/** /**
* Free the inference context (contains KV cache and computation state) * Free the inference context (contains KV cache and computation state)
*/ */

View File

@@ -23,7 +23,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
model_ = nullptr; model_ = nullptr;
} }
const llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params); model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) { if (model_ == nullptr) {
throw std::runtime_error( throw std::runtime_error(

View File

@@ -1,14 +1,13 @@
/** /**
* @file data_generation/llama/load_brewery_prompt.cpp * @file data_generation/llama/load_brewery_prompt.cpp
* @brief Resolves brewery system prompt content from cache or a configured * @brief Resolves brewery system prompt content from cache or filesystem
* filesystem path and provides a robust inline fallback prompt when absent. * search paths and provides a robust inline fallback prompt when absent.
*/ */
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
@@ -18,7 +17,7 @@ namespace fs = std::filesystem;
* @brief Loads brewery system prompt from disk or cache. * @brief Loads brewery system prompt from disk or cache.
* *
* @param prompt_file_path Preferred prompt file location. * @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk. * @return Prompt text loaded from disk or fallback content.
*/ */
std::string LlamaGenerator::LoadBrewerySystemPrompt( std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) { const std::string& prompt_file_path) {
@@ -27,33 +26,72 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt(
return brewery_system_prompt_; return brewery_system_prompt_;
} }
// Try the provided path only // Try multiple path locations
const fs::path prompt_path(prompt_file_path); std::vector<std::string> paths_to_try = {
std::ifstream prompt_file(prompt_path); prompt_file_path, // As provided
if (!prompt_file.is_open()) { "../" + prompt_file_path, // One level up
spdlog::error( "../../" + prompt_file_path, // Two levels up
"LlamaGenerator: Failed to open brewery system prompt file '{}'", };
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
const std::string prompt((std::istreambuf_iterator(prompt_file)), for (const auto& path : paths_to_try) {
std::ifstream prompt_file(path);
if (prompt_file.is_open()) {
std::string prompt((std::istreambuf_iterator<char>(prompt_file)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
prompt_file.close(); prompt_file.close();
if (prompt.empty()) { if (!prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_path.string());
}
spdlog::info( spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)", "LlamaGenerator: Loaded brewery system prompt from '{}' ({} "
prompt_path.string(), prompt.length()); "chars)",
path, prompt.length());
brewery_system_prompt_ = prompt; brewery_system_prompt_ = prompt;
return brewery_system_prompt_; return brewery_system_prompt_;
}
}
}
spdlog::warn(
"LlamaGenerator: Could not open brewery system prompt file at any of "
"the "
"expected locations. Using fallback inline prompt.");
return GetFallbackBreweryPrompt();
}
/**
* @brief Provides an inline fallback brewery system prompt.
*
* @return Default fallback prompt text.
*/
std::string LlamaGenerator::GetFallbackBreweryPrompt() {
return "You are an experienced brewmaster and owner of a local craft "
"brewery. "
"Create a distinctive, authentic name and detailed description that "
"genuinely reflects your specific location, brewing philosophy, "
"local "
"culture, and community connection. The brewery must feel real and "
"grounded—not generic or interchangeable.\n\n"
"AVOID REPETITIVE PHRASES - Never use:\n"
"Love letter to, tribute to, rolling hills, picturesque, every sip "
"tells a story, Come for X stay for Y, rich history, passion, woven "
"into, ancient roots, timeless, where tradition meets innovation\n\n"
"OPENING APPROACHES - Choose ONE:\n"
"1. Start with specific beer style and its regional origins\n"
"2. Begin with specific brewing challenge (water, altitude, "
"climate)\n"
"3. Open with founding story or personal motivation\n"
"4. Lead with specific local ingredient or resource\n"
"5. Start with unexpected angle or contradiction\n"
"6. Open with local event, tradition, or cultural moment\n"
"7. Begin with tangible architectural or geographic detail\n\n"
"BE SPECIFIC - Include:\n"
"- At least ONE concrete proper noun (landmark, river, "
"neighborhood)\n"
"- Specific beer styles relevant to the REGION'S culture\n"
"- Concrete brewing challenges or advantages\n"
"- Sensory details SPECIFIC to place—not generic adjectives\n\n"
"LENGTH: 150-250 words. TONE: Can be soulful, irreverent, "
"matter-of-fact, unpretentious, or minimalist.\n\n"
"Output ONLY a raw JSON object with keys name and description. "
"No markdown, backticks, preamble, or trailing text.";
} }

View File

@@ -0,0 +1,71 @@
/**
* @file data_generation/mock/data.cpp
* @brief Defines static lookup tables used by MockGenerator for deterministic
* brewery names, descriptions, usernames, and bios.
*/
#include <string>
#include <vector>
#include "data_generation/mock_generator.h"
const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
const std::vector<std::string> MockGenerator::kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
const std::vector<std::string> MockGenerator::kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal blends.",
"Modern taproom focused on sessionable lagers and classic pub styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp lagers.",
"Creative brewery offering rotating collaborations and limited draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich stouts."};
const std::vector<std::string> MockGenerator::kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
const std::vector<std::string> MockGenerator::kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};

View File

@@ -5,12 +5,13 @@
*/ */
#include <boost/container_hash/hash.hpp> #include <boost/container_hash/hash.hpp>
#include <string>
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
size_t MockGenerator::DeterministicHash(const Location& location) { std::size_t MockGenerator::DeterministicHash(const BreweryLocation& location) {
size_t seed = 0; std::size_t seed = 0;
boost::hash_combine(seed, location.city); boost::hash_combine(seed, location.city_name);
boost::hash_combine(seed, location.country); boost::hash_combine(seed, location.country_name);
return seed; return seed;
} }

View File

@@ -4,39 +4,35 @@
* and country into fixed mock phrase catalogs. * and country into fixed mock phrase catalogs.
*/ */
#include <format>
#include <string> #include <string>
#include <string_view>
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery( BreweryResult MockGenerator::GenerateBrewery(
const Location& location, const std::string& /*region_context*/) { const BreweryLocation& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location); const std::size_t hash = DeterministicHash(location);
const std::string_view adjective = const std::string& adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size()); kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun = const std::string& noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size()); kBreweryNouns.at((hash / 7) % kBreweryNouns.size());
const std::string_view base_description = const std::string& base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size()); kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string name = std::string name(location.city_name);
std::format("{} {} {}", location.city, adjective, noun); name.append(" ");
name.append(adjective);
name.append(" ");
name.append(noun);
const std::string state_suffix = std::string description = base_description;
location.state_province.empty() description.append(" Based in ");
? std::string{} description.append(location.city_name);
: std::format(", {}", location.state_province); if (!location.country_name.empty()) {
const std::string country_suffix = description.append(", ");
location.country.empty() ? std::string{} description.append(location.country_name);
: std::format(", {}", location.country); }
const std::string description = description.append(".");
std::format("{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix);
return { return {name, description};
.name = name,
.description = description,
};
} }

View File

@@ -6,7 +6,6 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <string_view>
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
@@ -14,9 +13,7 @@ UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale); const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result; UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()]; result.username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()]; result.bio = kBios[(hash / 11) % kBios.size()];
result.username = username;
result.bio = bio;
return result; return result;
} }

View File

@@ -12,35 +12,31 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string_view>
static std::string ReadRequiredString(const boost::json::object& object, static std::string ReadRequiredString(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) { if (value == nullptr || !value->is_string()) {
throw std::runtime_error(std::string("Missing or invalid string field: ") + throw std::runtime_error(
key); std::string("Missing or invalid string field: ") + key);
} }
const std::string_view text = value->as_string(); return std::string(value->as_string().c_str());
return std::string(text);
} }
static double ReadRequiredNumber(const boost::json::object& object, static double ReadRequiredNumber(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) { if (value == nullptr || !value->is_number()) {
throw std::runtime_error(std::string("Missing or invalid numeric field: ") + throw std::runtime_error(
key); std::string("Missing or invalid numeric field: ") + key);
} }
return value->to_number<double>(); return value->to_number<double>();
} }
std::vector<Location> JsonLoader::LoadLocations( std::vector<Location> JsonLoader::LoadLocations(const std::string& filepath) {
const std::filesystem::path& filepath) {
std::ifstream input(filepath); std::ifstream input(filepath);
if (!input.is_open()) { if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " + throw std::runtime_error("Failed to open locations file: " + filepath);
filepath.string());
} }
std::stringstream buffer; std::stringstream buffer;
@@ -82,6 +78,6 @@ std::vector<Location> JsonLoader::LoadLocations(
} }
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(), spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath.string()); filepath);
return locations; return locations;
} }

View File

@@ -10,7 +10,6 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include <exception> #include <exception>
#include <memory> #include <memory>
#include <optional>
#include <sstream> #include <sstream>
#include <string> #include <string>
@@ -31,35 +30,26 @@ namespace di = boost::di;
* *
* @param argc Command-line argument count. * @param argc Command-line argument count.
* @param argv Command-line arguments. * @param argv Command-line arguments.
* @return Parsed ApplicationOptions if parsing succeeded, std::nullopt * @param options Output ApplicationOptions struct.
* otherwise. * @return true if parsing succeeded and should proceed, false otherwise.
*/ */
std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) { bool ParseArguments(const int argc, char** argv,
ApplicationOptions& options) noexcept {
prog_opts::options_description desc("Pipeline Options"); prog_opts::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
auto opt = desc.add_options(); "mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data")(
opt("help,h", "Produce help message"); "model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
opt("mocked", prog_opts::bool_switch(), "temperature", prog_opts::value<float>()->default_value(1.0F),
"Use mocked generator for brewery/user data"); "Sampling temperature (higher = more random)")(
"top-p", prog_opts::value<float>()->default_value(0.95F),
opt("model,m", prog_opts::value<std::string>()->default_value(""), "Nucleus sampling top-p in (0,1] (higher = more random)")(
"Path to LLM model (gguf)"); "top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)")(
opt("temperature", prog_opts::value<float>()->default_value(1.0F), "n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Sampling temperature (higher = more random)"); "Context window size in tokens (1-32768)")(
"seed", prog_opts::value<int>()->default_value(-1),
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)");
opt("seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer"); "Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case // Handle the "no arguments" or "help" case
@@ -68,7 +58,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
std::stringstream usage_stream; std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc; usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str()); spdlog::info(usage_stream.str());
return std::nullopt; return false;
} }
try { try {
@@ -81,7 +71,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
std::stringstream help_stream; std::stringstream help_stream;
help_stream << "\n" << desc; help_stream << "\n" << desc;
spdlog::info(help_stream.str()); spdlog::info(help_stream.str());
return std::nullopt; return false;
} }
const auto use_mocked = variables_map["mocked"].as<bool>(); const auto use_mocked = variables_map["mocked"].as<bool>();
@@ -90,19 +80,19 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
if (use_mocked && !model_path.empty()) { if (use_mocked && !model_path.empty()) {
spdlog::error( spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive"); "Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt; return false;
} }
if (!use_mocked && model_path.empty()) { if (!use_mocked && model_path.empty()) {
spdlog::error( spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified"); "Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt; return false;
} }
const bool has_llm_params = !variables_map["temperature"].defaulted() || const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() || !variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() || !variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false; !variables_map["seed"].defaulted();
if (use_mocked && has_llm_params) { if (use_mocked && has_llm_params) {
spdlog::warn( spdlog::warn(
@@ -110,7 +100,6 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
" ignored when using --mocked"); " ignored when using --mocked");
} }
ApplicationOptions options;
options.use_mocked = use_mocked; options.use_mocked = use_mocked;
options.model_path = model_path; options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>(); options.temperature = variables_map["temperature"].as<float>();
@@ -119,37 +108,35 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
options.n_ctx = variables_map["n-ctx"].as<uint32_t>(); options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>(); options.seed = variables_map["seed"].as<int>();
return options; return true;
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}", spdlog::error("Failed to parse command-line arguments: {}",
exception.what()); exception.what());
return std::nullopt; return false;
} catch (...) { } catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error"); spdlog::error("Failed to parse command-line arguments: unknown error");
return std::nullopt; return false;
} }
} }
int main(const int argc, char** argv) { int main(const int argc, char** argv) noexcept {
try { try {
const CurlGlobalState curl_state; const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state; const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v"); spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
const auto parsed_options = ParseArguments(argc, argv); ApplicationOptions options;
if (!parsed_options.has_value()) { if (!ParseArguments(argc, argv, options)) {
return 0; return 0;
} }
const auto options = *parsed_options;
const auto injector = di::make_injector( const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(), di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options), di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(), di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path), di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to( di::bind<DataGenerator>().to([options](const auto& injector)
[options](const auto& inj) -> std::unique_ptr<DataGenerator> { -> std::unique_ptr<DataGenerator> {
if (options.use_mocked) { if (options.use_mocked) {
spdlog::info( spdlog::info(
"[Generator] Using MockGenerator (no model path provided)"); "[Generator] Using MockGenerator (no model path provided)");
@@ -161,7 +148,7 @@ int main(const int argc, char** argv) {
"top-p={}, top-k={}, n_ctx={}, seed={})", "top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p, options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed); options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>(); return injector.template create<std::unique_ptr<LlamaGenerator>>();
})); }));
auto generator = injector.create<BiergartenDataGenerator>(); auto generator = injector.create<BiergartenDataGenerator>();

View File

@@ -34,12 +34,9 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
if (!pages.empty()) { if (!pages.empty()) {
auto& page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string(); std::string extract(page.at("extract").as_string().c_str());
std::string extract(extract_view);
spdlog::debug("WikipediaService fetched {} chars for '{}'", spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query); extract.size(), query);
this->extract_cache_.emplace(cache_key, extract); this->extract_cache_.emplace(cache_key, extract);
return extract; return extract;
} }

View File

@@ -10,12 +10,19 @@
#include "services/wikipedia_service.h" #include "services/wikipedia_service.h"
std::string WikipediaService::GetLocationContext(const Location& loc) { std::string WikipediaService::GetLocationContext(const Location& loc) {
if (!client_) { const std::string cache_key = loc.city + "|" + loc.country;
return {}; const auto cache_it = cache_.find(cache_key);
if (cache_it != cache_.end()) {
return cache_it->second;
} }
std::string result; std::string result;
if (!client_) {
cache_.emplace(cache_key, result);
return result;
}
std::string region_query(loc.city); std::string region_query(loc.city);
if (!loc.country.empty()) { if (!loc.country.empty()) {
region_query += ", "; region_query += ", ";
@@ -43,5 +50,7 @@ std::string WikipediaService::GetLocationContext(const Location& loc) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query, spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what()); e.what());
} }
cache_.emplace(cache_key, result);
return result; return result;
} }

View File

@@ -7,5 +7,5 @@
#include <utility> #include <utility>
WikipediaService::WikipediaService(std::unique_ptr<WebClient> client) WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}

View File

@@ -5,70 +5,45 @@
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdint> #include <sstream>
#include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url) {
constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, const size_t size, static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
const size_t nmemb, void* userp) { void* userp) {
const size_t real_size = size * nmemb; size_t realsize = size * nmemb;
auto* str = static_cast<std::string*>(userp); auto* s = static_cast<std::string*>(userp);
str->append(static_cast<char*>(contents), real_size); s->append(static_cast<char*>(contents), realsize);
return real_size; return realsize;
} }
std::string CURLWebClient::Get(const std::string& url) { std::string CURLWebClient::Get(const std::string& url) {
const CurlHandle curl = create_handle(); auto curl = create_handle();
std::string response_string; std::string response_string;
set_common_get_options(curl.get(), url, {10L, 20L});
set_common_get_options(curl.get(), url);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) { if (res != CURLE_OK) {
const auto error = std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error); throw std::runtime_error(error);
} }
int64_t httpCode = 0; long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) { if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " + std::stringstream ss;
std::to_string(httpCode) + " for URL " + url; ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(error); throw std::runtime_error(ss.str());
} }
return response_string; return response_string;

View File

@@ -14,11 +14,10 @@ std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (!output) { if (output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}
std::string result(output); std::string result(output);
curl_free(output); curl_free(output);
return result; return result;
}
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
} }

View File

@@ -0,0 +1,28 @@
/**
* @file web_client/curl_web_client_utils.cpp
* @brief Shared CURLWebClient helper implementations.
*/
#include "curl_web_client_utils.h"
#include <stdexcept>
CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}

View File

@@ -0,0 +1,26 @@
#ifndef BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
#define BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
/**
* @file web_client/curl_web_client_utils.h
* @brief Shared helpers for CURLWebClient request setup.
*/
#include <curl/curl.h>
#include <memory>
#include <string>
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
struct CurlTimeouts {
long connect_timeout;
long total_timeout;
};
CurlHandle create_handle();
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts);
#endif // BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_