8 Commits

Author SHA1 Message Date
Aaron Po
299a767d39 remove unused code 2026-04-11 14:42:32 -04:00
Aaron Po
f07d48f810 Add missing includes, update readme 2026-04-11 14:31:24 -04:00
Aaron Po
bcfde856fe Split data models into dedicated headers 2026-04-11 13:21:50 -04:00
Aaron Po
5946356083 Style audit: update code to strictly follow Google Style Guide 2026-04-11 11:56:45 -04:00
Aaron Po
ae67fa8566 refactor: consolidate and rename data generation and service files 2026-04-11 00:06:23 -04:00
Aaron Po
8c572a2d07 fix: stabilize Gemma 4 brewery generation
remove misleading turn-token output guidance from the brewery prompt
extract the last balanced JSON object before validation
keep README model setup and run instructions aligned
preserve Gemma 4 sampling defaults and local model usage
2026-04-10 22:25:26 -04:00
Aaron Po
902bda6eb9 eat: make Gemma 4 the default model, enable thinking mode 2026-04-10 21:43:18 -04:00
Aaron Po
61d5077a95 update readme 2026-04-10 00:03:45 -04:00
50 changed files with 1744 additions and 1528 deletions

View File

@@ -1,5 +1,5 @@
---
BasedOnStyle: Google
ColumnLimit: 80
IndentWidth: 2
IndentWidth: 3
...

View File

@@ -9,23 +9,21 @@ Checks: >
-google-runtime-references
CheckOptions:
# Enforce Google Naming Conventions with valid clang-tidy strings
- key: readability-identifier-naming.ClassCase
value: CamelCase
# Enforce Google Naming Conventions
- key: readability-identifier-naming.ClassMemberCase
value: lower_case
value: snake_case
- key: readability-identifier-naming.ClassMemberSuffix
value: _
- key: readability-identifier-naming.ClassCase
value: PascalCase
- key: readability-identifier-naming.FunctionCase
value: CamelCase
value: PascalCase
- key: readability-identifier-naming.StructCase
value: CamelCase
value: PascalCase
- key: readability-identifier-naming.VariableCase
value: lower_case
value: snake_case
- key: readability-identifier-naming.GlobalConstantCase
value: CamelCase
- key: readability-identifier-naming.GlobalConstantPrefix
value: k
value: kPascalCase
# Ensure C++20 Modernization
- key: modernize-make-unique.MakeSmartPtrFunction

1
pipeline/.gitignore vendored
View File

@@ -1,6 +1,5 @@
dist
build
cmake-build-*
data
models
*.gguf

View File

@@ -1,10 +1,12 @@
cmake_minimum_required(VERSION 3.24)
project(biergarten-pipeline)
# Boost.DI still declares a very old minimum CMake version, which newer CMake
# releases reject unless a policy version floor is provided.
set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "" FORCE)
# =============================================================================
# 1. Platform & GPU Detection
# 1. Platform & GPU Detection (Windows explicitly NOT supported)
# =============================================================================
if(WIN32)
message(FATAL_ERROR "[biergarten] Windows is currently not supported. Please use Linux (Fedora 43) or macOS (M1 Pro).")
@@ -38,15 +40,18 @@ endif()
# 2. Project-wide Settings (Standard & Optimization)
# =============================================================================
# Downgrade to C++20 as per Google Style Guide
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# GCC/Clang specific settings (warnings as errors)
add_compile_options(-Wall -Wextra -Werror -Wpedantic)
# Release Build Optimization: Aggressive (-O3), Arch-specific, and LTO
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -march=native -flto")
# Debug Build Optimization: Fast and debuggable (-Og)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Og -g")
# =============================================================================
@@ -101,6 +106,7 @@ set(SOURCES
src/services/wikipedia/fetch_extract.cpp
src/web_client/curl_global_state.cpp
src/web_client/curl_web_client_get.cpp
src/web_client/curl_web_client_utils.cpp
src/web_client/curl_web_client_url_encode.cpp
src/data_generation/llama/llama_generator.cpp
src/data_generation/llama/generate_brewery.cpp
@@ -109,6 +115,7 @@ set(SOURCES
src/data_generation/llama/infer.cpp
src/data_generation/llama/load.cpp
src/data_generation/llama/load_brewery_prompt.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
@@ -141,9 +148,3 @@ configure_file(
${CMAKE_BINARY_DIR}/locations.json
COPYONLY
)
add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_SOURCE_DIR}/prompts
${CMAKE_BINARY_DIR}/prompts
)

View File

@@ -76,7 +76,7 @@ curl -L \
## Run
Run the executable from the build directory so the copied `locations.json` and `prompts/` directory are available.
Run the executable from the build directory so the copied `locations.json` is available.
```bash
./biergarten-pipeline --mocked

View File

@@ -1,7 +1,7 @@
@startuml BiergartenPipeline
title Biergarten Pipeline - Class and Composition Diagram
top to bottom direction
left to right direction
skinparam shadowing false
skinparam classAttributeIconSize 0
skinparam packageStyle rectangle
@@ -16,17 +16,11 @@ package "Composition root" {
+~CurlGlobalState()
}
class LlamaBackendState {
+LlamaBackendState()
+~LlamaBackendState()
}
note right of Main
Binds with Boost.DI:
- WebClient -> CURLWebClient
- IEnrichmentService -> WikipediaService
- DataGenerator -> MockGenerator or LlamaGenerator
- std::string -> model_path
- LlamaGenerator receives ApplicationOptions and model_path directly
end note
}
@@ -38,8 +32,8 @@ package "Core orchestration" {
-generated_breweries_: std::vector<GeneratedBrewery>
+BiergartenDataGenerator(context_service: std::shared_ptr<IEnrichmentService>, generator: std::unique_ptr<DataGenerator>)
+Run(): bool
{static} -QueryCitiesWithCountries(): std::vector<Location>
-GenerateBreweries(cities: const std::vector<EnrichedCity>&): void
-QueryCitiesWithCountries(): std::vector<Location>
-GenerateBreweries(cities: std::vector<EnrichedCity>): void
-LogResults(): void
}
}
@@ -55,6 +49,11 @@ package "Data models" {
+seed: int
}
class BreweryLocation <<struct>> {
+city_name: std::string_view
+country_name: std::string_view
}
class Location <<struct>> {
+city: std::string
+state_province: std::string
@@ -88,64 +87,67 @@ package "Data models" {
package "Generation" {
interface DataGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult
+GenerateUser(locale: const std::string&): UserResult
+GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: std::string): UserResult
}
class MockGenerator {
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult
+GenerateUser(locale: const std::string&): UserResult
+GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: std::string): UserResult
}
class LlamaGenerator {
+LlamaGenerator(options: const ApplicationOptions&, model_path: const std::string&)
+GenerateBrewery(location: const Location&, region_context: const std::string&): BreweryResult
+GenerateUser(locale: const std::string&): UserResult
+LlamaGenerator(options: ApplicationOptions, model_path: std::string)
+GenerateBrewery(location: BreweryLocation, region_context: std::string): BreweryResult
+GenerateUser(locale: std::string): UserResult
}
}
package "HTTP" {
interface WebClient {
+Get(url: const std::string&): std::string
+UrlEncode(value: const std::string&): std::string
+DownloadToFile(url: std::string, file_path: std::string): void
+Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
}
class CURLWebClient {
+Get(url: const std::string&): std::string
+UrlEncode(value: const std::string&): std::string
+CURLWebClient()
+~CURLWebClient()
+DownloadToFile(url: std::string, file_path: std::string): void
+Get(url: std::string): std::string
+UrlEncode(value: std::string): std::string
}
}
package "JSON handling" {
class JsonLoader {
{static} +LoadLocations(filepath: const std::string&): std::vector<Location>
{static} +LoadLocations(filepath: std::string): std::vector<Location>
}
}
package "Wikipedia" {
interface IEnrichmentService {
+GetLocationContext(loc: const Location&): std::string
+GetLocationContext(loc: Location): std::string
}
class WikipediaService {
+WikipediaService(client: std::unique_ptr<WebClient>)
+GetLocationContext(loc: const Location&): std::string
+WikipediaService(client: std::shared_ptr<WebClient>)
+GetLocationContext(loc: Location): std::string
}
}
Main --> CurlGlobalState
Main --> LlamaBackendState
Main --> ApplicationOptions
Main --> BiergartenDataGenerator
Main ..> IEnrichmentService : DI binding
Main ..> DataGenerator : DI factory
Main ..> CURLWebClient : DI binding
BiergartenDataGenerator *-- EnrichedCity
BiergartenDataGenerator *-- GeneratedBrewery
BiergartenDataGenerator ..> JsonLoader : LoadLocations()
BiergartenDataGenerator --> IEnrichmentService : context lookup
BiergartenDataGenerator --> DataGenerator : brewery generation
BiergartenDataGenerator ..> EnrichedCity
BiergartenDataGenerator ..> Location
BiergartenDataGenerator ..> BreweryResult
@@ -154,7 +156,7 @@ DataGenerator <|.. LlamaGenerator
WebClient <|.. CURLWebClient
IEnrichmentService <|.. WikipediaService
WikipediaService *-- WebClient : unique_ptr
WikipediaService --> WebClient : shared_ptr
note right of BiergartenDataGenerator
Current behavior:

View File

@@ -7,7 +7,6 @@
*/
#include <memory>
#include <span>
#include <vector>
#include "data_generation/data_generator.h"
@@ -30,7 +29,7 @@ class BiergartenDataGenerator {
* @param context_service Context provider for sampled locations.
* @param generator Brewery and user data generator.
*/
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service,
BiergartenDataGenerator(std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator);
/**
@@ -46,8 +45,8 @@ class BiergartenDataGenerator {
bool Run();
private:
/// @brief Owning context provider dependency.
std::unique_ptr<IEnrichmentService> context_service_;
/// @brief Shared context provider dependency.
std::shared_ptr<IEnrichmentService> context_service_;
/// @brief Generator dependency selected in the composition root.
std::unique_ptr<DataGenerator> generator_;
@@ -62,9 +61,9 @@ class BiergartenDataGenerator {
/**
* @brief Generate breweries for enriched cities.
*
* @param cities Span of enriched city data.
* @param cities Vector of enriched city data.
*/
void GenerateBreweries(std::span<const EnrichedCity> cities);
void GenerateBreweries(const std::vector<EnrichedCity>& cities);
/**
* @brief Log the generated brewery results.

View File

@@ -8,8 +8,8 @@
#include <string>
#include "data_model/brewery_location.h"
#include "data_model/brewery_result.h"
#include "data_model/location.h"
#include "data_model/user_result.h"
/**
@@ -17,16 +17,17 @@
*/
class DataGenerator {
public:
/// @brief Virtual destructor for polymorphic cleanup.
virtual ~DataGenerator() = default;
/**
* @brief Generates brewery data for a location.
*
* @param location Location data
* @param location City and country names.
* @param region_context Additional regional context text.
* @return Brewery generation result.
*/
virtual BreweryResult GenerateBrewery(const Location& location,
virtual BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) = 0;
/**

View File

@@ -16,7 +16,6 @@
struct llama_model;
struct llama_context;
struct LlamaSampler;
/**
* @brief Data generator implementation backed by llama.cpp.
@@ -36,19 +35,14 @@ class LlamaGenerator final : public DataGenerator {
/// @brief Releases model/context resources.
~LlamaGenerator() override;
LlamaGenerator(const LlamaGenerator&) = delete;
LlamaGenerator& operator=(const LlamaGenerator&) = delete;
LlamaGenerator(LlamaGenerator&&) = delete;
LlamaGenerator& operator=(LlamaGenerator&&) = delete;
/**
* @brief Generates brewery data for a specific location.
*
* @param location Location object.
* @param location City and country names.
* @param region_context Additional regional context.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const Location& location,
BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override;
/**
@@ -60,23 +54,6 @@ class LlamaGenerator final : public DataGenerator {
UserResult GenerateUser(const std::string& locale) override;
private:
static constexpr int kDefaultMaxTokens = 10000;
static constexpr float kDefaultSamplingTopP = 0.95F;
static constexpr uint32_t kDefaultSamplingTopK = 64;
static constexpr uint32_t kDefaultContextSize = 8192;
struct SamplerState {
SamplerState() = default;
~SamplerState();
SamplerState(const SamplerState&) = delete;
SamplerState& operator=(const SamplerState&) = delete;
SamplerState(SamplerState&&) = delete;
SamplerState& operator=(SamplerState&&) = delete;
LlamaSampler* chain = nullptr;
};
/**
* @brief Loads model and prepares inference context.
*
@@ -84,6 +61,15 @@ class LlamaGenerator final : public DataGenerator {
*/
void Load(const std::string& model_path);
/**
* @brief Infers text from a user prompt.
*
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& prompt, int max_tokens = 10000);
/**
* @brief Infers text from separate system and user prompts.
*
@@ -95,8 +81,8 @@ class LlamaGenerator final : public DataGenerator {
* @param max_tokens Maximum tokens to generate.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt, const std::string& prompt,
int max_tokens = kDefaultMaxTokens);
std::string Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens = 10000);
/**
* @brief Runs inference on an already-formatted prompt.
@@ -106,25 +92,30 @@ class LlamaGenerator final : public DataGenerator {
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = kDefaultMaxTokens);
int max_tokens = 10000);
/**
* @brief Loads the brewery system prompt from disk.
*
* @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text.
* @return Loaded prompt text or fallback prompt.
*/
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
/**
* @brief Returns a built-in fallback system prompt.
*
* @return Fallback prompt text.
*/
std::string GetFallbackBreweryPrompt();
llama_model* model_ = nullptr;
llama_context* context_ = nullptr;
/// @brief Persistent sampler chain reused across inference calls.
std::unique_ptr<SamplerState> sampler_;
float sampling_temperature_ = 1.0F;
float sampling_top_p_ = kDefaultSamplingTopP;
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
float sampling_top_p_ = 0.95F;
uint32_t sampling_top_k_ = 64;
std::mt19937 rng_;
uint32_t n_ctx_ = kDefaultContextSize;
uint32_t n_ctx_ = 8192;
std::string brewery_system_prompt_;
};

View File

@@ -7,7 +7,6 @@
*/
#include <cstddef>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
@@ -36,6 +35,16 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message);
/**
* @brief Applies model chat template to a user-only prompt.
*
* @param model Loaded llama model.
* @param user_prompt User prompt text.
* @return Model-formatted prompt.
*/
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
/**
* @brief Applies model chat template to system and user prompts.
*
@@ -64,10 +73,10 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
* @param raw Raw model output.
* @param name_out Parsed brewery name.
* @param description_out Parsed brewery description.
* @return Validation error message if invalid, or std::nullopt on success.
* @return Empty string on success, or validation error message.
*/
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out);
/**

View File

@@ -6,9 +6,9 @@
* @brief Deterministic mock implementation of DataGenerator.
*/
#include <array>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/data_generator.h"
@@ -24,7 +24,7 @@ class MockGenerator final : public DataGenerator {
* @param region_context Unused for mock generation.
* @return Generated brewery result.
*/
BreweryResult GenerateBrewery(const Location& location,
BreweryResult GenerateBrewery(const BreweryLocation& location,
const std::string& region_context) override;
/**
@@ -42,82 +42,13 @@ class MockGenerator final : public DataGenerator {
* @param location City and country names.
* @return Deterministic hash value.
*/
static std::size_t DeterministicHash(const Location& location);
static std::size_t DeterministicHash(const BreweryLocation& location);
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
static constexpr std::array<std::string_view, 18> kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
static constexpr std::array<std::string_view, 18> kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark "
"porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy "
"pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky "
"lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food "
"pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional "
"grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward "
"sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal "
"blends.",
"Modern taproom focused on sessionable lagers and classic pub "
"styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp "
"lagers.",
"Creative brewery offering rotating collaborations and limited "
"draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich "
"stouts."};
static constexpr std::array<std::string_view, 18> kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
static constexpr std::array<std::string_view, 18> kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer "
"gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward "
"styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean "
"fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};
static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns;
static const std::vector<std::string> kBreweryDescriptions;
static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -33,7 +33,7 @@ struct ApplicationOptions {
/// @brief Context window size (tokens) for LLM inference. Higher values
/// support longer prompts but use more memory.
uint32_t n_ctx = 8192;
uint32_t n_ctx = 2048;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;

View File

@@ -13,10 +13,10 @@
*/
struct BreweryResult {
/// @brief Brewery display name.
std::string name{};
std::string name;
/// @brief Brewery description text.
std::string description{};
std::string description;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_BREWERY_RESULT_H_

View File

@@ -15,7 +15,7 @@
*/
struct EnrichedCity {
Location location;
std::string region_context{};
std::string region_context;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_ENRICHED_CITY_H_

View File

@@ -13,25 +13,25 @@
*/
struct Location {
/// @brief City name.
std::string city{};
std::string city;
/// @brief State or province name.
std::string state_province{};
std::string state_province;
/// @brief ISO 3166-2 subdivision code.
std::string iso3166_2{};
std::string iso3166_2;
/// @brief Country name.
std::string country{};
std::string country;
/// @brief ISO 3166-1 country code.
std::string iso3166_1{};
std::string iso3166_1;
/// @brief Latitude in decimal degrees.
double latitude{};
double latitude;
/// @brief Longitude in decimal degrees.
double longitude{};
double longitude;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_LOCATION_H_

View File

@@ -13,10 +13,10 @@
*/
struct UserResult {
/// @brief Username handle.
std::string username{};
std::string username;
/// @brief Short user biography.
std::string bio{};
std::string bio;
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_MODEL_USER_RESULT_H_

View File

@@ -6,7 +6,7 @@
* @brief Loader API for curated location data.
*/
#include <filesystem>
#include <string>
#include <vector>
#include "data_model/location.h"
@@ -15,8 +15,7 @@
class JsonLoader {
public:
/// @brief Parses a JSON array file and returns all location records.
static std::vector<Location> LoadLocations(
const std::filesystem::path& filepath);
static std::vector<Location> LoadLocations(const std::string& filepath);
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -14,19 +14,19 @@
#include "services/enrichment_service.h"
#include "web_client/web_client.h"
/// @brief Provides Wikipedia summary lookups backed by cached raw extracts.
/// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService final : public IEnrichmentService {
public:
/// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::unique_ptr<WebClient> client);
explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia-derived context for a location.
[[nodiscard]] std::string GetLocationContext(const Location& loc) override;
private:
std::string FetchExtract(std::string_view query);
std::unique_ptr<WebClient> client_;
/// @brief Canonical cache for raw Wikipedia query extracts.
std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_;
std::unordered_map<std::string, std::string> extract_cache_;
};

View File

@@ -8,7 +8,7 @@
#include <utility>
BiergartenDataGenerator::BiergartenDataGenerator(
std::unique_ptr<IEnrichmentService> context_service,
std::shared_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator)
: context_service_(std::move(context_service)),
generator_(std::move(generator)) {}

View File

@@ -8,32 +8,34 @@
#include "biergarten_data_generator.h"
void BiergartenDataGenerator::GenerateBreweries(
std::span<const EnrichedCity> cities) {
const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generated_breweries_.clear();
size_t skipped_count = 0;
for (const auto& [location, region_context] : cities) {
for (const auto& enriched_city : cities) {
try {
const BreweryResult brewery =
generator_->GenerateBrewery(location, region_context);
const GeneratedBrewery gen{.location = location, .brewery = brewery};
generated_breweries_.push_back(gen);
auto brewery = generator_->GenerateBrewery(
BreweryLocation{enriched_city.location.city,
enriched_city.location.country},
enriched_city.region_context);
generated_breweries_.push_back(GeneratedBrewery{
.location = enriched_city.location, .brewery = brewery});
} catch (const std::exception& e) {
++skipped_count;
spdlog::warn(
"[Pipeline] Skipping city '{}' ({}): brewery generation failed: "
"{}",
location.city, location.country, e.what());
enriched_city.location.city, enriched_city.location.country,
e.what());
}
}
if (skipped_count > 0) {
spdlog::warn("[Pipeline] Skipped {} city/cities due to generation errors",
spdlog::warn(
"[Pipeline] Skipped {} city/cities due to generation "
"errors",
skipped_count);
}
}

View File

@@ -13,18 +13,18 @@
#include "biergarten_data_generator.h"
#include "json_handling/json_loader.h"
static constexpr std::size_t kBreweryAmount = 4;
static constexpr unsigned int brewery_amount = 4;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
const std::filesystem::path locations_path = "locations.json";
auto all_locations = JsonLoader::LoadLocations(locations_path);
auto all_locations = JsonLoader::LoadLocations(locations_path.string());
spdlog::info(" Locations available: {}", all_locations.size());
const std::size_t sample_count =
std::min(kBreweryAmount, all_locations.size());
const size_t sample_count =
std::min<size_t>(brewery_amount, all_locations.size());
const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
sample_count);

View File

@@ -21,8 +21,8 @@ bool BiergartenDataGenerator::Run() {
spdlog::info("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context);
enriched.push_back(
EnrichedCity{.location = city, .region_context = region_context});
enriched.push_back(EnrichedCity{.location = city,
.region_context = region_context});
} catch (const std::exception& exception) {
++skipped_count;
spdlog::warn(

View File

@@ -7,16 +7,16 @@
#include <spdlog/spdlog.h>
#include <array>
#include <format>
#include <optional>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view {
namespace {
std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](std::string_view text) -> std::string_view {
const std::size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
@@ -26,7 +26,7 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
return text.substr(first, last - first + 1);
};
static constexpr std::array<std::string_view, 6> separator_tokens = {
static const std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
@@ -35,7 +35,8 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
for (const std::string_view token : separator_tokens) {
const std::size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
(separator_pos == std::string::npos ||
candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
@@ -46,9 +47,8 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
}
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
std::string json_candidate =
ExtractLastJsonObjectPublic(std::string(trimmed));
if (!json_candidate.empty()) {
return ExtractLastJsonObjectPublic(std::string(trimmed));
}
@@ -56,22 +56,16 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
return std::string(trimmed);
}
} // namespace
BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) {
const BreweryLocation& location, const std::string& region_context) {
/**
* Preprocess and truncate region context to manageable size
*/
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string region_suffix =
safe_region_context.empty()
? "."
: std::format(". Regional context: {}", safe_region_context);
/**
* Load brewery system prompt from file
* Falls back to minimal inline prompt if file not found
@@ -81,33 +75,47 @@ BreweryResult LlamaGenerator::GenerateBrewery(
/**
* User prompt: provides geographic context to guide generation towards
* culturally relevant and locally-inspired brewery attributes
* culturally appropriate and locally-inspired brewery attributes
*/
std::string prompt = std::format(
std::string prompt =
"Write a brewery name and place-specific long description for a craft "
"brewery in {}{}{}",
location.city, country_suffix, region_suffix);
"brewery in ";
prompt.append(location.city_name);
if (!location.country_name.empty()) {
prompt.append(", ");
prompt.append(location.country_name);
}
if (safe_region_context.empty()) {
prompt.append(".");
} else {
prompt.append(". Regional context: ");
prompt.append(safe_region_context);
}
/**
* Store location context for retry prompts (without repeating full context)
*/
const std::string retry_location =
std::format("Location: {}{}", location.city, country_suffix);
std::string retry_location = "Location: ";
retry_location.append(location.city_name);
if (!location.country_name.empty()) {
retry_location.append(", ");
retry_location.append(location.country_name);
}
/**
* RETRY LOOP with validation and error correction
* Attempts to generate valid brewery data up to 3 times, with feedback-based
* refinement
*/
constexpr int max_attempts = 3;
const int max_attempts = 3;
std::string raw;
std::string last_error;
// Limit output length to keep it concise and focused
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
// Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens);
raw = Infer(system_prompt, prompt, max_tokens);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
@@ -116,28 +124,31 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
const std::string validation_error =
ValidateBreweryJsonPublic(json_only, name, description);
if (!validation_error.has_value()) {
if (validation_error.empty()) {
// Success: return parsed brewery data
return BreweryResult{.name = std::move(name),
.description = std::move(description)};
return {std::move(name), std::move(description)};
}
// Validation failed: log error and prepare corrective feedback
last_error = *validation_error;
last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, *validation_error);
attempt + 1, validation_error);
// Update prompt with error details to guide LLM toward correct output.
prompt = std::format(
R"(Your previous response was invalid. Error: {}
Return ONLY valid JSON with exactly these keys: {{"name": "<brewery name>", "description": "<single-paragraph description>"}}.
Do not include markdown, comments, extra keys, or literal placeholder values.
{})",
*validation_error, retry_location);
// For retries, use a compact prompt format to avoid exceeding token
// limits.
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with exactly these keys: "
"{\"name\": \"<brewery name>\", "
"\"description\": \"<single-paragraph description>\"}."
"\nDo not include markdown, comments, extra keys, or literal "
"placeholder values.";
prompt += "\n\n";
prompt += retry_location;
}
// All retry attempts exhausted: log failure and throw exception

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
@@ -13,6 +14,87 @@
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
return {.username = "test_user",
.bio = "This is a test user profile from " + locale + "."};
/**
* System prompt: specifies exact output format to minimize parsing errors
* Constraints: 2-line output, username format, bio length bounds
*/
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
/**
* User prompt: locale parameter guides cultural appropriateness of generated
* profiles
*/
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
/**
* RETRY LOOP with format validation
* Attempts up to 3 times to generate valid user profile with correct format
*/
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
/**
* Generate user profile (max 128 tokens - should fit 2 lines easily)
*/
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
/**
* Parse two-line response: first line = username, second line = bio
*/
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
/**
* Remove any whitespace from username (usernames shouldn't have
* spaces)
*/
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
/**
* Validate both fields are non-empty after processing
*/
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
/**
* Truncate bio if exceeds reasonable length for bio field
*/
if (bio.size() > 200) bio = bio.substr(0, 200);
/**
* Success: return parsed user profile
*/
return {username, bio};
} catch (const std::exception& e) {
/**
* Parsing failed: log and continue to next attempt
*/
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
/**
* All retry attempts exhausted: log failure and throw exception
*/
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -4,17 +4,13 @@
* parsing, token decoding, and JSON validation helpers for Llama modules.
*/
#include <spdlog/spdlog.h>
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#include "data_generation/llama_generator.h"
@@ -23,42 +19,40 @@
/**
* String trimming: removes leading and trailing whitespace
*/
static std::string Trim(std::string_view value) {
constexpr std::string_view whitespace = " \t\n\r\f\v";
const std::size_t first_index = value.find_first_not_of(whitespace);
if (first_index == std::string_view::npos) {
return {};
}
static std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
const std::size_t last_index = value.find_last_not_of(whitespace);
return std::string(value.substr(first_index, last_index - first_index + 1));
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
}
/**
* Normalize whitespace: collapses multiple spaces/tabs/newlines into single
* spaces
*/
static std::string CondenseWhitespace(std::string_view text) {
static std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool pending_space = false;
for (const unsigned char chr : text) {
if (std::isspace(chr) != 0) {
if (!out.empty()) {
pending_space = true;
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
}
continue;
}
if (pending_space) {
out.push_back(' ');
pending_space = false;
}
out.push_back(static_cast<char>(chr));
in_whitespace = false;
out.push_back(static_cast<char>(ch));
}
return out;
return Trim(std::move(out));
}
/**
@@ -66,14 +60,14 @@ static std::string CondenseWhitespace(std::string_view text) {
* boundaries
*/
static std::string PrepareRegionContext(std::string_view region_context,
const size_t max_chars) {
std::string normalized = CondenseWhitespace(region_context);
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const size_t last_space = normalized.find_last_of(' ');
const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
@@ -82,20 +76,108 @@ static std::string PrepareRegionContext(std::string_view region_context,
return normalized;
}
static std::string ToChatPrompt(const llama_model* model,
/**
* Remove common bullet points, numbers, and field labels added by LLM in output
*/
static std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
/**
* Parse two-line response from LLM: normalize line endings, strip formatting,
* filter spurious output, and combine remaining lines if needed
*/
static std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Filter known thinking tags like <think>...</think>, but be conservative
// to avoid removing legitimate output. Only filter specific known
// patterns.
if (!l.empty() && l.front() == '<' && low.back() == '>') {
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
if (low.find("think") != std::string::npos ||
low.find("reasoning") != std::string::npos ||
low.find("reflect") != std::string::npos) {
continue;
}
}
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
std::string combined_prompt;
combined_prompt.append(system_prompt);
combined_prompt.append("\n\n");
combined_prompt.append(user_prompt);
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// No template found, fallback to raw text
spdlog::warn(
"LlamaGenerator: missing chat template; using raw prompt fallback");
return combined_prompt;
return system_prompt + "\n\n" + user_prompt;
}
const std::array<llama_chat_message, 2> messages = {
@@ -104,67 +186,71 @@ static std::string ToChatPrompt(const llama_model* model,
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
int32_t message_count) -> int32_t {
int32_t result = llama_chat_apply_template(
tmpl, chat_messages, message_count, true, buffer.data(),
int32_t required =
llama_chat_apply_template(tmpl, messages.data(), 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (result < 0) {
return result;
}
if (result >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(result) + 1);
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
true, buffer.data(),
static_cast<int32_t>(buffer.size()));
}
return result;
};
int32_t template_result = apply_template_with_resize(messages.data(), 2);
if (template_result >= 0) {
return {buffer.data(), static_cast<std::size_t>(template_result)};
}
spdlog::warn(
"LlamaGenerator: chat template rejected system/user messages (result "
"{}); trying single user fallback",
template_result);
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
// FALLBACK: If the template fails (e.g., Gemma rejecting the "system" role),
// combine the system and user prompts into a single "user" message.
if (required < 0) {
std::string combined_prompt = system_prompt + "\n\n" + user_prompt;
const std::array<llama_chat_message, 1> fallback_msg = {
{{"user", combined_prompt.c_str()}}};
template_result = apply_template_with_resize(fallback_msg.data(), 1);
required = llama_chat_apply_template(tmpl, fallback_msg.data(), 1, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// Ultimate fallback: if GGUF template parsing still fails, use raw text.
if (template_result < 0) {
spdlog::warn(
"LlamaGenerator: chat template fallback failed (result {}); using "
"raw prompt text",
template_result);
// THE FIX: Ultimate fallback. If the GGUF's internal template is
// completely unparseable (which happens with complex Jinja macros),
// degrade gracefully to raw text instead of throwing a runtime_error.
if (required < 0) {
return combined_prompt;
}
return {buffer.data(), static_cast<std::size_t>(template_result)};
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(
tmpl, fallback_msg.data(), 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
return combined_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
// Standard buffer resize if the original "system" + "user" array succeeded
// but needed more space
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages.data(), 2, true,
buffer.data(),
static_cast<int32_t>(buffer.size()));
// Final safety net on resize
if (required < 0) {
return system_prompt + "\n\n" + user_prompt;
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0,
true);
static_cast<int32_t>(dynamic_buffer.size()),
0, true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
@@ -242,8 +328,8 @@ std::string ExtractLastJsonObjectPublic(const std::string& text) {
return {};
}
static std::optional<std::string> ValidateBreweryJson(
const std::string& raw, std::string& name_out,
static std::string ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
@@ -263,11 +349,9 @@ static std::optional<std::string> ValidateBreweryJson(
return false;
}
const auto& name_value = obj.at("name").as_string();
const auto& description_value = obj.at("description").as_string();
name_out = Trim(std::string_view(name_value.data(), name_value.size()));
description_out = Trim(
std::string_view(description_value.data(), description_value.size()));
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
@@ -317,14 +401,14 @@ static std::optional<std::string> ValidateBreweryJson(
return validation_error;
}
return std::nullopt;
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return std::nullopt;
return {};
}
// Forward declarations for helper functions exposed to other translation units
@@ -333,6 +417,16 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt, "");
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
@@ -344,8 +438,8 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
AppendTokenPiece(vocab, token, output);
}
std::optional<std::string> ValidateBreweryJsonPublic(
const std::string& raw, std::string& name_out,
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -2,7 +2,7 @@
* Text Generation / Inference Module
* Core module that performs LLM inference: converts text prompts into tokens,
* runs the neural network forward pass, samples the next token, and converts
* output tokens back to text for system+user chat prompts.
* output tokens back to text. Supports both simple and system+user prompts.
*/
#include <spdlog/spdlog.h>
@@ -17,31 +17,30 @@
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
static constexpr std::size_t kPromptTokenSlack = 8;
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt,
const int max_tokens) {
const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
const int max_tokens) {
int max_tokens) {
/**
* Validate that model and context are loaded
*/
if (model_ == nullptr || context_ == nullptr) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
}
/**
* Get vocabulary for tokenization and token-to-text conversion
*/
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
/**
* Clear KV cache to ensure clean inference state (no residual context)
@@ -52,8 +51,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* TOKENIZATION PHASE
* Convert text prompt into token IDs (integers) that the model understands
*/
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
kPromptTokenSlack);
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
@@ -70,20 +68,18 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0) {
if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
/**
* CONTEXT SIZE VALIDATION
* Validate and compute effective token budgets based on context window
* constraints
*/
const auto n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const auto n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) {
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
/**
* Clamp generation limit to available context window, reserve space for
@@ -117,9 +113,47 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
*/
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) {
if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
/**
* SAMPLER CONFIGURATION PHASE
* Set up the probabilistic token selection pipeline (sampler chain)
* Samplers are applied in sequence: temperature -> top-k -> top-p ->
* distribution
*/
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
/**
* Temperature: scales logits before softmax (controls randomness)
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
/**
* Top-K: limits sampling to the most likely tokens before nucleus
* sampling
*/
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
/**
* Top-P: nucleus sampling - filters to most likely tokens summing to top_p
* probability
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
/**
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
/**
* TOKEN GENERATION LOOP
@@ -129,44 +163,36 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
if (sampler_ == nullptr || sampler_->chain == nullptr) {
throw std::runtime_error("LlamaGenerator: sampler not initialized");
}
for (int i = 0; i < effective_max_tokens; ++i) {
/**
* Sample next token using configured sampler chain and model logits
* Index -1 means use the last output position from previous batch
*/
const llama_token next =
llama_sampler_sample(sampler_->chain, context_, -1);
llama_sampler_sample(sampler.get(), context_, -1);
/**
* Stop if model predicts end-of-generation token (EOS/EOT)
*/
if (llama_vocab_is_eog(vocab, next)) {
break;
}
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
/**
* Feed the sampled token back into model for next iteration
* (autoregressive)
*/
llama_token decode_token = next;
const llama_batch one_token_batch = llama_batch_get_one(&decode_token, 1);
if (llama_decode(context_, one_token_batch) != 0) {
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
/**
* DETOKENIZATION PHASE
* Convert generated token IDs back to text using vocabulary
*/
std::string output;
for (const llama_token token : generated_tokens) {
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
}
return output;
}

View File

@@ -5,7 +5,6 @@
#include "data_generation/llama_generator.h"
#include <memory>
#include <random>
#include <stdexcept>
#include <string>
@@ -13,47 +12,6 @@
#include "data_model/application_options.h"
#include "llama.h"
static constexpr uint32_t kMaxContextSize = 32768U;
struct SamplerConfig {
float temperature;
float top_p;
uint32_t top_k;
};
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
std::mt19937& rng) {
const llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(config.temperature));
llama_sampler_chain_add(
sampler.get(),
llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(config.top_p, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng()));
return sampler;
}
LlamaGenerator::SamplerState::~SamplerState() {
if (chain != nullptr) {
llama_sampler_free(chain);
chain = nullptr;
}
}
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path)
: rng_(std::random_device{}()) {
@@ -80,7 +38,7 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
if (options.n_ctx == 0 || options.n_ctx > kMaxContextSize) {
if (options.n_ctx == 0 || options.n_ctx > 32768) {
throw std::runtime_error(
"LlamaGenerator: context size must be in range [1, 32768]");
}
@@ -97,16 +55,9 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
n_ctx_ = options.n_ctx;
this->Load(model_path);
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
sampling_top_k_};
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
sampler_.reset(new SamplerState());
sampler_->chain = sampler_chain.release();
}
LlamaGenerator::~LlamaGenerator() {
sampler_.reset();
/**
* Free the inference context (contains KV cache and computation state)
*/

View File

@@ -23,7 +23,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
model_ = nullptr;
}
const llama_model_params model_params = llama_model_default_params();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(

View File

@@ -1,14 +1,13 @@
/**
* @file data_generation/llama/load_brewery_prompt.cpp
* @brief Resolves brewery system prompt content from cache or a configured
* filesystem path and provides a robust inline fallback prompt when absent.
* @brief Resolves brewery system prompt content from cache or filesystem
* search paths and provides a robust inline fallback prompt when absent.
*/
#include <spdlog/spdlog.h>
#include <filesystem>
#include <fstream>
#include <stdexcept>
#include "data_generation/llama_generator.h"
@@ -18,7 +17,7 @@ namespace fs = std::filesystem;
* @brief Loads brewery system prompt from disk or cache.
*
* @param prompt_file_path Preferred prompt file location.
* @return Prompt text loaded from disk.
* @return Prompt text loaded from disk or fallback content.
*/
std::string LlamaGenerator::LoadBrewerySystemPrompt(
const std::string& prompt_file_path) {
@@ -27,33 +26,72 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt(
return brewery_system_prompt_;
}
// Try the provided path only
const fs::path prompt_path(prompt_file_path);
std::ifstream prompt_file(prompt_path);
if (!prompt_file.is_open()) {
spdlog::error(
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: missing brewery system prompt file: " +
prompt_path.string());
}
// Try multiple path locations
std::vector<std::string> paths_to_try = {
prompt_file_path, // As provided
"../" + prompt_file_path, // One level up
"../../" + prompt_file_path, // Two levels up
};
const std::string prompt((std::istreambuf_iterator(prompt_file)),
for (const auto& path : paths_to_try) {
std::ifstream prompt_file(path);
if (prompt_file.is_open()) {
std::string prompt((std::istreambuf_iterator<char>(prompt_file)),
std::istreambuf_iterator<char>());
prompt_file.close();
if (prompt.empty()) {
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
prompt_path.string());
throw std::runtime_error(
"LlamaGenerator: empty brewery system prompt file: " +
prompt_path.string());
}
if (!prompt.empty()) {
spdlog::info(
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
prompt_path.string(), prompt.length());
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} "
"chars)",
path, prompt.length());
brewery_system_prompt_ = prompt;
return brewery_system_prompt_;
}
}
}
spdlog::warn(
"LlamaGenerator: Could not open brewery system prompt file at any of "
"the "
"expected locations. Using fallback inline prompt.");
return GetFallbackBreweryPrompt();
}
/**
* @brief Provides an inline fallback brewery system prompt.
*
* @return Default fallback prompt text.
*/
std::string LlamaGenerator::GetFallbackBreweryPrompt() {
return "You are an experienced brewmaster and owner of a local craft "
"brewery. "
"Create a distinctive, authentic name and detailed description that "
"genuinely reflects your specific location, brewing philosophy, "
"local "
"culture, and community connection. The brewery must feel real and "
"grounded—not generic or interchangeable.\n\n"
"AVOID REPETITIVE PHRASES - Never use:\n"
"Love letter to, tribute to, rolling hills, picturesque, every sip "
"tells a story, Come for X stay for Y, rich history, passion, woven "
"into, ancient roots, timeless, where tradition meets innovation\n\n"
"OPENING APPROACHES - Choose ONE:\n"
"1. Start with specific beer style and its regional origins\n"
"2. Begin with specific brewing challenge (water, altitude, "
"climate)\n"
"3. Open with founding story or personal motivation\n"
"4. Lead with specific local ingredient or resource\n"
"5. Start with unexpected angle or contradiction\n"
"6. Open with local event, tradition, or cultural moment\n"
"7. Begin with tangible architectural or geographic detail\n\n"
"BE SPECIFIC - Include:\n"
"- At least ONE concrete proper noun (landmark, river, "
"neighborhood)\n"
"- Specific beer styles relevant to the REGION'S culture\n"
"- Concrete brewing challenges or advantages\n"
"- Sensory details SPECIFIC to place—not generic adjectives\n\n"
"LENGTH: 150-250 words. TONE: Can be soulful, irreverent, "
"matter-of-fact, unpretentious, or minimalist.\n\n"
"Output ONLY a raw JSON object with keys name and description. "
"No markdown, backticks, preamble, or trailing text.";
}

View File

@@ -0,0 +1,71 @@
/**
* @file data_generation/mock/data.cpp
* @brief Defines static lookup tables used by MockGenerator for deterministic
* brewery names, descriptions, usernames, and bios.
*/
#include <string>
#include <vector>
#include "data_generation/mock_generator.h"
const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
"Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel",
"Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"};
const std::vector<std::string> MockGenerator::kBreweryNouns = {
"Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works",
"House", "Fermentery", "Ale Co.", "Cellars", "Collective",
"Project", "Foundry", "Malthouse", "Public House", "Co-op",
"Lab", "Beer Hall", "Guild"};
const std::vector<std::string> MockGenerator::kBreweryDescriptions = {
"Handcrafted pale ales and seasonal IPAs with local ingredients.",
"Traditional lagers and experimental sours in small batches.",
"Award-winning stouts and wildly hoppy blonde ales.",
"Craft brewery specializing in Belgian-style triples and dark porters.",
"Modern brewery blending tradition with bold experimental flavors.",
"Neighborhood-focused taproom pouring crisp pilsners and citrusy pale "
"ales.",
"Small-batch brewery known for barrel-aged releases and smoky lagers.",
"Independent brewhouse pairing farmhouse ales with rotating food pop-ups.",
"Community brewpub making balanced bitters, saisons, and hazy IPAs.",
"Experimental nanobrewery exploring local yeast and regional grains.",
"Family-run brewery producing smooth amber ales and robust porters.",
"Urban brewery crafting clean lagers and bright, fruit-forward sours.",
"Riverfront brewhouse featuring oak-matured ales and seasonal blends.",
"Modern taproom focused on sessionable lagers and classic pub styles.",
"Brewery rooted in tradition with a lineup of malty reds and crisp lagers.",
"Creative brewery offering rotating collaborations and limited draft-only "
"pours.",
"Locally inspired brewery serving approachable ales with bold hop "
"character.",
"Destination taproom known for balanced IPAs and cocoa-rich stouts."};
const std::vector<std::string> MockGenerator::kUsernames = {
"hopseeker", "malttrail", "yeastwhisper", "lagerlane",
"barrelbound", "foamfinder", "taphunter", "graingeist",
"brewscout", "aleatlas", "caskcompass", "hopsandmaps",
"mashpilot", "pintnomad", "fermentfriend", "stoutsignal",
"sessionwander", "kettlekeeper"};
const std::vector<std::string> MockGenerator::kBios = {
"Always chasing balanced IPAs and crisp lagers across local taprooms.",
"Weekend brewery explorer with a soft spot for dark, roasty stouts.",
"Documenting tiny brewpubs, fresh pours, and unforgettable beer gardens.",
"Fan of farmhouse ales, food pairings, and long tasting flights.",
"Collecting favorite pilsners one city at a time.",
"Hops-first drinker who still saves room for classic malt-forward styles.",
"Finding hidden tap lists and sharing the best seasonal releases.",
"Brewery road-tripper focused on local ingredients and clean fermentation.",
"Always comparing house lagers and ranking patio pint vibes.",
"Curious about yeast strains, barrel programs, and cellar experiments.",
"Believes every neighborhood deserves a great community taproom.",
"Looking for session beers that taste great from first sip to last.",
"Belgian ale enthusiast who never skips a new saison.",
"Hazy IPA critic with deep respect for a perfectly clear pilsner.",
"Visits breweries for the stories, stays for the flagship pours.",
"Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."};

View File

@@ -5,12 +5,13 @@
*/
#include <boost/container_hash/hash.hpp>
#include <string>
#include "data_generation/mock_generator.h"
size_t MockGenerator::DeterministicHash(const Location& location) {
size_t seed = 0;
boost::hash_combine(seed, location.city);
boost::hash_combine(seed, location.country);
std::size_t MockGenerator::DeterministicHash(const BreweryLocation& location) {
std::size_t seed = 0;
boost::hash_combine(seed, location.city_name);
boost::hash_combine(seed, location.country_name);
return seed;
}

View File

@@ -4,39 +4,35 @@
* and country into fixed mock phrase catalogs.
*/
#include <format>
#include <string>
#include <string_view>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const Location& location, const std::string& /*region_context*/) {
const BreweryLocation& location, const std::string& /*region_context*/) {
const std::size_t hash = DeterministicHash(location);
const std::string_view adjective =
const std::string& adjective =
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
const std::string_view noun =
kBreweryNouns.at(hash / 7 % kBreweryNouns.size());
const std::string_view base_description =
const std::string& noun =
kBreweryNouns.at((hash / 7) % kBreweryNouns.size());
const std::string& base_description =
kBreweryDescriptions.at((hash / 13) % kBreweryDescriptions.size());
const std::string name =
std::format("{} {} {}", location.city, adjective, noun);
std::string name(location.city_name);
name.append(" ");
name.append(adjective);
name.append(" ");
name.append(noun);
const std::string state_suffix =
location.state_province.empty()
? std::string{}
: std::format(", {}", location.state_province);
const std::string country_suffix =
location.country.empty() ? std::string{}
: std::format(", {}", location.country);
const std::string description =
std::format("{} Located in {}{}{}.", base_description, location.city,
state_suffix, country_suffix);
return {
.name = name,
.description = description,
};
std::string description = base_description;
description.append(" Based in ");
description.append(location.city_name);
if (!location.country_name.empty()) {
description.append(", ");
description.append(location.country_name);
}
description.append(".");
return {name, description};
}

View File

@@ -6,7 +6,6 @@
#include <functional>
#include <string>
#include <string_view>
#include "data_generation/mock_generator.h"
@@ -14,9 +13,7 @@ UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
const std::string_view username = kUsernames[hash % kUsernames.size()];
const std::string_view bio = kBios[hash / 11 % kBios.size()];
result.username = username;
result.bio = bio;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -12,35 +12,31 @@
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <string_view>
static std::string ReadRequiredString(const boost::json::object& object,
const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_string()) {
throw std::runtime_error(std::string("Missing or invalid string field: ") +
key);
throw std::runtime_error(
std::string("Missing or invalid string field: ") + key);
}
const std::string_view text = value->as_string();
return std::string(text);
return std::string(value->as_string().c_str());
}
static double ReadRequiredNumber(const boost::json::object& object,
const char* key) {
const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_number()) {
throw std::runtime_error(std::string("Missing or invalid numeric field: ") +
key);
throw std::runtime_error(
std::string("Missing or invalid numeric field: ") + key);
}
return value->to_number<double>();
}
std::vector<Location> JsonLoader::LoadLocations(
const std::filesystem::path& filepath) {
std::vector<Location> JsonLoader::LoadLocations(const std::string& filepath) {
std::ifstream input(filepath);
if (!input.is_open()) {
throw std::runtime_error("Failed to open locations file: " +
filepath.string());
throw std::runtime_error("Failed to open locations file: " + filepath);
}
std::stringstream buffer;
@@ -82,6 +78,6 @@ std::vector<Location> JsonLoader::LoadLocations(
}
spdlog::info("[JsonLoader] Loaded {} locations from {}", locations.size(),
filepath.string());
filepath);
return locations;
}

View File

@@ -10,7 +10,6 @@
#include <boost/program_options.hpp>
#include <exception>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
@@ -31,35 +30,26 @@ namespace di = boost::di;
*
* @param argc Command-line argument count.
* @param argv Command-line arguments.
* @return Parsed ApplicationOptions if parsing succeeded, std::nullopt
* otherwise.
* @param options Output ApplicationOptions struct.
* @return true if parsing succeeded and should proceed, false otherwise.
*/
std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
bool ParseArguments(const int argc, char** argv,
ApplicationOptions& options) noexcept {
prog_opts::options_description desc("Pipeline Options");
auto opt = desc.add_options();
opt("help,h", "Produce help message");
opt("mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data");
opt("model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)");
opt("temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)");
opt("top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)");
opt("top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)");
opt("n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)");
opt("seed", prog_opts::value<int>()->default_value(-1),
desc.add_options()("help,h", "Produce help message")(
"mocked", prog_opts::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", prog_opts::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"temperature", prog_opts::value<float>()->default_value(1.0F),
"Sampling temperature (higher = more random)")(
"top-p", prog_opts::value<float>()->default_value(0.95F),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"top-k", prog_opts::value<uint32_t>()->default_value(64),
"Top-k sampling parameter (higher = more candidate tokens)")(
"n-ctx", prog_opts::value<uint32_t>()->default_value(8192),
"Context window size in tokens (1-32768)")(
"seed", prog_opts::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
// Handle the "no arguments" or "help" case
@@ -68,7 +58,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
std::stringstream usage_stream;
usage_stream << "\nUsage: biergarten-pipeline [options]\n\n" << desc;
spdlog::info(usage_stream.str());
return std::nullopt;
return false;
}
try {
@@ -81,7 +71,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
std::stringstream help_stream;
help_stream << "\n" << desc;
spdlog::info(help_stream.str());
return std::nullopt;
return false;
}
const auto use_mocked = variables_map["mocked"].as<bool>();
@@ -90,19 +80,19 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
if (use_mocked && !model_path.empty()) {
spdlog::error(
"Invalid arguments: --mocked and --model are mutually exclusive");
return std::nullopt;
return false;
}
if (!use_mocked && model_path.empty()) {
spdlog::error(
"Invalid arguments: Either --mocked or --model must be specified");
return std::nullopt;
return false;
}
const bool has_llm_params = !variables_map["temperature"].defaulted() ||
!variables_map["top-p"].defaulted() ||
!variables_map["top-k"].defaulted() ||
!variables_map["seed"].defaulted() = false;
!variables_map["seed"].defaulted();
if (use_mocked && has_llm_params) {
spdlog::warn(
@@ -110,7 +100,6 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
" ignored when using --mocked");
}
ApplicationOptions options;
options.use_mocked = use_mocked;
options.model_path = model_path;
options.temperature = variables_map["temperature"].as<float>();
@@ -119,37 +108,35 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
options.n_ctx = variables_map["n-ctx"].as<uint32_t>();
options.seed = variables_map["seed"].as<int>();
return options;
return true;
} catch (const std::exception& exception) {
spdlog::error("Failed to parse command-line arguments: {}",
exception.what());
return std::nullopt;
return false;
} catch (...) {
spdlog::error("Failed to parse command-line arguments: unknown error");
return std::nullopt;
return false;
}
}
int main(const int argc, char** argv) {
int main(const int argc, char** argv) noexcept {
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
const auto parsed_options = ParseArguments(argc, argv);
if (!parsed_options.has_value()) {
ApplicationOptions options;
if (!ParseArguments(argc, argv, options)) {
return 0;
}
const auto options = *parsed_options;
const auto injector = di::make_injector(
di::bind<WebClient>().to<CURLWebClient>(),
di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(),
di::bind<std::string>().to(options.model_path),
di::bind<DataGenerator>().to(
[options](const auto& inj) -> std::unique_ptr<DataGenerator> {
di::bind<DataGenerator>().to([options](const auto& injector)
-> std::unique_ptr<DataGenerator> {
if (options.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
@@ -161,7 +148,7 @@ int main(const int argc, char** argv) {
"top-p={}, top-k={}, n_ctx={}, seed={})",
options.model_path, options.temperature, options.top_p,
options.top_k, options.n_ctx, options.seed);
return inj.template create<std::unique_ptr<LlamaGenerator>>();
return injector.template create<std::unique_ptr<LlamaGenerator>>();
}));
auto generator = injector.create<BiergartenDataGenerator>();

View File

@@ -34,12 +34,9 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string();
std::string extract(extract_view);
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}

View File

@@ -10,12 +10,19 @@
#include "services/wikipedia_service.h"
std::string WikipediaService::GetLocationContext(const Location& loc) {
if (!client_) {
return {};
const std::string cache_key = loc.city + "|" + loc.country;
const auto cache_it = cache_.find(cache_key);
if (cache_it != cache_.end()) {
return cache_it->second;
}
std::string result;
if (!client_) {
cache_.emplace(cache_key, result);
return result;
}
std::string region_query(loc.city);
if (!loc.country.empty()) {
region_query += ", ";
@@ -43,5 +50,7 @@ std::string WikipediaService::GetLocationContext(const Location& loc) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());
}
cache_.emplace(cache_key, result);
return result;
}

View File

@@ -7,5 +7,5 @@
#include <utility>
WikipediaService::WikipediaService(std::unique_ptr<WebClient> client)
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {}

View File

@@ -5,70 +5,45 @@
#include <curl/curl.h>
#include <cstdint>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h"
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url) {
constexpr uint64_t connection_timeout = 10;
constexpr uint64_t request_timeout = 30;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connection_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, request_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, const size_t size,
const size_t nmemb, void* userp) {
const size_t real_size = size * nmemb;
auto* str = static_cast<std::string*>(userp);
str->append(static_cast<char*>(contents), real_size);
return real_size;
static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void* userp) {
size_t realsize = size * nmemb;
auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char*>(contents), realsize);
return realsize;
}
std::string CURLWebClient::Get(const std::string& url) {
const CurlHandle curl = create_handle();
auto curl = create_handle();
std::string response_string;
set_common_get_options(curl.get(), url);
set_common_get_options(curl.get(), url, {10L, 20L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
const auto error =
std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
int64_t httpCode = 0;
long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
const std::string error = "[CURLWebClient] HTTP error " +
std::to_string(httpCode) + " for URL " + url;
throw std::runtime_error(error);
std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str());
}
return response_string;

View File

@@ -14,11 +14,10 @@ std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (!output) {
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}
if (output) {
std::string result(output);
curl_free(output);
return result;
}
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}

View File

@@ -0,0 +1,28 @@
/**
* @file web_client/curl_web_client_utils.cpp
* @brief Shared CURLWebClient helper implementations.
*/
#include "curl_web_client_utils.h"
#include <stdexcept>
CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}

View File

@@ -0,0 +1,26 @@
#ifndef BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
#define BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
/**
* @file web_client/curl_web_client_utils.h
* @brief Shared helpers for CURLWebClient request setup.
*/
#include <curl/curl.h>
#include <memory>
#include <string>
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
struct CurlTimeouts {
long connect_timeout;
long total_timeout;
};
CurlHandle create_handle();
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts);
#endif // BIERGARTEN_PIPELINE_SRC_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_