8 Commits

Author SHA1 Message Date
Aaron Po
299a767d39 remove unused code 2026-04-11 14:42:32 -04:00
Aaron Po
f07d48f810 Add missing includes, update readme 2026-04-11 14:31:24 -04:00
Aaron Po
bcfde856fe Split data models into dedicated headers 2026-04-11 13:21:50 -04:00
Aaron Po
5946356083 Style audit: update code to strictly follow Google Style Guide 2026-04-11 11:56:45 -04:00
Aaron Po
ae67fa8566 refactor: consolidate and rename data generation and service files 2026-04-11 00:06:23 -04:00
Aaron Po
8c572a2d07 fix: stabilize Gemma 4 brewery generation
remove misleading turn-token output guidance from the brewery prompt
extract the last balanced JSON object before validation
keep README model setup and run instructions aligned
preserve Gemma 4 sampling defaults and local model usage
2026-04-10 22:25:26 -04:00
Aaron Po
902bda6eb9 eat: make Gemma 4 the default model, enable thinking mode 2026-04-10 21:43:18 -04:00
Aaron Po
61d5077a95 update readme 2026-04-10 00:03:45 -04:00
50 changed files with 1744 additions and 1528 deletions

View File

@@ -1,5 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
ColumnLimit: 80 ColumnLimit: 80
IndentWidth: 2 IndentWidth: 3
... ...

View File

@@ -9,23 +9,21 @@ Checks: >
-google-runtime-references -google-runtime-references
CheckOptions: CheckOptions:
# Enforce Google Naming Conventions with valid clang-tidy strings # Enforce Google Naming Conventions
- key: readability-identifier-naming.ClassCase
value: CamelCase
- key: readability-identifier-naming.ClassMemberCase - key: readability-identifier-naming.ClassMemberCase
value: lower_case value: snake_case
- key: readability-identifier-naming.ClassMemberSuffix - key: readability-identifier-naming.ClassMemberSuffix
value: _ value: _
- key: readability-identifier-naming.ClassCase
value: PascalCase
- key: readability-identifier-naming.FunctionCase - key: readability-identifier-naming.FunctionCase
value: CamelCase value: PascalCase
- key: readability-identifier-naming.StructCase - key: readability-identifier-naming.StructCase
value: CamelCase value: PascalCase
- key: readability-identifier-naming.VariableCase - key: readability-identifier-naming.VariableCase
value: lower_case value: snake_case
- key: readability-identifier-naming.GlobalConstantCase - key: readability-identifier-naming.GlobalConstantCase
value: CamelCase value: kPascalCase
- key: readability-identifier-naming.GlobalConstantPrefix
value: k
# Ensure C++20 Modernization # Ensure C++20 Modernization
- key: modernize-make-unique.MakeSmartPtrFunction - key: modernize-make-unique.MakeSmartPtrFunction

1
pipeline/.gitignore vendored
View File

@@ -1,6 +1,5 @@
dist dist
build build
cmake-build-*
data data
models models
*.gguf *.gguf

View File

@@ -1,10 +1,12 @@
cmake_minimum_required(VERSION 3.24) cmake_minimum_required(VERSION 3.24)
project(biergarten-pipeline) project(biergarten-pipeline)
# Boost.DI still declares a very old minimum CMake version, which newer CMake
# releases reject unless a policy version floor is provided.
set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "" FORCE) set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "" FORCE)
# ============================================================================= # =============================================================================
# 1. Platform & GPU Detection # 1. Platform & GPU Detection (Windows explicitly NOT supported)
# ============================================================================= # =============================================================================
if(WIN32) if(WIN32)
message(FATAL_ERROR "[biergarten] Windows is currently not supported. Please use Linux (Fedora 43) or macOS (M1 Pro).") message(FATAL_ERROR "[biergarten] Windows is currently not supported. Please use Linux (Fedora 43) or macOS (M1 Pro).")
@@ -38,15 +40,18 @@ endif()
# 2. Project-wide Settings (Standard & Optimization) # 2. Project-wide Settings (Standard & Optimization)
# ============================================================================= # =============================================================================
# Downgrade to C++20 as per Google Style Guide
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# GCC/Clang specific settings (warnings as errors)
add_compile_options(-Wall -Wextra -Werror -Wpedantic) add_compile_options(-Wall -Wextra -Werror -Wpedantic)
# Release Build Optimization: Aggressive (-O3), Arch-specific, and LTO # Release Build Optimization: Aggressive (-O3), Arch-specific, and LTO
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native -flto") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native -flto")
# Debug Build Optimization: Fast and debuggable (-Og)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og -g") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og -g")
# ============================================================================= # =============================================================================
@@ -101,6 +106,7 @@ set(SOURCES
src/services/wikipedia/fetch_extract.cpp src/services/wikipedia/fetch_extract.cpp
src/web_client/curl_global_state.cpp src/web_client/curl_global_state.cpp
src/web_client/curl_web_client_get.cpp src/web_client/curl_web_client_get.cpp
src/web_client/curl_web_client_utils.cpp
src/web_client/curl_web_client_url_encode.cpp src/web_client/curl_web_client_url_encode.cpp
src/data_generation/llama/llama_generator.cpp src/data_generation/llama/llama_generator.cpp
src/data_generation/llama/generate_brewery.cpp src/data_generation/llama/generate_brewery.cpp
@@ -109,6 +115,7 @@ set(SOURCES
src/data_generation/llama/infer.cpp src/data_generation/llama/infer.cpp
src/data_generation/llama/load.cpp src/data_generation/llama/load.cpp
src/data_generation/llama/load_brewery_prompt.cpp src/data_generation/llama/load_brewery_prompt.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/generate_brewery.cpp src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp src/data_generation/mock/generate_user.cpp
@@ -141,9 +148,3 @@ configure_file(
${CMAKE_BINARY_DIR}/locations.json ${CMAKE_BINARY_DIR}/locations.json
COPYONLY COPYONLY
) )
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_SOURCE_DIR}/prompts
${CMAKE_BINARY_DIR}/prompts
)

View File

@@ -76,7 +76,7 @@ curl -L \
## Run ## Run
Run the executable from the build directory so the copied `locations.json` and `prompts/` directory are available. Run the executable from the build directory so the copied `locations.json` is available.
```bash ```bash
./biergarten-pipeline --mocked ./biergarten-pipeline --mocked

View File

@@ -1,7 +1,7 @@
@startuml BiergartenPipeline @startuml BiergartenPipeline
title Biergarten Pipeline - Class and Composition Diagram title Biergarten Pipeline - Class and Composition Diagram
top to bottom direction left to right direction
skinparam shadowing false skinparam shadowing false
skinparam classAttributeIconSize 0 skinparam classAttributeIconSize 0
skinparam packageStyle rectangle skinparam packageStyle rectangle
@@ -16,17 +16,11 @@ package "Composition root" {
+~CurlGlobalState() +~CurlGlobalState()
} }
class LlamaBackendState {
+LlamaBackendState()
+~LlamaBackendState()
}
note right of Main note right of Main
Binds with Boost.DI: Binds with Boost.DI:
- WebClient -> CURLWebClient - WebClient -> CURLWebClient
- IEnrichmentService -> WikipediaService - IEnrichmentService -> WikipediaService
- DataGenerator -> MockGenerator or LlamaGenerator - DataGenerator -> MockGenerator or LlamaGenerator
- std::string -> model_path
- LlamaGenerator receives ApplicationOptions and model_path directly - LlamaGenerator receives ApplicationOptions and model_path directly
end note end note
} }
@@ -38,8 +32,8 @@ package "Core orchestration" {
-generated_breweries_: std::vector<GeneratedBrewery> -generated_breweries_: std::vector<GeneratedBrewery>
+BiergartenDataGenerator(context_service: std::shared_ptr<IEnrichmentService>, generator: std::unique_ptr<DataGenerator>) +BiergartenDataGenerator(context_service: std::shared_ptr<IEnrichmentService>, generator: std::unique_ptr<DataGenerator>)
+Run(): bool +Run(): bool
{static} -QueryCitiesWithCountries(): std::vector<Location> -QueryCitiesWithCountries(): std::vector<Location>
-GenerateBreweries(cities: const std::vector<EnrichedCity>&): void -GenerateBreweries(cities: std::vector<EnrichedCity>): void
-LogResults(): void -LogResults(): void
} }
} }
@@ -55,6 +49,11 @@ package "Data models" {
+seed: int +seed: int
} }
class BreweryLocation <<struct>> {
+city_name: std::string_view
+country_name: std::string_view
}
class Location <<struct>> { class Location <<struct>> {
+city: std::string +city: std::string
+state_province: std::string +state_province: std::string
@@ -88,64 +87,67 @@ package "Data models" {
package "Generation" { package "Generation" {
interface DataGenerator { interface DataGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult +GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: const std::string&): UserResult +GenerateUser(locale: std::string): UserResult
} }
class MockGenerator { class MockGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult +GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: const std::string&): UserResult +GenerateUser(locale: std::string): UserResult
} }
class LlamaGenerator { class LlamaGenerator {
+LlamaGenerator(options: const ApplicationOptions&, model_path: const std::string&) +LlamaGenerator(options: ApplicationOptions, model_path: std::string)
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult +GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: const std::string&): UserResult +GenerateUser(locale: std::string): UserResult
} }
} }
package "HTTP" { package "HTTP" {
interface WebClient { interface WebClient {
+Get(url: const std::string&): std::string +DownloadToFile(url: std::string, file_path: std::string): void
+UrlEncode(value: const std::string&): std::string +Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
} }
class CURLWebClient { class CURLWebClient {
+Get(url: const std::string&): std::string +CURLWebClient()
+UrlEncode(value: const std::string&): std::string +~CURLWebClient()
+DownloadToFile(url: std::string, file_path: std::string): void
+Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
} }
} }
package "JSON handling" { package "JSON handling" {
class JsonLoader { class JsonLoader {
{static} +LoadLocations(filepath: const std::string&): std::vector<Location> {static} +LoadLocations(filepath: std::string): std::vector<Location>
} }
} }
package "Wikipedia" { package "Wikipedia" {
interface IEnrichmentService { interface IEnrichmentService {
+GetLocationContext(loc: const Location&): std::string +GetLocationContext(loc: Location): std::string
} }
class WikipediaService { class WikipediaService {
+WikipediaService(client: std::unique_ptr<WebClient>) +WikipediaService(client: std::shared_ptr<WebClient>)
+GetLocationContext(loc: const Location&): std::string +GetLocationContext(loc: Location): std::string
} }
} }
Main --> CurlGlobalState Main --> CurlGlobalState
Main --> LlamaBackendState
Main --> ApplicationOptions Main --> ApplicationOptions
Main --> BiergartenDataGenerator Main --> BiergartenDataGenerator
Main ..> IEnrichmentService : DI binding Main ..> IEnrichmentService : DI binding
Main ..> DataGenerator : DI factory Main ..> DataGenerator : DI factory
Main ..> CURLWebClient : DI binding Main ..> CURLWebClient : DI binding
BiergartenDataGenerator *-- EnrichedCity
BiergartenDataGenerator *-- GeneratedBrewery BiergartenDataGenerator *-- GeneratedBrewery
BiergartenDataGenerator ..> JsonLoader : LoadLocations() BiergartenDataGenerator ..> JsonLoader : LoadLocations()
BiergartenDataGenerator --> IEnrichmentService : context lookup BiergartenDataGenerator --> IEnrichmentService : context lookup
BiergartenDataGenerator --> DataGenerator : brewery generation BiergartenDataGenerator --> DataGenerator : brewery generation
BiergartenDataGenerator ..> EnrichedCity
BiergartenDataGenerator ..> Location BiergartenDataGenerator ..> Location
BiergartenDataGenerator ..> BreweryResult BiergartenDataGenerator ..> BreweryResult
@@ -154,7 +156,7 @@ DataGenerator <|.. LlamaGenerator
WebClient <|.. CURLWebClient WebClient <|.. CURLWebClient
IEnrichmentService <|.. WikipediaService IEnrichmentService <|.. WikipediaService
WikipediaService *-- WebClient : unique_ptr WikipediaService --> WebClient : shared_ptr
note right of BiergartenDataGenerator note right of BiergartenDataGenerator
Current behavior: Current behavior:

View File

@@ -7,7 +7,6 @@
*/ */
#include <memory> #include <memory>
#include <span>
#include <vector> #include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
@@ -23,55 +22,55 @@
* It handles location loading, city enrichment, and brewery generation. * It handles location loading, city enrichment, and brewery generation.
*/ */
class BiergartenDataGenerator { class BiergartenDataGenerator {
public: public:
/** /**
* @brief Construct a BiergartenDataGenerator with injected dependencies. * @brief Construct a BiergartenDataGenerator with injected dependencies.
* *
* @param context_service Context provider for sampled locations. * @param context_service Context provider for sampled locations.
* @param generator Brewery and user data generator. * @param generator Brewery and user data generator.
*/ */
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service, BiergartenDataGenerator(std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator); std::unique_ptr<DataGenerator> generator);
/** /**
* @brief Run the data generation pipeline. * @brief Run the data generation pipeline.
* *
* Performs the following steps: * Performs the following steps:
* 1. Load curated locations from JSON * 1. Load curated locations from JSON
* 2. Resolve context for each city using the injected context service * 2. Resolve context for each city using the injected context service
* 3. Generate brewery data for sampled cities * 3. Generate brewery data for sampled cities
* *
* @return true if successful, false if not * @return true if successful, false if not
*/ */
bool Run(); bool Run();
private: private:
/// @brief Owning context provider dependency. /// @brief Shared context provider dependency.
std::unique_ptr<IEnrichmentService> context_service_; std::shared_ptr<IEnrichmentService> context_service_;
/// @brief Generator dependency selected in the composition root. /// @brief Generator dependency selected in the composition root.
std::unique_ptr<DataGenerator> generator_; std::unique_ptr<DataGenerator> generator_;
/** /**
* @brief Load locations from JSON and sample cities. * @brief Load locations from JSON and sample cities.
* *
* @return Vector of sampled locations capped at 4 entries. * @return Vector of sampled locations capped at 4 entries.
*/ */
static std::vector<Location> QueryCitiesWithCountries(); static std::vector<Location> QueryCitiesWithCountries();
/** /**
* @brief Generate breweries for enriched cities. * @brief Generate breweries for enriched cities.
* *
* @param cities Span of enriched city data. * @param cities Vector of enriched city data.
*/ */
void GenerateBreweries(std::span<const EnrichedCity> cities); void GenerateBreweries(const std::vector<EnrichedCity>& cities);
/** /**
* @brief Log the generated brewery results. * @brief Log the generated brewery results.
*/ */
void LogResults() const; void LogResults() const;
/// @brief Stores generated brewery data. /// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generated_breweries_; std::vector<GeneratedBrewery> generated_breweries_;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_BIERGARTEN_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_BIERGARTEN_DATA_GENERATOR_H_

View File

@@ -8,34 +8,35 @@
#include <string> #include <string>
#include "data_model/brewery_location.h"
#include "data_model/brewery_result.h" #include "data_model/brewery_result.h"
#include "data_model/location.h"
#include "data_model/user_result.h" #include "data_model/user_result.h"
/** /**
* @brief Interface for data generator implementations. * @brief Interface for data generator implementations.
*/ */
class DataGenerator { class DataGenerator {
public: public:
virtual ~DataGenerator() = default; /// @brief Virtual destructor for polymorphic cleanup.
virtual ~DataGenerator() = default;
/** /**
* @brief Generates brewery data for a location. * @brief Generates brewery data for a location.
* *
* @param location Location data * @param location City and country names.
* @param region_context Additional regional context text. * @param region_context Additional regional context text.
* @return Brewery generation result. * @return Brewery generation result.
*/ */
virtual BreweryResult GenerateBrewery(const Location& location, virtual BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) = 0; const std::string& region_context) = 0;
/** /**
* @brief Generates a user profile for a locale. * @brief Generates a user profile for a locale.
* *
* @param locale Locale hint used by generator. * @param locale Locale hint used by generator.
* @return User generation result. * @return User generation result.
*/ */
virtual UserResult GenerateUser(const std::string& locale) = 0; virtual UserResult GenerateUser(const std::string& locale) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -16,116 +16,107 @@
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct LlamaSampler;
/** /**
* @brief Data generator implementation backed by llama.cpp. * @brief Data generator implementation backed by llama.cpp.
*/ */
class LlamaGenerator final : public DataGenerator { class LlamaGenerator final : public DataGenerator {
public: public:
/** /**
* @brief Constructs a generator using parsed application options and loads * @brief Constructs a generator using parsed application options and loads
* the configured model immediately. * the configured model immediately.
* *
* @param options Parsed application options. * @param options Parsed application options.
* @param model_path Filesystem path to GGUF model assets. * @param model_path Filesystem path to GGUF model assets.
*/ */
LlamaGenerator(const ApplicationOptions& options, LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path); const std::string& model_path);
/// @brief Releases model/context resources. /// @brief Releases model/context resources.
~LlamaGenerator() override; ~LlamaGenerator() override;
LlamaGenerator(const LlamaGenerator&) = delete; /**
LlamaGenerator& operator=(const LlamaGenerator&) = delete; * @brief Generates brewery data for a specific location.
LlamaGenerator(LlamaGenerator&&) = delete; *
LlamaGenerator& operator=(LlamaGenerator&&) = delete; * @param location City and country names.
* @param region_context Additional regional context.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override;
/** /**
* @brief Generates brewery data for a specific location. * @brief Generates a user profile for the provided locale.
* *
* @param location Location object. * @param locale Locale hint.
* @param region_context Additional regional context. * @return Generated user profile.
* @return Generated brewery result. */
*/ UserResult GenerateUser(const std::string& locale) override;
BreweryResult GenerateBrewery(const Location& location,
const std::string& region_context) override;
/** private:
* @brief Generates a user profile for the provided locale. /**
* * @brief Loads model and prepares inference context.
* @param locale Locale hint. *
* @return Generated user profile. * @param model_path Filesystem path to GGUF model.
*/ */
UserResult GenerateUser(const std::string& locale) override; void Load(const std::string& model_path);
private: /**
static constexpr int kDefaultMaxTokens = 10000; * @brief Infers text from a user prompt.
static constexpr float kDefaultSamplingTopP = 0.95F; *
static constexpr uint32_t kDefaultSamplingTopK = 64; * @param prompt User prompt.
static constexpr uint32_t kDefaultContextSize = 8192; * @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& prompt, int max_tokens = 10000);
struct SamplerState { /**
SamplerState() = default; * @brief Infers text from separate system and user prompts.
~SamplerState(); *
* This helps chat-capable models preserve system-role behavior instead of
* concatenating system text into user input.
*
* @param system_prompt System role prompt.
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens = 10000);
SamplerState(const SamplerState&) = delete; /**
SamplerState& operator=(const SamplerState&) = delete; * @brief Runs inference on an already-formatted prompt.
SamplerState(SamplerState&&) = delete; *
SamplerState& operator=(SamplerState&&) = delete; * @param formatted_prompt Prompt preformatted for model chat template.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = 10000);
LlamaSampler* chain = nullptr; /**
}; * @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text or fallback prompt.
*/
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
/** /**
* @brief Loads model and prepares inference context. * @brief Returns a built-in fallback system prompt.
* *
* @param model_path Filesystem path to GGUF model. * @return Fallback prompt text.
*/ */
void Load(const std::string& model_path); std::string GetFallbackBreweryPrompt();
/** llama_model* model_ = nullptr;
* @brief Infers text from separate system and user prompts. llama_context* context_ = nullptr;
* float sampling_temperature_ = 1.0F;
* This helps chat-capable models preserve system-role behavior instead of float sampling_top_p_ = 0.95F;
* concatenating system text into user input. uint32_t sampling_top_k_ = 64;
* std::mt19937 rng_;
* @param system_prompt System role prompt. uint32_t n_ctx_ = 8192;
* @param prompt User prompt. std::string brewery_system_prompt_;
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt, const std::string& prompt,
int max_tokens = kDefaultMaxTokens);
/**
* @brief Runs inference on an already-formatted prompt.
*
* @param formatted_prompt Prompt preformatted for model chat template.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = kDefaultMaxTokens);
/**
* @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text.
*/
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
llama_model* model_ = nullptr;
llama_context* context_ = nullptr;
/// @brief Persistent sampler chain reused across inference calls.
std::unique_ptr<SamplerState> sampler_;
float sampling_temperature_ = 1.0F;
float sampling_top_p_ = kDefaultSamplingTopP;
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize;
std::string brewery_system_prompt_;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -7,7 +7,6 @@
*/ */
#include <cstddef> #include <cstddef>
#include <optional>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <utility> #include <utility>
@@ -36,6 +35,16 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
std::pair<std::string, std::string> ParseTwoLineResponsePublic( std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message); const std::string& raw, const std::string& error_message);
/**
* @brief Applies model chat template to a user-only prompt.
*
* @param model Loaded llama model.
* @param user_prompt User prompt text.
* @return Model-formatted prompt.
*/
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
/** /**
* @brief Applies model chat template to system and user prompts. * @brief Applies model chat template to system and user prompts.
* *
@@ -64,11 +73,11 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
* @param raw Raw model output. * @param raw Raw model output.
* @param name_out Parsed brewery name. * @param name_out Parsed brewery name.
* @param description_out Parsed brewery description. * @param description_out Parsed brewery description.
* @return Validation error message if invalid, or std::nullopt on success. * @return Empty string on success, or validation error message.
*/ */
std::optional<std::string> ValidateBreweryJsonPublic( std::string ValidateBreweryJsonPublic(const std::string& raw,
const std::string& raw, std::string& name_out, std::string& name_out,
std::string& description_out); std::string& description_out);
/** /**
* @brief Extracts the last balanced JSON object from text. * @brief Extracts the last balanced JSON object from text.
@@ -78,4 +87,4 @@ std::optional<std::string> ValidateBreweryJsonPublic(
*/ */
std::string ExtractLastJsonObjectPublic(const std::string& text); std::string ExtractLastJsonObjectPublic(const std::string& text);
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_

View File

@@ -6,9 +6,9 @@
* @brief Deterministic mock implementation of DataGenerator. * @brief Deterministic mock implementation of DataGenerator.
*/ */
#include <array>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
@@ -16,108 +16,39 @@
* @brief Mock generator used for deterministic, model-free outputs. * @brief Mock generator used for deterministic, model-free outputs.
*/ */
class MockGenerator final : public DataGenerator { class MockGenerator final : public DataGenerator {
public: public:
/** /**
* @brief Generates deterministic brewery data for a location. * @brief Generates deterministic brewery data for a location.
* *
* @param location City and country names. * @param location City and country names.
* @param region_context Unused for mock generation. * @param region_context Unused for mock generation.
* @return Generated brewery result. * @return Generated brewery result.
*/ */
BreweryResult GenerateBrewery(const Location& location, BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override; const std::string& region_context) override;
/** /**
* @brief Generates deterministic user data for a locale. * @brief Generates deterministic user data for a locale.
* *
* @param locale Locale hint. * @param locale Locale hint.
* @return Generated user result. * @return Generated user result.
*/ */
UserResult GenerateUser(const std::string& locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
/** /**
* @brief Combines two strings into a stable hash value. * @brief Combines two strings into a stable hash value.
* *
* @param location City and country names. * @param location City and country names.
* @return Deterministic hash value. * @return Deterministic hash value.
*/ */
static std::size_t DeterministicHash(const Location& location); static std::size_t DeterministicHash(const BreweryLocation& location);
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = { static const std::vector<std::string> kBreweryAdjectives;
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", static const std::vector<std::string> kBreweryNouns;
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel", static const std::vector<std::string> kBreweryDescriptions;
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"}; static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios;
static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18> kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark "
"porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky "
"lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food "
"pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional "
"grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward "
"sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal "
"blends.",
"Modern taproom focused on sessionable lagers and classic pub "
"styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp "
"lagers.",
"Creative brewery offering rotating collaborations and limited "
"draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich "
"stouts."};
static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward "
"styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean "
"fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -13,30 +13,30 @@
* @brief Program options for the Biergarten pipeline application. * @brief Program options for the Biergarten pipeline application.
*/ */
struct ApplicationOptions { struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with /// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked. /// use_mocked.
std::string model_path; std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with /// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path. /// model_path.
bool use_mocked = false; bool use_mocked = false;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 1.0F; float temperature = 1.0F;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random). /// random).
float top_p = 0.95F; float top_p = 0.95F;
/// @brief LLM top-k sampling parameter. /// @brief LLM top-k sampling parameter.
uint32_t top_k = 64; uint32_t top_k = 64;
/// @brief Context window size (tokens) for LLM inference. Higher values /// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory. /// support longer prompts but use more memory.
uint32_t n_ctx = 8192; uint32_t n_ctx = 2048;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_APPLICATION_OPTIONS_H_

View File

@@ -12,11 +12,11 @@
* @brief Non-owning brewery location input. * @brief Non-owning brewery location input.
*/ */
struct BreweryLocation { struct BreweryLocation {
/// @brief City name. /// @brief City name.
std::string_view city_name; std::string_view city_name;
/// @brief Country name. /// @brief Country name.
std::string_view country_name; std::string_view country_name;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_LOCATION_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated brewery payload. * @brief Generated brewery payload.
*/ */
struct BreweryResult { struct BreweryResult {
/// @brief Brewery display name. /// @brief Brewery display name.
std::string name{}; std::string name;
/// @brief Brewery description text. /// @brief Brewery description text.
std::string description{}; std::string description;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_

View File

@@ -14,8 +14,8 @@
* @brief Enriched city data with Wikipedia context. * @brief Enriched city data with Wikipedia context.
*/ */
struct EnrichedCity { struct EnrichedCity {
Location location; Location location;
std::string region_context{}; std::string region_context;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_

View File

@@ -13,8 +13,8 @@
* @brief Helper struct to store generated brewery data. * @brief Helper struct to store generated brewery data.
*/ */
struct GeneratedBrewery { struct GeneratedBrewery {
Location location; Location location;
BreweryResult brewery; BreweryResult brewery;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_GENERATED_BREWERY_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_GENERATED_BREWERY_H_

View File

@@ -12,26 +12,26 @@
* @brief Canonical location record for city-level generation. * @brief Canonical location record for city-level generation.
*/ */
struct Location { struct Location {
/// @brief City name. /// @brief City name.
std::string city{}; std::string city;
/// @brief State or province name. /// @brief State or province name.
std::string state_province{}; std::string state_province;
/// @brief ISO 3166-2 subdivision code. /// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{}; std::string iso3166_2;
/// @brief Country name. /// @brief Country name.
std::string country{}; std::string country;
/// @brief ISO 3166-1 country code. /// @brief ISO 3166-1 country code.
std::string iso3166_1{}; std::string iso3166_1;
/// @brief Latitude in decimal degrees. /// @brief Latitude in decimal degrees.
double latitude{}; double latitude;
/// @brief Longitude in decimal degrees. /// @brief Longitude in decimal degrees.
double longitude{}; double longitude;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_

View File

@@ -12,11 +12,11 @@
* @brief Generated user profile payload. * @brief Generated user profile payload.
*/ */
struct UserResult { struct UserResult {
/// @brief Username handle. /// @brief Username handle.
std::string username{}; std::string username;
/// @brief Short user biography. /// @brief Short user biography.
std::string bio{}; std::string bio;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_

View File

@@ -6,17 +6,16 @@
* @brief Loader API for curated location data. * @brief Loader API for curated location data.
*/ */
#include <filesystem> #include <string>
#include <vector> #include <vector>
#include "data_model/location.h" #include "data_model/location.h"
/// @brief Loads curated world locations from a JSON file into memory. /// @brief Loads curated world locations from a JSON file into memory.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON array file and returns all location records. /// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations( static std::vector<Location> LoadLocations(const std::string& filepath);
const std::filesystem::path& filepath);
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -15,18 +15,18 @@
* it alive for application lifetime. * it alive for application lifetime.
*/ */
class LlamaBackendState { class LlamaBackendState {
public: public:
/// @brief Initializes global llama backend state. /// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); } LlamaBackendState() { llama_backend_init(); }
/// @brief Cleans up global llama backend state. /// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); } ~LlamaBackendState() { llama_backend_free(); }
/// @brief Non-copyable type. /// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete; LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type. /// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete; LlamaBackendState& operator=(const LlamaBackendState&) = delete;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_LLAMA_BACKEND_STATE_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_LLAMA_BACKEND_STATE_H_

View File

@@ -14,17 +14,17 @@
* @brief Interface for services that can enrich a location with context. * @brief Interface for services that can enrich a location with context.
*/ */
class IEnrichmentService { class IEnrichmentService {
public: public:
/// @brief Virtual destructor for polymorphic cleanup. /// @brief Virtual destructor for polymorphic cleanup.
virtual ~IEnrichmentService() = default; virtual ~IEnrichmentService() = default;
/** /**
* @brief Resolves contextual enrichment for a location. * @brief Resolves contextual enrichment for a location.
* *
* @param loc Location to enrich. * @param loc Location to enrich.
* @return Context text, or an empty string if unavailable. * @return Context text, or an empty string if unavailable.
*/ */
virtual std::string GetLocationContext(const Location& loc) = 0; virtual std::string GetLocationContext(const Location& loc) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_SERVICE_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_SERVICE_H_

View File

@@ -14,20 +14,20 @@
#include "services/enrichment_service.h" #include "services/enrichment_service.h"
#include "web_client/web_client.h" #include "web_client/web_client.h"
/// @brief Provides Wikipedia summary lookups backed by cached raw extracts. /// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService final : public IEnrichmentService { class WikipediaService final : public IEnrichmentService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::unique_ptr<WebClient> client); explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia-derived context for a location. /// @brief Returns the Wikipedia-derived context for a location.
[[nodiscard]] std::string GetLocationContext(const Location& loc) override; [[nodiscard]] std::string GetLocationContext(const Location& loc) override;
private: private:
std::string FetchExtract(std::string_view query); std::string FetchExtract(std::string_view query);
std::unique_ptr<WebClient> client_; std::shared_ptr<WebClient> client_;
/// @brief Canonical cache for raw Wikipedia query extracts. std::unordered_map<std::string, std::string> cache_;
std::unordered_map<std::string, std::string> extract_cache_; std::unordered_map<std::string, std::string> extract_cache_;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_WIKIPEDIA_SERVICE_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_WIKIPEDIA_SERVICE_H_

View File

@@ -15,40 +15,40 @@
* alive for application lifetime. * alive for application lifetime.
*/ */
class CurlGlobalState { class CurlGlobalState {
public: public:
/// @brief Initializes global libcurl state. /// @brief Initializes global libcurl state.
CurlGlobalState(); CurlGlobalState();
/// @brief Cleans up global libcurl state. /// @brief Cleans up global libcurl state.
~CurlGlobalState(); ~CurlGlobalState();
/// @brief Non-copyable type. /// @brief Non-copyable type.
CurlGlobalState(const CurlGlobalState&) = delete; CurlGlobalState(const CurlGlobalState&) = delete;
/// @brief Non-copyable type. /// @brief Non-copyable type.
CurlGlobalState& operator=(const CurlGlobalState&) = delete; CurlGlobalState& operator=(const CurlGlobalState&) = delete;
}; };
/** /**
* @brief WebClient implementation backed by libcurl. * @brief WebClient implementation backed by libcurl.
*/ */
class CURLWebClient : public WebClient { class CURLWebClient : public WebClient {
public: public:
/** /**
* @brief Executes an HTTP GET request. * @brief Executes an HTTP GET request.
* *
* @param url Request URL. * @param url Request URL.
* @return Response body. * @return Response body.
*/ */
std::string Get(const std::string& url) override; std::string Get(const std::string& url) override;
/** /**
* @brief URL-encodes a string value. * @brief URL-encodes a string value.
* *
* @param value Raw value. * @param value Raw value.
* @return URL-encoded string. * @return URL-encoded string.
*/ */
std::string UrlEncode(const std::string& value) override; std::string UrlEncode(const std::string& value) override;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_CURL_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -12,25 +12,25 @@
* @brief Abstract web client interface. * @brief Abstract web client interface.
*/ */
class WebClient { class WebClient {
public: public:
/// @brief Virtual destructor for polymorphic cleanup. /// @brief Virtual destructor for polymorphic cleanup.
virtual ~WebClient() = default; virtual ~WebClient() = default;
/** /**
* @brief Executes an HTTP GET request. * @brief Executes an HTTP GET request.
* *
* @param url Request URL. * @param url Request URL.
* @return Response body. * @return Response body.
*/ */
virtual std::string Get(const std::string& url) = 0; virtual std::string Get(const std::string& url) = 0;
/** /**
* @brief URL-encodes a string value. * @brief URL-encodes a string value.
* *
* @param value Raw string value. * @param value Raw string value.
* @return Encoded value safe for URL usage. * @return Encoded value safe for URL usage.
*/ */
virtual std::string UrlEncode(const std::string& value) = 0; virtual std::string UrlEncode(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -8,7 +8,7 @@
#include <utility> #include <utility>
BiergartenDataGenerator::BiergartenDataGenerator( BiergartenDataGenerator::BiergartenDataGenerator(
std::unique_ptr<IEnrichmentService> context_service, std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator) std::unique_ptr<DataGenerator> generator)
: context_service_(std::move(context_service)), : context_service_(std::move(context_service)),
generator_(std::move(generator)) {} generator_(std::move(generator)) {}

View File

@@ -8,32 +8,34 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
void BiergartenDataGenerator::GenerateBreweries( void BiergartenDataGenerator::GenerateBreweries(
std::span<const EnrichedCity> cities) { const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generated_breweries_.clear();
generated_breweries_.clear(); size_t skipped_count = 0;
size_t skipped_count = 0;
for (const auto& [location, region_context] : cities) { for (const auto& enriched_city : cities) {
try { try {
const BreweryResult brewery = auto brewery = generator_->GenerateBrewery(
generator_->GenerateBrewery(location, region_context); BreweryLocation{enriched_city.location.city,
enriched_city.location.country},
const GeneratedBrewery gen{.location = location, .brewery = brewery}; enriched_city.region_context);
generated_breweries_.push_back(GeneratedBrewery{
generated_breweries_.push_back(gen); .location = enriched_city.location, .brewery = brewery});
} catch (const std::exception& e) { } catch (const std::exception& e) {
++skipped_count; ++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}",
enriched_city.location.city, enriched_city.location.country,
e.what());
}
}
if (skipped_count > 0) {
spdlog::warn( spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: " "[Pipeline] Skipped {} city/cities due to generation "
"{}", "errors",
location.city, location.country, e.what()); skipped_count);
} }
}
if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors",
skipped_count);
}
} }

View File

@@ -8,16 +8,16 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
void BiergartenDataGenerator::LogResults() const { void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ==="); spdlog::info("\n=== GENERATED DATA DUMP ===");
size_t index = 1; size_t index = 1;
for (const auto& [location, brewery] : generated_breweries_) { for (const auto& [location, brewery] : generated_breweries_) {
spdlog::info( spdlog::info(
"{}. city=\"{}\" country=\"{}\" state=\"{}\" " "{}. city=\"{}\" country=\"{}\" state=\"{}\" "
"iso3166_2={} lat={} lon={}", "iso3166_2={} lat={} lon={}",
index, location.city, location.country, location.state_province, index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude); location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name=\"{}\"", brewery.name); spdlog::info(" brewery_name=\"{}\"", brewery.name);
spdlog::info(" brewery_description=\"{}\"", brewery.description); spdlog::info(" brewery_description=\"{}\"", brewery.description);
++index; ++index;
} }
} }

View File

@@ -13,28 +13,28 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include "json_handling/json_loader.h" #include "json_handling/json_loader.h"
static constexpr std::size_t kBreweryAmount = 4; static constexpr unsigned int brewery_amount = 4;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() { std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
const std::filesystem::path locations_path = "locations.json"; const std::filesystem::path locations_path = "locations.json";
auto all_locations = JsonLoader::LoadLocations(locations_path); auto all_locations = JsonLoader::LoadLocations(locations_path.string());
spdlog::info(" Locations available: {}", all_locations.size()); spdlog::info(" Locations available: {}", all_locations.size());
const std::size_t sample_count = const size_t sample_count =
std::min(kBreweryAmount, all_locations.size()); std::min<size_t>(brewery_amount, all_locations.size());
const auto sample_count_signed = const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>( static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count); sample_count);
std::vector<Location> sampled_locations; std::vector<Location> sampled_locations;
sampled_locations.reserve(sample_count); sampled_locations.reserve(sample_count);
std::random_device random_generator; std::random_device random_generator;
std::ranges::sample(all_locations, std::back_inserter(sampled_locations), std::ranges::sample(all_locations, std::back_inserter(sampled_locations),
sample_count_signed, random_generator); sample_count_signed, random_generator);
spdlog::info(" Sampled locations: {}", sampled_locations.size()); spdlog::info(" Sampled locations: {}", sampled_locations.size());
return sampled_locations; return sampled_locations;
} }

View File

@@ -8,40 +8,40 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
bool BiergartenDataGenerator::Run() { bool BiergartenDataGenerator::Run() {
try { try {
const std::vector<Location> cities = QueryCitiesWithCountries(); const std::vector<Location> cities = QueryCitiesWithCountries();
std::vector<EnrichedCity> enriched; std::vector<EnrichedCity> enriched;
enriched.reserve(cities.size()); enriched.reserve(cities.size());
size_t skipped_count = 0; size_t skipped_count = 0;
for (const auto& city : cities) { for (const auto& city : cities) {
try { try {
const std::string region_context = const std::string region_context =
context_service_->GetLocationContext(city); context_service_->GetLocationContext(city);
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}", spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context); city.city, city.country, region_context);
enriched.push_back( enriched.push_back(EnrichedCity{.location = city,
EnrichedCity{.location = city, .region_context = region_context}); .region_context = region_context});
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
++skipped_count; ++skipped_count;
spdlog::warn( spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): context lookup failed: {}", "[Pipeline] Skipping city '{}' ({}): context lookup failed: {}",
city.city, city.country, exception.what()); city.city, city.country, exception.what());
}
} }
}
if (skipped_count > 0) { if (skipped_count > 0) {
spdlog::warn( spdlog::warn(
"[Pipeline] Skipped {} city/cities due to context lookup errors", "[Pipeline] Skipped {} city/cities due to context lookup errors",
skipped_count); skipped_count);
} }
this->GenerateBreweries(enriched); this->GenerateBreweries(enriched);
this->LogResults(); this->LogResults();
return true; return true;
} catch (const std::exception& e) { } catch (const std::exception& e) {
spdlog::error("Pipeline execution failed with error: {}", e.what()); spdlog::error("Pipeline execution failed with error: {}", e.what());
return false; return false;
} }
} }

View File

@@ -7,143 +7,154 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <array> #include <array>
#include <format>
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) { namespace {
auto trim = [](const std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
const std::size_t last = text.find_last_not_of(" \t\n\r"); std::string ExtractFinalJsonPayload(std::string raw_response) {
return text.substr(first, last - first + 1); auto trim = [](std::string_view text) -> std::string_view {
}; const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
static constexpr std::array<std::string_view, 6> separator_tokens = { const std::size_t last = text.find_last_not_of(" \t\n\r");
"<|think|>", "<think|>", "<|turn|>", return text.substr(first, last - first + 1);
"<turn|>", "<channel|>", "<|channel|>"}; };
std::size_t separator_pos = std::string::npos; static const std::array<std::string_view, 6> separator_tokens = {
std::size_t separator_length = 0; "<|think|>", "<think|>", "<|turn|>",
for (const std::string_view token : separator_tokens) { "<turn|>", "<channel|>", "<|channel|>"};
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
if (separator_pos != std::string::npos) { std::size_t separator_pos = std::string::npos;
raw_response.erase(0, separator_pos + separator_length); std::size_t separator_length = 0;
} for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
const std::string_view trimmed = trim(raw_response); if (separator_pos != std::string::npos) {
const std::string json_candidate = raw_response.erase(0, separator_pos + separator_length);
ExtractLastJsonObjectPublic(std::string(trimmed)); }
if (!json_candidate.empty()) { const std::string_view trimmed = trim(raw_response);
return ExtractLastJsonObjectPublic(std::string(trimmed)); std::string json_candidate =
} ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
return std::string(trimmed); return std::string(trimmed);
} }
} // namespace
BreweryResult LlamaGenerator::GenerateBrewery( BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) { const BreweryLocation& location, const std::string& region_context) {
/** /**
* Preprocess and truncate region context to manageable size * Preprocess and truncate region context to manageable size
*/ */
const std::string safe_region_context = const std::string safe_region_context =
PrepareRegionContextPublic(region_context); PrepareRegionContextPublic(region_context);
const std::string country_suffix = /**
location.country.empty() ? std::string{} * Load brewery system prompt from file
: std::format(", {}", location.country); * Falls back to minimal inline prompt if file not found
const std::string region_suffix = */
safe_region_context.empty() const std::string system_prompt =
? "." LoadBrewerySystemPrompt("prompts/system.md");
: std::format(". Regional context: {}", safe_region_context);
/** /**
* Load brewery system prompt from file * User prompt: provides geographic context to guide generation towards
* Falls back to minimal inline prompt if file not found * culturally appropriate and locally-inspired brewery attributes
*/ */
const std::string system_prompt = std::string prompt =
LoadBrewerySystemPrompt("prompts/system.md"); "Write a brewery name and place-specific long description for a craft "
"brewery in ";
prompt.append(location.city_name);
if (!location.country_name.empty()) {
prompt.append(", ");
prompt.append(location.country_name);
}
if (safe_region_context.empty()) {
prompt.append(".");
} else {
prompt.append(". Regional context: ");
prompt.append(safe_region_context);
}
/** /**
* User prompt: provides geographic context to guide generation towards * Store location context for retry prompts (without repeating full context)
* culturally relevant and locally-inspired brewery attributes */
*/ std::string retry_location = "Location: ";
std::string prompt = std::format( retry_location.append(location.city_name);
"Write a brewery name and place-specific long description for a craft " if (!location.country_name.empty()) {
"brewery in {}{}{}", retry_location.append(", ");
location.city, country_suffix, region_suffix); retry_location.append(location.country_name);
}
/** /**
* Store location context for retry prompts (without repeating full context) * RETRY LOOP with validation and error correction
*/ * Attempts to generate valid brewery data up to 3 times, with feedback-based
const std::string retry_location = * refinement
std::format("Location: {}{}", location.city, country_suffix); */
const int max_attempts = 3;
std::string raw;
std::string last_error;
/** // Limit output length to keep it concise and focused
* RETRY LOOP with validation and error correction constexpr int max_tokens = 1052;
* Attempts to generate valid brewery data up to 3 times, with feedback-based for (int attempt = 0; attempt < max_attempts; ++attempt) {
* refinement // Generate brewery data from LLM
*/ raw = Infer(system_prompt, prompt, max_tokens);
constexpr int max_attempts = 3; spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
std::string raw; raw);
std::string last_error;
// Limit output length to keep it concise and focused // Validate output: parse JSON and check required fields
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
// Validate output: parse JSON and check required fields std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::string validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (validation_error.empty()) {
// Success: return parsed brewery data
return {std::move(name), std::move(description)};
}
std::string name; // Validation failed: log error and prepare corrective feedback
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
}
// Validation failed: log error and prepare corrective feedback last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
last_error = *validation_error; // Update prompt with error details to guide LLM toward correct output.
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", // For retries, use a compact prompt format to avoid exceeding token
attempt + 1, *validation_error); // limits.
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with exactly these keys: "
"{\"name\": \"<brewery name>\", "
"\"description\": \"<single-paragraph description>\"}."
"\nDo not include markdown, comments, extra keys, or literal "
"placeholder values.";
prompt += "\n\n";
prompt += retry_location;
}
// Update prompt with error details to guide LLM toward correct output. // All retry attempts exhausted: log failure and throw exception
prompt = std::format( spdlog::error(
R"(Your previous response was invalid. Error: {} "LlamaGenerator: malformed brewery response after {} attempts: "
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}. "{}",
Do not include markdown, comments, extra keys, or literal placeholder values. max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
{})",
*validation_error, retry_location);
}
// All retry attempts exhausted: log failure and throw exception
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
} }

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@@ -13,6 +14,87 @@
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) { UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user", /**
.bio = "This is a test user profile from " + locale + "."}; * System prompt: specifies exact output format to minimize parsing errors
* Constraints: 2-line output, username format, bio length bounds
*/
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
/**
* User prompt: locale parameter guides cultural appropriateness of generated
* profiles
*/
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
/**
* RETRY LOOP with format validation
* Attempts up to 3 times to generate valid user profile with correct format
*/
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
/**
* Generate user profile (max 128 tokens - should fit 2 lines easily)
*/
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
/**
* Parse two-line response: first line = username, second line = bio
*/
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
/**
* Remove any whitespace from username (usernames shouldn't have
* spaces)
*/
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
/**
* Validate both fields are non-empty after processing
*/
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
/**
* Truncate bio if exceeds reasonable length for bio field
*/
if (bio.size() > 200) bio = bio.substr(0, 200);
/**
* Success: return parsed user profile
*/
return {username, bio};
} catch (const std::exception& e) {
/**
* Parsing failed: log and continue to next attempt
*/
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
/**
* All retry attempts exhausted: log failure and throw exception
*/
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
} }

View File

@@ -4,17 +4,13 @@
* parsing, token decoding, and JSON validation helpers for Llama modules. * parsing, token decoding, and JSON validation helpers for Llama modules.
*/ */
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <cctype> #include <cctype>
#include <optional>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <string_view>
#include <vector> #include <vector>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
@@ -23,42 +19,40 @@
/** /**
* String trimming: removes leading and trailing whitespace * String trimming: removes leading and trailing whitespace
*/ */
static std::string Trim(std::string_view value) { static std::string Trim(std::string value) {
constexpr std::string_view whitespace = " \t\n\r\f\v"; auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
const std::size_t last_index = value.find_last_not_of(whitespace); value.erase(value.begin(),
return std::string(value.substr(first_index, last_index - first_index + 1)); std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
} }
/** /**
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single * Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces * spaces
*/ */
static std::string CondenseWhitespace(std::string_view text) { static std::string CondenseWhitespace(std::string text) {
std::string out; std::string out;
out.reserve(text.size()); out.reserve(text.size());
bool pending_space = false; bool in_whitespace = false;
for (const unsigned char chr : text) { for (unsigned char ch : text) {
if (std::isspace(chr) != 0) { if (std::isspace(ch)) {
if (!out.empty()) { if (!in_whitespace) {
pending_space = true; out.push_back(' ');
in_whitespace = true;
}
continue;
} }
continue;
}
if (pending_space) { in_whitespace = false;
out.push_back(' '); out.push_back(static_cast<char>(ch));
pending_space = false; }
}
out.push_back(static_cast<char>(chr));
}
return out; return Trim(std::move(out));
} }
/** /**
@@ -66,286 +60,386 @@ static std::string CondenseWhitespace(std::string_view text) {
* boundaries * boundaries
*/ */
static std::string PrepareRegionContext(std::string_view region_context, static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) { std::size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context); std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) { if (normalized.size() <= max_chars) {
return normalized; return normalized;
} }
normalized.resize(max_chars); normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' '); const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) { if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space); normalized.resize(last_space);
} }
normalized += "..."; normalized += "...";
return normalized; return normalized;
} }
static std::string ToChatPrompt(const llama_model* model, /**
const std::string& system_prompt, * Remove common bullet points, numbers, and field labels added by LLM in output
const std::string& user_prompt) { */
std::string combined_prompt; static std::string StripCommonPrefix(std::string line) {
combined_prompt.append(system_prompt); line = Trim(std::move(line));
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
const char* tmpl = llama_model_chat_template(model, nullptr); if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
if (tmpl == nullptr) { line = Trim(line.substr(1));
// No template found, fallback to raw text } else {
spdlog::warn( std::size_t i = 0;
"LlamaGenerator: missing chat template; using raw prompt fallback"); while (i < line.size() &&
return combined_prompt; std::isdigit(static_cast<unsigned char>(line[i]))) {
} ++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
const std::array<llama_chat_message, 2> messages = { auto strip_label = [&line](const std::string& label) {
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}}; if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
std::vector<char> buffer(std::max<std::size_t>( strip_label("name:");
1024, (system_prompt.size() + user_prompt.size()) * 4)); strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages, return Trim(std::move(line));
int32_t message_count) -> int32_t { }
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (result < 0) { /**
return result; * Parse two-line response from LLM: normalize line endings, strip formatting,
} * filter spurious output, and combine remaining lines if needed
*/
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
if (result >= static_cast<int32_t>(buffer.size())) { std::vector<std::string> lines;
buffer.resize(static_cast<std::size_t>(result) + 1); std::stringstream stream(normalized);
result = llama_chat_apply_template(tmpl, chat_messages, message_count, std::string line;
true, buffer.data(), while (std::getline(stream, line)) {
static_cast<int32_t>(buffer.size())); line = StripCommonPrefix(std::move(line));
} if (!line.empty()) lines.push_back(std::move(line));
}
return result; std::vector<std::string> filtered;
}; for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
int32_t template_result = apply_template_with_resize(messages.data(), 2); if (filtered.size() < 2) throw std::runtime_error(error_message);
if (template_result >= 0) { std::string first = Trim(filtered.front());
return {buffer.data(), static_cast<std::size_t>(template_result)}; std::string second;
} for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
spdlog::warn( if (first.empty() || second.empty()) throw std::runtime_error(error_message);
"LlamaGenerator: chat template rejected system/user messages (result " return {first, second};
"{}); trying single user fallback", }
template_result); std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
return system_prompt + "\n\n" + user_prompt;
}
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role), const std::array<llama_chat_message, 2> messages = {
// combine the system and user prompts into a single "user" message. {{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1); std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
// Ultimate fallback: if GGUF template parsing still fails, use raw text. int32_t required =
if (template_result < 0) { llama_chat_apply_template(tmpl, messages.data(), 2, true, buffer.data(),
spdlog::warn( static_cast<int32_t>(buffer.size()));
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
return combined_prompt;
}
return {buffer.data(), static_cast<std::size_t>(template_result)}; // FALLBACK: If the template fails (e.g., Gemma rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
if (required < 0) {
std::string combined_prompt = system_prompt + "\n\n" + user_prompt;
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
required = llama_chat_apply_template(tmpl, fallback_msg.data(), 1, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// THE FIX: Ultimate fallback. If the GGUF's internal template is
// completely unparseable (which happens with complex Jinja macros),
// degrade gracefully to raw text instead of throwing a runtime_error.
if (required < 0) {
return combined_prompt;
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(
tmpl, fallback_msg.data(), 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
return combined_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
// Standard buffer resize if the original "system" + "user" array succeeded
// but needed more space
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages.data(), 2, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// Final safety net on resize
if (required < 0) {
return system_prompt + "\n\n" + user_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
} }
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token, static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
std::array<char, 256> buffer{}; std::array<char, 256> buffer{};
int32_t bytes = int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true); llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) { if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes)); std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0, static_cast<int32_t>(dynamic_buffer.size()),
true); 0, true);
if (bytes < 0) { if (bytes < 0) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece"); "LlamaGenerator: failed to decode sampled token piece");
} }
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes)); output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return; return;
} }
output.append(buffer.data(), static_cast<std::size_t>(bytes)); output.append(buffer.data(), static_cast<std::size_t>(bytes));
} }
static bool ExtractLastJsonObject(const std::string& text, static bool ExtractLastJsonObject(const std::string& text,
std::string& json_out) { std::string& json_out) {
std::size_t start = std::string::npos; std::size_t start = std::string::npos;
int depth = 0; int depth = 0;
bool in_string = false; bool in_string = false;
bool escaped = false; bool escaped = false;
bool found = false; bool found = false;
std::string candidate; std::string candidate;
for (std::size_t i = 0; i < text.size(); ++i) { for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i]; const char ch = text[i];
if (in_string) { if (in_string) {
if (escaped) { if (escaped) {
escaped = false; escaped = false;
} else if (ch == '\\') { } else if (ch == '\\') {
escaped = true; escaped = true;
} else if (ch == '"') { } else if (ch == '"') {
in_string = false; in_string = false;
}
continue;
} }
continue;
}
if (ch == '"') { if (ch == '"') {
in_string = true; in_string = true;
continue; continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
} }
++depth;
continue;
}
if (ch == '}') { if (ch == '{') {
if (depth == 0) { if (depth == 0) {
continue; start = i;
}
++depth;
continue;
} }
--depth;
if (depth == 0 && start != std::string::npos) { if (ch == '}') {
candidate = text.substr(start, i - start + 1); if (depth == 0) {
found = true; continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
candidate = text.substr(start, i - start + 1);
found = true;
}
} }
} }
}
if (!found) { if (!found) {
return false; return false;
} }
json_out = std::move(candidate); json_out = std::move(candidate);
return true; return true;
} }
std::string ExtractLastJsonObjectPublic(const std::string& text) { std::string ExtractLastJsonObjectPublic(const std::string& text) {
std::string extracted; std::string extracted;
if (ExtractLastJsonObject(text, extracted)) { if (ExtractLastJsonObject(text, extracted)) {
return extracted; return extracted;
} }
return {}; return {};
} }
static std::optional<std::string> ValidateBreweryJson( static std::string ValidateBreweryJson(const std::string& raw,
const std::string& raw, std::string& name_out, std::string& name_out,
std::string& description_out) { std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv, auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool { std::string& error_out) -> bool {
if (!jv.is_object()) { if (!jv.is_object()) {
error_out = "JSON root must be an object"; error_out = "JSON root must be an object";
return false; return false;
} }
const auto& obj = jv.get_object(); const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) { if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string"; error_out = "JSON field 'name' is missing or not a string";
return false; return false;
} }
if (!obj.contains("description") || !obj.at("description").is_string()) { if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string"; error_out = "JSON field 'description' is missing or not a string";
return false; return false;
} }
const auto& name_value = obj.at("name").as_string(); name_out = Trim(std::string(obj.at("name").as_string().c_str()));
const auto& description_value = obj.at("description").as_string(); description_out =
name_out = Trim(std::string_view(name_value.data(), name_value.size())); Trim(std::string(obj.at("description").as_string().c_str()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
if (name_out.empty()) { if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty"; error_out = "JSON field 'name' must not be empty";
return false; return false;
} }
if (description_out.empty()) { if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty"; error_out = "JSON field 'description' must not be empty";
return false; return false;
} }
std::string name_lower = name_out; std::string name_lower = name_out;
std::string description_lower = description_out; std::string description_lower = description_out;
std::transform( std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(), name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); }); [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(), std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) { description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (name_lower == "string" || description_lower == "string") { if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content"; error_out = "JSON appears to be a schema placeholder, not content";
return false; return false;
} }
error_out.clear(); error_out.clear();
return true; return true;
}; };
boost::system::error_code ec; boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec); boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error; std::string validation_error;
if (ec) { if (ec) {
std::string extracted; std::string extracted;
if (!ExtractLastJsonObject(raw, extracted)) { if (!ExtractLastJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message(); return "JSON parse error: " + ec.message();
} }
ec.clear(); ec.clear();
jv = boost::json::parse(extracted, ec); jv = boost::json::parse(extracted, ec);
if (ec) { if (ec) {
return "JSON parse error: " + ec.message(); return "JSON parse error: " + ec.message();
} }
if (!validate_object(jv, validation_error)) { if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error; return validation_error;
} }
return std::nullopt; return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
} }
// Forward declarations for helper functions exposed to other translation units // Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context, std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) { std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars); return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt, "");
} }
std::string ToChatPromptPublic(const llama_model* model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt); return ToChatPrompt(model, system_prompt, user_prompt);
} }
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token, void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
AppendTokenPiece(vocab, token, output); AppendTokenPiece(vocab, token, output);
} }
std::optional<std::string> ValidateBreweryJsonPublic( std::string ValidateBreweryJsonPublic(const std::string& raw,
const std::string& raw, std::string& name_out, std::string& name_out,
std::string& description_out) { std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out); return ValidateBreweryJson(raw, name_out, description_out);
} }

View File

@@ -2,7 +2,7 @@
* Text Generation / Inference Module * Text Generation / Inference Module
* Core module that performs LLM inference: converts text prompts into tokens, * Core module that performs LLM inference: converts text prompts into tokens,
* runs the neural network forward pass, samples the next token, and converts * runs the neural network forward pass, samples the next token, and converts
* output tokens back to text for system+user chat prompts. * output tokens back to text. Supports both simple and system+user prompts.
*/ */
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
@@ -17,156 +17,182 @@
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
#include "llama.h" #include "llama.h"
static constexpr std::size_t kPromptTokenSlack = 8; std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt, std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, const std::string& prompt, int max_tokens) {
const int max_tokens) { return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt), max_tokens);
max_tokens);
} }
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
const int max_tokens) { int max_tokens) {
/** /**
* Validate that model and context are loaded * Validate that model and context are loaded
*/ */
if (model_ == nullptr || context_ == nullptr) { if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded"); throw std::runtime_error("LlamaGenerator: model not loaded");
}
/** /**
* Get vocabulary for tokenization and token-to-text conversion * Get vocabulary for tokenization and token-to-text conversion
*/ */
const llama_vocab* vocab = llama_model_get_vocab(model_); const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) { if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable"); throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
/** /**
* Clear KV cache to ensure clean inference state (no residual context) * Clear KV cache to ensure clean inference state (no residual context)
*/ */
llama_memory_clear(llama_get_memory(context_), true); llama_memory_clear(llama_get_memory(context_), true);
/** /**
* TOKENIZATION PHASE * TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands * Convert text prompt into token IDs (integers) that the model understands
*/ */
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
kPromptTokenSlack); int32_t token_count = llama_tokenize(
int32_t token_count = llama_tokenize( vocab, formatted_prompt.c_str(),
vocab, formatted_prompt.c_str(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()), true, true);
static_cast<int32_t>(prompt_tokens.size()), true, true);
/** /**
* If buffer too small, negative return indicates required size * If buffer too small, negative return indicates required size
*/ */
if (token_count < 0) { if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count)); prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize( token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true); static_cast<int32_t>(prompt_tokens.size()), true, true);
} }
if (token_count < 0) { if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
/** /**
* CONTEXT SIZE VALIDATION * CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window * Validate and compute effective token budgets based on context window
* constraints * constraints
*/ */
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_)); const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_)); const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) { if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
/** /**
* Clamp generation limit to available context window, reserve space for * Clamp generation limit to available context window, reserve space for
* output * output
*/ */
const int32_t effective_max_tokens = const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1)); std::max(1, std::min(max_tokens, n_ctx - 1));
/** /**
* Prompt can use remaining context after reserving space for generation * Prompt can use remaining context after reserving space for generation
*/ */
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget); prompt_budget = std::max<int32_t>(1, prompt_budget);
/** /**
* Truncate prompt if necessary to fit within constraints * Truncate prompt if necessary to fit within constraints
*/ */
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits", "tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget); token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget; token_count = prompt_budget;
} }
/** /**
* PROMPT PROCESSING PHASE * PROMPT PROCESSING PHASE
* Create a batch containing all prompt tokens and feed through the model * Create a batch containing all prompt tokens and feed through the model
* This computes internal representations and fills the KV cache * This computes internal representations and fills the KV cache
*/ */
const llama_batch prompt_batch = llama_batch_get_one( const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size())); prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) { if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed"); throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
/** /**
* TOKEN GENERATION LOOP * SAMPLER CONFIGURATION PHASE
* Iteratively generate tokens one at a time until max_tokens or * Set up the probabilistic token selection pipeline (sampler chain)
* end-of-sequence * Samplers are applied in sequence: temperature -> top-k -> top-p ->
*/ * distribution
std::vector<llama_token> generated_tokens; */
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens)); llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
if (sampler_ == nullptr || sampler_->chain == nullptr) { /**
throw std::runtime_error("LlamaGenerator: sampler not initialized"); * Temperature: scales logits before softmax (controls randomness)
} */
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
/**
* Top-K: limits sampling to the most likely tokens before nucleus
* sampling
*/
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
/**
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
* probability
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
/**
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
for (int i = 0; i < effective_max_tokens; ++i) { /**
/** * TOKEN GENERATION LOOP
* Sample next token using configured sampler chain and model logits * Iteratively generate tokens one at a time until max_tokens or
* Index -1 means use the last output position from previous batch * end-of-sequence
*/ */
const llama_token next = std::vector<llama_token> generated_tokens;
llama_sampler_sample(sampler_->chain, context_, -1); generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) {
break;
}
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token decode_token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
if (llama_decode(context_, one_token_batch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
/** for (int i = 0; i < effective_max_tokens; ++i) {
* DETOKENIZATION PHASE /**
* Convert generated token IDs back to text using vocabulary * Sample next token using configured sampler chain and model logits
*/ * Index -1 means use the last output position from previous batch
std::string output; */
for (const llama_token token : generated_tokens) { const llama_token next =
AppendTokenPiecePublic(vocab, token, output); llama_sampler_sample(sampler.get(), context_, -1);
} /**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
return output; /**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
return output;
} }

View File

@@ -5,7 +5,6 @@
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include <memory>
#include <random> #include <random>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@@ -13,113 +12,65 @@
#include "data_model/application_options.h" #include "data_model/application_options.h"
#include "llama.h" #include "llama.h"
static constexpr uint32_t kMaxContextSize = 32768U;
struct SamplerConfig {
float temperature;
float top_p;
uint32_t top_k;
};
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
std::mt19937& rng) {
const llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(config.temperature));
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(config.top_p, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng()));
return sampler;
}
LlamaGenerator::SamplerState::~SamplerState() {
if (chain != nullptr) {
llama_sampler_free(chain);
chain = nullptr;
}
}
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) const std::string& model_path)
: rng_(std::random_device{}()) { : rng_(std::random_device{}()) {
if (model_path.empty()) { if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty"); throw std::runtime_error("LlamaGenerator: model path must not be empty");
} }
if (options.temperature < 0.0F) { if (options.temperature < 0.0F) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0"); "LlamaGenerator: sampling temperature must be >= 0");
} }
if (options.top_p <= 0.0F || options.top_p > 1.0F) { if (options.top_p <= 0.0F || options.top_p > 1.0F) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]"); "LlamaGenerator: sampling top-p must be in (0, 1]");
} }
if (options.top_k == 0U) { if (options.top_k == 0U) {
throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0"); throw std::runtime_error("LlamaGenerator: sampling top-k must be > 0");
} }
if (options.seed < -1) { if (options.seed < -1) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random"); "LlamaGenerator: seed must be >= 0, or -1 for random");
} }
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) { if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]"); "LlamaGenerator: context size must be in range [1, 32768]");
} }
sampling_temperature_ = options.temperature; sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p; sampling_top_p_ = options.top_p;
sampling_top_k_ = options.top_k; sampling_top_k_ = options.top_k;
if (options.seed == -1) { if (options.seed == -1) {
std::random_device random_device; std::random_device random_device;
rng_.seed(random_device()); rng_.seed(random_device());
} else { } else {
rng_.seed(static_cast<uint32_t>(options.seed)); rng_.seed(static_cast<uint32_t>(options.seed));
} }
n_ctx_ = options.n_ctx; n_ctx_ = options.n_ctx;
this->Load(model_path); this->Load(model_path);
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
sampling_top_k_};
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
sampler_.reset(new SamplerState());
sampler_->chain = sampler_chain.release();
} }
LlamaGenerator::~LlamaGenerator() { LlamaGenerator::~LlamaGenerator() {
sampler_.reset(); /**
* Free the inference context (contains KV cache and computation state)
*/
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
/** /**
* Free the inference context (contains KV cache and computation state) * Free the loaded model (contains weights and vocabulary)
*/ */
if (context_ != nullptr) { if (model_ != nullptr) {
llama_free(context_); llama_model_free(model_);
context_ = nullptr; model_ = nullptr;
} }
/**
* Free the loaded model (contains weights and vocabulary)
*/
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
} }

View File

@@ -14,32 +14,32 @@
#include "llama.h" #include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) { void LlamaGenerator::Load(const std::string& model_path) {
if (context_ != nullptr) { if (context_ != nullptr) {
llama_free(context_); llama_free(context_);
context_ = nullptr; context_ = nullptr;
} }
if (model_ != nullptr) { if (model_ != nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
} }
const llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params); model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) { if (model_ == nullptr) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path); "LlamaGenerator: failed to load model from path: " + model_path);
} }
llama_context_params context_params = llama_context_default_params(); llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_; context_params.n_ctx = n_ctx_;
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000)); context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(5000));
context_ = llama_init_from_model(model_, context_params); context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) { if (context_ == nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context"); throw std::runtime_error("LlamaGenerator: failed to create context");
} }
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
} }

View File

@@ -1,14 +1,13 @@
/** /**
* @file data_generation/llama/load_brewery_prompt.cpp * @file data_generation/llama/load_brewery_prompt.cpp
* @brief Resolves brewery system prompt content from cache or a configured * @brief Resolves brewery system prompt content from cache or filesystem
* filesystem path and provides a robust inline fallback prompt when absent. * search paths and provides a robust inline fallback prompt when absent.
*/ */
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
@@ -18,42 +17,81 @@ namespace fs = std::filesystem;
* @brief Loads brewery system prompt from disk or cache. * @brief Loads brewery system prompt from disk or cache.
* *
* @param prompt_file_path Preferred prompt file location. * @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk. * @return Prompt text loaded from disk or fallback content.
*/ */
std::string LlamaGenerator::LoadBrewerySystemPrompt( std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) { const std::string& prompt_file_path) {
// Return cached version if already loaded // Return cached version if already loaded
if (!brewery_system_prompt_.empty()) { if (!brewery_system_prompt_.empty()) {
return brewery_system_prompt_; return brewery_system_prompt_;
} }
// Try the provided path only // Try multiple path locations
const fs::path prompt_path(prompt_file_path); std::vector<std::string> paths_to_try = {
std::ifstream prompt_file(prompt_path); prompt_file_path, // As provided
if (!prompt_file.is_open()) { "../" + prompt_file_path, // One level up
spdlog::error( "../../" + prompt_file_path, // Two levels up
"LlamaGenerator: Failed to open brewery system prompt file '{}'", };
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
const std::string prompt((std::istreambuf_iterator(prompt_file)), for (const auto& path : paths_to_try) {
std::istreambuf_iterator<char>()); std::ifstream prompt_file(path);
prompt_file.close(); if (prompt_file.is_open()) {
std::string prompt((std::istreambuf_iterator<char>(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (prompt.empty()) { if (!prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty", spdlog::info(
prompt_path.string()); "LlamaGenerator: Loaded brewery system prompt from '{}' ({} "
throw std::runtime_error( "chars)",
"LlamaGenerator: empty brewery system prompt file: " + path, prompt.length());
prompt_path.string()); brewery_system_prompt_ = prompt;
} return brewery_system_prompt_;
}
}
}
spdlog::info( spdlog::warn(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)", "LlamaGenerator: Could not open brewery system prompt file at any of "
prompt_path.string(), prompt.length()); "the "
brewery_system_prompt_ = prompt; "expected locations. Using fallback inline prompt.");
return brewery_system_prompt_; return GetFallbackBreweryPrompt();
} }
/**
* @brief Provides an inline fallback brewery system prompt.
*
* @return Default fallback prompt text.
*/
std::string LlamaGenerator::GetFallbackBreweryPrompt() {
return "You are an experienced brewmaster and owner of a local craft "
"brewery. "
"Create a distinctive, authentic name and detailed description that "
"genuinely reflects your specific location, brewing philosophy, "
"local "
"culture, and community connection. The brewery must feel real and "
"grounded—not generic or interchangeable.\n\n"
"AVOID REPETITIVE PHRASES - Never use:\n"
"Love letter to, tribute to, rolling hills, picturesque, every sip "
"tells a story, Come for X stay for Y, rich history, passion, woven "
"into, ancient roots, timeless, where tradition meets innovation\n\n"
"OPENING APPROACHES - Choose ONE:\n"
"1. Start with specific beer style and its regional origins\n"
"2. Begin with specific brewing challenge (water, altitude, "
"climate)\n"
"3. Open with founding story or personal motivation\n"
"4. Lead with specific local ingredient or resource\n"
"5. Start with unexpected angle or contradiction\n"
"6. Open with local event, tradition, or cultural moment\n"
"7. Begin with tangible architectural or geographic detail\n\n"
"BE SPECIFIC - Include:\n"
"- At least ONE concrete proper noun (landmark, river, "
"neighborhood)\n"
"- Specific beer styles relevant to the REGION'S culture\n"
"- Concrete brewing challenges or advantages\n"
"- Sensory details SPECIFIC to place—not generic adjectives\n\n"
"LENGTH: 150-250 words. TONE: Can be soulful, irreverent, "
"matter-of-fact, unpretentious, or minimalist.\n\n"
"Output ONLY a raw JSON object with keys name and description. "
"No markdown, backticks, preamble, or trailing text.";
}

View File

@@ -0,0 +1,71 @@
/**
* @file data_generation/mock/data.cpp
* @brief Defines static lookup tables used by MockGenerator for deterministic
* brewery names, descriptions, usernames, and bios.
*/
#include <string>
#include <vector>
#include "data_generation/mock_generator.h"
const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
const std::vector<std::string> MockGenerator::kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
const std::vector<std::string> MockGenerator::kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal blends.",
"Modern taproom focused on sessionable lagers and classic pub styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp lagers.",
"Creative brewery offering rotating collaborations and limited draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich stouts."};
const std::vector<std::string> MockGenerator::kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
const std::vector<std::string> MockGenerator::kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};

View File

@@ -5,12 +5,13 @@
*/ */
#include <boost/container_hash/hash.hpp> #include <boost/container_hash/hash.hpp>
#include <string>
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
size_t MockGenerator::DeterministicHash(const Location& location) { std::size_t MockGenerator::DeterministicHash(const BreweryLocation& location) {
size_t seed = 0; std::size_t seed = 0;
boost::hash_combine(seed, location.city); boost::hash_combine(seed, location.city_name);
boost::hash_combine(seed, location.country); boost::hash_combine(seed, location.country_name);
return seed; return seed;
} }

View File

@@ -4,39 +4,35 @@
* and country into fixed mock phrase catalogs. * and country into fixed mock phrase catalogs.
*/ */
#include <format>
#include <string> #include <string>
#include <string_view>
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery( BreweryResult MockGenerator::GenerateBrewery(
const Location& location, const std::string& /*region_context*/) { const BreweryLocation& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location); const std::size_t hash = DeterministicHash(location);
const std::string_view adjective = const std::string& adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size()); kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun = const std::string& noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size()); kBreweryNouns.at((hash / 7) % kBreweryNouns.size());
const std::string_view base_description = const std::string& base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size()); kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string name = std::string name(location.city_name);
std::format("{} {} {}", location.city, adjective, noun); name.append(" ");
name.append(adjective);
name.append(" ");
name.append(noun);
const std::string state_suffix = std::string description = base_description;
location.state_province.empty() description.append(" Based in ");
? std::string{} description.append(location.city_name);
: std::format(", {}", location.state_province); if (!location.country_name.empty()) {
const std::string country_suffix = description.append(", ");
location.country.empty() ? std::string{} description.append(location.country_name);
: std::format(", {}", location.country); }
const std::string description = description.append(".");
std::format("{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix);
return { return {name, description};
.name = name,
.description = description,
};
} }

View File

@@ -6,17 +6,14 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <string_view>
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) { UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale); const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result; UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()]; result.username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()]; result.bio = kBios[(hash / 11) % kBios.size()];
result.username = username; return result;
result.bio = bio;
return result;
} }

View File

@@ -12,76 +12,72 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string_view>
static std::string ReadRequiredString(const boost::json::object& object, static std::string ReadRequiredString(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) { if (value == nullptr || !value->is_string()) {
throw std::runtime_error(std::string("Missing or invalid string field: ") + throw std::runtime_error(
key); std::string("Missing or invalid string field: ") + key);
} }
const std::string_view text = value->as_string(); return std::string(value->as_string().c_str());
return std::string(text);
} }
static double ReadRequiredNumber(const boost::json::object& object, static double ReadRequiredNumber(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) { if (value == nullptr || !value->is_number()) {
throw std::runtime_error(std::string("Missing or invalid numeric field: ") +
key);
}
return value->to_number<double>();
}
std::vector<Location> JsonLoader::LoadLocations(
const std::filesystem::path& filepath) {
std::ifstream input(filepath);
if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " +
filepath.string());
}
std::stringstream buffer;
buffer << input.rdbuf();
const std::string content = buffer.str();
boost::system::error_code error;
boost::json::value root = boost::json::parse(content, error);
if (error) {
throw std::runtime_error("Failed to parse locations JSON: " +
error.message());
}
if (!root.is_array()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
for (const auto& item : items) {
if (!item.is_object()) {
throw std::runtime_error( throw std::runtime_error(
"Invalid locations JSON: each entry must be an object"); std::string("Missing or invalid numeric field: ") + key);
} }
return value->to_number<double>();
const auto& object = item.as_object(); }
locations.push_back(Location{
.city = ReadRequiredString(object, "city"), std::vector<Location> JsonLoader::LoadLocations(const std::string& filepath) {
.state_province = ReadRequiredString(object, "state_province"), std::ifstream input(filepath);
.iso3166_2 = ReadRequiredString(object, "iso3166_2"), if (!input.is_open()) {
.country = ReadRequiredString(object, "country"), throw std::runtime_error("Failed to open locations file: " + filepath);
.iso3166_1 = ReadRequiredString(object, "iso3166_1"), }
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"), std::stringstream buffer;
}); buffer << input.rdbuf();
} const std::string content = buffer.str();
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(), boost::system::error_code error;
filepath.string()); boost::json::value root = boost::json::parse(content, error);
return locations; if (error) {
throw std::runtime_error("Failed to parse locations JSON: " +
error.message());
}
if (!root.is_array()) {
throw std::runtime_error(
"Invalid locations JSON: root element must be an array");
}
std::vector<Location> locations;
const auto& items = root.as_array();
locations.reserve(items.size());
for (const auto& item : items) {
if (!item.is_object()) {
throw std::runtime_error(
"Invalid locations JSON: each entry must be an object");
}
const auto& object = item.as_object();
locations.push_back(Location{
.city = ReadRequiredString(object, "city"),
.state_province = ReadRequiredString(object, "state_province"),
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"),
});
}
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath);
return locations;
} }

View File

@@ -10,7 +10,6 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include <exception> #include <exception>
#include <memory> #include <memory>
#include <optional>
#include <sstream> #include <sstream>
#include <string> #include <string>
@@ -31,153 +30,141 @@ namespace di = boost::di;
* *
* @param argc Command-line argument count. * @param argc Command-line argument count.
* @param argv Command-line arguments. * @param argv Command-line arguments.
* @return Parsed ApplicationOptions if parsing succeeded, std::nullopt * @param options Output ApplicationOptions struct.
* otherwise. * @return true if parsing succeeded and should proceed, false otherwise.
*/ */
std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) { bool ParseArguments(const int argc, char** argv,
prog_opts::options_description desc("Pipeline Options"); ApplicationOptions& options) noexcept {
prog_opts::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)")(
"top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)")(
"n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)")(
"seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
auto opt = desc.add_options(); // Handle the "no arguments" or "help" case
if (argc == 1) {
spdlog::info("Biergarten Pipeline");
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
return false;
}
opt("help,h", "Produce help message"); try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
opt("mocked", prog_opts::bool_switch(), if (variables_map.contains("help")) {
"Use mocked generator for brewery/user data"); std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return false;
}
opt("model,m", prog_opts::value<std::string>()->default_value(""), const auto use_mocked = variables_map["mocked"].as<bool>();
"Path to LLM model (gguf)"); const auto model_path = variables_map["model"].as<std::string>();
opt("temperature", prog_opts::value<float>()->default_value(1.0F), if (use_mocked && !model_path.empty()) {
"Sampling temperature (higher = more random)"); spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return false;
}
opt("top-p", prog_opts::value<float>()->default_value(0.95F), if (!use_mocked && model_path.empty()) {
"Nucleus sampling top-p in (0,1] (higher = more random)"); spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return false;
}
opt("top-k", prog_opts::value<uint32_t>()->default_value(64), const bool has_llm_params = !variables_map["temperature"].defaulted() ||
"Top-k sampling parameter (higher = more candidate tokens)"); !variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted();
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192), if (use_mocked && has_llm_params) {
"Context window size in tokens (1-32768)"); spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
opt("seed", prog_opts::value<int>()->default_value(-1), options.use_mocked = use_mocked;
"Sampler seed: -1 for random, otherwise non-negative integer"); options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
// Handle the "no arguments" or "help" case return true;
if (argc == 1) { } catch (const std::exception& exception) {
spdlog::info("Biergarten Pipeline"); spdlog::error("Failed to parse command-line arguments: {}",
std::stringstream usage_stream; exception.what());
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc; return false;
spdlog::info(usage_stream.str()); } catch (...) {
return std::nullopt; spdlog::error("Failed to parse command-line arguments: unknown error");
} return false;
}
try {
prog_opts::variables_map variables_map;
prog_opts::store(prog_opts::parse_command_line(argc, argv, desc),
variables_map);
prog_opts::notify(variables_map);
if (variables_map.contains("help")) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt;
}
const auto use_mocked = variables_map["mocked"].as<bool>();
const auto model_path = variables_map["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt;
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt;
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
if (use_mocked && has_llm_params) {
spdlog::warn(
"Sampling parameters (--temperature, --top-p, --top-k, --seed) are"
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
options.top_p = variables_map["top-p"].as<float>();
options.top_k = variables_map["top-k"].as<uint32_t>();
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return std::nullopt;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
return std::nullopt;
}
} }
int main(const int argc, char** argv) { int main(const int argc, char** argv) noexcept {
try { try {
const CurlGlobalState curl_state; const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state; const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v"); spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
const auto parsed_options = ParseArguments(argc, argv); ApplicationOptions options;
if (!parsed_options.has_value()) { if (!ParseArguments(argc, argv, options)) {
return 0; return 0;
} }
const auto options = *parsed_options; const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
const auto injector = di::make_injector( di::bind<ApplicationOptions>().to(options),
di::bind<WebClient>().to<CURLWebClient>(), di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<ApplicationOptions>().to(options), di::bind<std::string>().to(options.model_path),
di::bind<IEnrichmentService>().to<WikipediaService>(), di::bind<DataGenerator>().to([options](const auto& injector)
di::bind<std::string>().to(options.model_path), -> std::unique_ptr<DataGenerator> {
di::bind<DataGenerator>().to( if (options.use_mocked) {
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
spdlog::info( spdlog::info(
"[Generator] Using MockGenerator (no model path provided)"); "[Generator] Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>(); return std::make_unique<MockGenerator>();
} }
spdlog::info( spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, " "[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})", "top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p, options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed); options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>(); return injector.template create<std::unique_ptr<LlamaGenerator>>();
})); }));
auto generator = injector.create<BiergartenDataGenerator>(); auto generator = injector.create<BiergartenDataGenerator>();
if (!generator.Run()) { if (!generator.Run()) {
spdlog::error("Pipeline execution failed"); spdlog::error("Pipeline execution failed");
return 1;
}
spdlog::info("Pipeline executed successfully");
return 0;
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1; return 1;
} } catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
spdlog::info("Pipeline executed successfully"); return 1;
return 0; }
} catch (const std::exception& exception) {
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
return 1;
} catch (...) {
spdlog::critical("Unhandled fatal non-standard exception in main");
return 1;
}
} }

View File

@@ -12,50 +12,47 @@
#include "services/wikipedia_service.h" #include "services/wikipedia_service.h"
std::string WikipediaService::FetchExtract(std::string_view query) { std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string cache_key(query); const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key); const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) { if (cache_it != this->extract_cache_.end()) {
return cache_it->second; return cache_it->second;
} }
const std::string encoded = this->client_->UrlEncode(cache_key); const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url = const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json"; "&prop=extracts&explaintext=1&format=json";
const std::string body = this->client_->Get(url); const std::string body = this->client_->Get(url);
boost::system::error_code parse_error; boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error); boost::json::value doc = boost::json::parse(body, parse_error);
if (!parse_error && doc.is_object()) { if (!parse_error && doc.is_object()) {
try { try {
auto& pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto& page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string(); std::string extract(page.at("extract").as_string().c_str());
std::string extract(extract_view); spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
spdlog::debug("WikipediaService fetched {} chars for '{}'", this->extract_cache_.emplace(cache_key, extract);
extract.size(), query); return extract;
}
this->extract_cache_.emplace(cache_key, extract); }
return extract; this->extract_cache_.emplace(cache_key, std::string{});
} } catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
} }
this->extract_cache_.emplace(cache_key, std::string{}); } else if (parse_error) {
} catch (const std::exception& e) { spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
spdlog::warn( parse_error.message());
"WikipediaService: failed to parse response structure for '{}': " }
"{}",
query, e.what());
return {};
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
return {}; return {};
} }

View File

@@ -10,38 +10,47 @@
#include "services/wikipedia_service.h" #include "services/wikipedia_service.h"
std::string WikipediaService::GetLocationContext(const Location& loc) { std::string WikipediaService::GetLocationContext(const Location& loc) {
if (!client_) { const std::string cache_key = loc.city + "|" + loc.country;
return {}; const auto cache_it = cache_.find(cache_key);
} if (cache_it != cache_.end()) {
return cache_it->second;
}
std::string result; std::string result;
std::string region_query(loc.city); if (!client_) {
if (!loc.country.empty()) { cache_.emplace(cache_key, result);
region_query += ", "; return result;
region_query += loc.country; }
}
const std::string beer_query = "beer in " + loc.country; std::string region_query(loc.city);
const std::string city_beer_query = "beer in " + loc.city; if (!loc.country.empty()) {
region_query += ", ";
region_query += loc.country;
}
auto append_extract = [&result](const std::string& extract) -> void { const std::string beer_query = "beer in " + loc.country;
if (extract.empty()) { const std::string city_beer_query = "beer in " + loc.city;
return;
}
if (!result.empty()) {
result += "\n\n";
}
result += extract;
};
try { auto append_extract = [&result](const std::string& extract) -> void {
append_extract(FetchExtract(region_query)); if (extract.empty()) {
append_extract(FetchExtract(beer_query)); return;
append_extract(FetchExtract(city_beer_query)); }
} catch (const std::runtime_error& e) { if (!result.empty()) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query, result += "\n\n";
e.what()); }
} result += extract;
return result; };
try {
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());
}
cache_.emplace(cache_key, result);
return result;
} }

View File

@@ -7,5 +7,5 @@
#include <utility> #include <utility>
WikipediaService::WikipediaService(std::unique_ptr<WebClient> client) WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}

View File

@@ -10,10 +10,10 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
CurlGlobalState::CurlGlobalState() { CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally"); "[CURLWebClient] Failed to initialize libcurl globally");
} }
} }
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); } CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }

View File

@@ -5,71 +5,46 @@
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdint> #include <sstream>
#include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url) {
constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, const size_t size, static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
const size_t nmemb, void* userp) { void* userp) {
const size_t real_size = size * nmemb; size_t realsize = size * nmemb;
auto* str = static_cast<std::string*>(userp); auto* s = static_cast<std::string*>(userp);
str->append(static_cast<char*>(contents), real_size); s->append(static_cast<char*>(contents), realsize);
return real_size; return realsize;
} }
std::string CURLWebClient::Get(const std::string& url) { std::string CURLWebClient::Get(const std::string& url) {
const CurlHandle curl = create_handle(); auto curl = create_handle();
std::string response_string; std::string response_string;
set_common_get_options(curl.get(), url, {10L, 20L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
set_common_get_options(curl.get(), url); CURLcode res = curl_easy_perform(curl.get());
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); if (res != CURLE_OK) {
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
CURLcode res = curl_easy_perform(curl.get()); long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (res != CURLE_OK) { if (httpCode != 200) {
const auto error = std::stringstream ss;
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(error); throw std::runtime_error(ss.str());
} }
int64_t httpCode = 0; return response_string;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); }
if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " +
std::to_string(httpCode) + " for URL " + url;
throw std::runtime_error(error);
}
return response_string;
}

View File

@@ -11,14 +11,13 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
std::string CURLWebClient::UrlEncode(const std::string& value) { std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (!output) { if (output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); std::string result(output);
} curl_free(output);
return result;
std::string result(output); }
curl_free(output); throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
return result; }
}

View File

@@ -0,0 +1,28 @@
/**
* @file web_client/curl_web_client_utils.cpp
* @brief Shared CURLWebClient helper implementations.
*/
#include "curl_web_client_utils.h"
#include <stdexcept>
CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}

View File

@@ -0,0 +1,26 @@
#ifndef BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
#define BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
/**
* @file web_client/curl_web_client_utils.h
* @brief Shared helpers for CURLWebClient request setup.
*/
#include <curl/curl.h>
#include <memory>
#include <string>
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
struct CurlTimeouts {
long connect_timeout;
long total_timeout;
};
CurlHandle create_handle();
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts);
#endif // BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_