Compare commits

4 Commits

Author SHA1 Message Date
Aaron Po
077f6ab4ae edit prompt 2026-04-02 22:56:18 -04:00
Aaron Po
534403734a Refactor BiergartenDataGenerator and LlamaGenerator 2026-04-02 22:46:00 -04:00
Aaron Po
3af053f0eb format codebase 2026-04-02 21:46:46 -04:00
Aaron Po
ba165d8aa7 Separate llama generator class src file into method files 2026-04-02 21:37:46 -04:00
34 changed files with 1879 additions and 1754 deletions

View File

@@ -1,10 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
Standard: c++23 ColumnLimit: 80
ColumnLimit: 100 IndentWidth: 3
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
... ...

View File

@@ -83,8 +83,18 @@ set(PIPELINE_SOURCES
src/data_generation/data_downloader.cpp src/data_generation/data_downloader.cpp
src/database/database.cpp src/database/database.cpp
src/json_handling/json_loader.cpp src/json_handling/json_loader.cpp
src/data_generation/llama_generator.cpp src/data_generation/llama/destructor.cpp
src/data_generation/mock_generator.cpp src/data_generation/llama/set_sampling_options.cpp
src/data_generation/llama/load.cpp
src/data_generation/llama/infer.cpp
src/data_generation/llama/generate_brewery.cpp
src/data_generation/llama/generate_user.cpp
src/data_generation/llama/helpers.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/load.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
src/json_handling/stream_parser.cpp src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp src/wikipedia/wikipedia_service.cpp
src/main.cpp src/main.cpp

View File

@@ -3,23 +3,24 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "database/database.h" #include "database/database.h"
#include "web_client/web_client.h" #include "web_client/web_client.h"
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
/** /**
* @brief Program options for the Biergarten pipeline application. * @brief Program options for the Biergarten pipeline application.
*/ */
struct ApplicationOptions { struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with use_mocked. /// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path; std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with model_path. /// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false; bool use_mocked = false;
/// @brief Directory for cached JSON and database files. /// @brief Directory for cached JSON and database files.
@@ -28,27 +29,27 @@ struct ApplicationOptions {
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 0.8f; float temperature = 0.8f;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more random). /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random).
float top_p = 0.92f; float top_p = 0.92f;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
/// @brief Git commit hash for database consistency (always pinned to c5eb7772). /// @brief Git commit hash for database consistency (always pinned to
/// c5eb7772).
std::string commit = "c5eb7772"; std::string commit = "c5eb7772";
}; };
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_
/** /**
* @brief Main data generator class for the Biergarten pipeline. * @brief Main data generator class for the Biergarten pipeline.
* *
* This class encapsulates the core logic for generating brewery data. * This class encapsulates the core logic for generating brewery data.
* It handles database initialization, data loading/downloading, and brewery generation. * It handles database initialization, data loading/downloading, and brewery
* generation.
*/ */
class BiergartenDataGenerator { class BiergartenDataGenerator {
public: public:
/** /**
* @brief Construct a BiergartenDataGenerator with injected dependencies. * @brief Construct a BiergartenDataGenerator with injected dependencies.
* *
@@ -56,9 +57,9 @@ public:
* @param web_client HTTP client for downloading data. * @param web_client HTTP client for downloading data.
* @param database SQLite database instance. * @param database SQLite database instance.
*/ */
BiergartenDataGenerator(const ApplicationOptions &options, BiergartenDataGenerator(const ApplicationOptions& options,
std::shared_ptr<WebClient> web_client, std::shared_ptr<WebClient> web_client,
SqliteDatabase &database); SqliteDatabase& database);
/** /**
* @brief Run the data generation pipeline. * @brief Run the data generation pipeline.
@@ -73,7 +74,7 @@ public:
*/ */
int Run(); int Run();
private: private:
/// @brief Immutable application options. /// @brief Immutable application options.
const ApplicationOptions options_; const ApplicationOptions options_;
@@ -81,7 +82,17 @@ private:
std::shared_ptr<WebClient> webClient_; std::shared_ptr<WebClient> webClient_;
/// @brief Database dependency. /// @brief Database dependency.
SqliteDatabase &database_; SqliteDatabase& database_;
/**
* @brief Enriched city data with Wikipedia context.
*/
struct EnrichedCity {
int city_id;
std::string city_name;
std::string country_name;
std::string region_context;
};
/** /**
* @brief Initialize the data generator based on options. * @brief Initialize the data generator based on options.
@@ -98,19 +109,45 @@ private:
void LoadGeographicData(); void LoadGeographicData();
/** /**
* @brief Generate sample breweries for demonstration. * @brief Query cities from database and build country name map.
*
* @return Vector of (City, country_name) pairs capped at 30 entries.
*/ */
void GenerateSampleBreweries(); std::vector<std::pair<City, std::string>> QueryCitiesWithCountries();
/**
* @brief Enrich cities with Wikipedia summaries.
*
* @param cities Vector of (City, country_name) pairs.
* @return Vector of enriched city data with context.
*/
std::vector<EnrichedCity> EnrichWithWikipedia(
const std::vector<std::pair<City, std::string>>& cities);
/**
* @brief Generate breweries for enriched cities.
*
* @param generator The data generator instance.
* @param cities Vector of enriched city data.
*/
void GenerateBreweries(DataGenerator& generator,
const std::vector<EnrichedCity>& cities);
/**
* @brief Log the generated brewery results.
*/
void LogResults() const;
/** /**
* @brief Helper struct to store generated brewery data. * @brief Helper struct to store generated brewery data.
*/ */
struct GeneratedBrewery { struct GeneratedBrewery {
int cityId; int city_id;
std::string cityName; std::string city_name;
BreweryResult brewery; BreweryResult brewery;
}; };
/// @brief Stores generated brewery data. /// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generatedBreweries_; std::vector<GeneratedBrewery> generatedBreweries_;
}; };
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_

View File

@@ -9,7 +9,7 @@
/// @brief Downloads and caches source geography JSON payloads. /// @brief Downloads and caches source geography JSON payloads.
class DataDownloader { class DataDownloader {
public: public:
/// @brief Initializes global curl state used by this downloader. /// @brief Initializes global curl state used by this downloader.
explicit DataDownloader(std::shared_ptr<WebClient> web_client); explicit DataDownloader(std::shared_ptr<WebClient> web_client);
@@ -18,12 +18,13 @@ public:
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string &cache_path, const std::string& cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string& commit =
"c5eb7772" // Stable commit: 2026-03-28 export
); );
private: private:
static bool FileExists(const std::string &file_path); static bool FileExists(const std::string& file_path);
std::shared_ptr<WebClient> web_client_; std::shared_ptr<WebClient> web_client_;
}; };

View File

@@ -14,16 +14,16 @@ struct UserResult {
}; };
class DataGenerator { class DataGenerator {
public: public:
virtual ~DataGenerator() = default; virtual ~DataGenerator() = default;
virtual void Load(const std::string &model_path) = 0; virtual void Load(const std::string& model_path) = 0;
virtual BreweryResult GenerateBrewery(const std::string &city_name, virtual BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) = 0; const std::string& region_context) = 0;
virtual UserResult GenerateUser(const std::string &locale) = 0; virtual UserResult GenerateUser(const std::string& locale) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -10,29 +10,32 @@ struct llama_model;
struct llama_context; struct llama_context;
class LlamaGenerator final : public DataGenerator { class LlamaGenerator final : public DataGenerator {
public: public:
LlamaGenerator() = default; LlamaGenerator() = default;
~LlamaGenerator() override; ~LlamaGenerator() override;
void SetSamplingOptions(float temperature, float top_p, int seed = -1); void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void Load(const std::string &model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) override; const std::string& region_context) override;
UserResult GenerateUser(const std::string &locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
std::string Infer(const std::string &prompt, int max_tokens = 10000); std::string Infer(const std::string& prompt, int max_tokens = 10000);
// Overload that allows passing a system message separately so chat-capable // Overload that allows passing a system message separately so chat-capable
// models receive a proper system role instead of having the system text // models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal // concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output). // reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt, std::string Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens = 10000);
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = 10000); int max_tokens = 10000);
llama_model *model_ = nullptr; llama_model* model_ = nullptr;
llama_context *context_ = nullptr; llama_context* context_ = nullptr;
float sampling_temperature_ = 0.8f; float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f; float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu; uint32_t sampling_seed_ = 0xFFFFFFFFu;

View File

@@ -0,0 +1,32 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#include <string>
#include <utility>
struct llama_model;
struct llama_vocab;
typedef int llama_token;
// Helper functions for LlamaGenerator methods
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700);
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message);
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt);
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output);
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out);
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_

View File

@@ -1,21 +1,22 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string> #include <string>
#include <vector> #include <vector>
class MockGenerator final : public DataGenerator { #include "data_generation/data_generator.h"
public:
void Load(const std::string &model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) override;
UserResult GenerateUser(const std::string &locale) override;
private: class MockGenerator final : public DataGenerator {
static std::size_t DeterministicHash(const std::string &a, public:
const std::string &b); void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string& city_name,
const std::string& country_name,
const std::string& region_context) override;
UserResult GenerateUser(const std::string& locale) override;
private:
static std::size_t DeterministicHash(const std::string& a,
const std::string& b);
static const std::vector<std::string> kBreweryAdjectives; static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns; static const std::vector<std::string> kBreweryNouns;

View File

@@ -1,8 +1,9 @@
#ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
@@ -39,18 +40,18 @@ struct City {
/// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks.
class SqliteDatabase { class SqliteDatabase {
private: private:
sqlite3 *db_ = nullptr; sqlite3* db_ = nullptr;
std::mutex db_mutex_; std::mutex db_mutex_;
void InitializeSchema(); void InitializeSchema();
public: public:
/// @brief Closes the SQLite connection if initialized. /// @brief Closes the SQLite connection if initialized.
~SqliteDatabase(); ~SqliteDatabase();
/// @brief Opens the SQLite database at db_path and creates schema objects. /// @brief Opens the SQLite database at db_path and creates schema objects.
void Initialize(const std::string &db_path = ":memory:"); void Initialize(const std::string& db_path = ":memory:");
/// @brief Starts a database transaction for batched writes. /// @brief Starts a database transaction for batched writes.
void BeginTransaction(); void BeginTransaction();
@@ -59,16 +60,16 @@ public:
void CommitTransaction(); void CommitTransaction();
/// @brief Inserts a country row. /// @brief Inserts a country row.
void InsertCountry(int id, const std::string &name, const std::string &iso2, void InsertCountry(int id, const std::string& name, const std::string& iso2,
const std::string &iso3); const std::string& iso3);
/// @brief Inserts a state row linked to a country. /// @brief Inserts a state row linked to a country.
void InsertState(int id, int country_id, const std::string &name, void InsertState(int id, int country_id, const std::string& name,
const std::string &iso2); const std::string& iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int state_id, int country_id, const std::string &name, void InsertCity(int id, int state_id, int country_id,
double latitude, double longitude); const std::string& name, double latitude, double longitude);
/// @brief Returns city records including parent country id. /// @brief Returns city records including parent country id.
std::vector<City> QueryCities(); std::vector<City> QueryCities();

View File

@@ -1,15 +1,17 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include <string>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing. /// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON file and writes country/state/city rows into db. /// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); static void LoadWorldCities(const std::string& json_path,
SqliteDatabase& db);
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,10 +1,11 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "database/database.h"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
class SqliteDatabase; class SqliteDatabase;
@@ -20,13 +21,13 @@ struct CityRecord {
/// @brief Streaming SAX parser that emits city records during traversal. /// @brief Streaming SAX parser that emits city records during traversal.
class StreamingJsonParser { class StreamingJsonParser {
public: public:
/// @brief Parses file_path and invokes callbacks for city rows and progress. /// @brief Parses file_path and invokes callbacks for city rows and progress.
static void Parse(const std::string &file_path, SqliteDatabase &db, static void Parse(const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress = nullptr); std::function<void(size_t, size_t)> on_progress = nullptr);
private: private:
/// @brief Mutable SAX handler state while traversing nested JSON arrays. /// @brief Mutable SAX handler state while traversing nested JSON arrays.
struct ParseState { struct ParseState {
int current_country_id = 0; int current_country_id = 0;
@@ -42,7 +43,7 @@ private:
bool in_states_array = false; bool in_states_array = false;
bool in_cities_array = false; bool in_cities_array = false;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t bytes_processed = 0; size_t bytes_processed = 0;
}; };

View File

@@ -1,29 +1,30 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory> #include <memory>
#include "web_client/web_client.h"
// RAII for curl_global_init/cleanup. // RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl // An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application. // operations and exist for the lifetime of the application.
class CurlGlobalState { class CurlGlobalState {
public: public:
CurlGlobalState(); CurlGlobalState();
~CurlGlobalState(); ~CurlGlobalState();
CurlGlobalState(const CurlGlobalState &) = delete; CurlGlobalState(const CurlGlobalState&) = delete;
CurlGlobalState &operator=(const CurlGlobalState &) = delete; CurlGlobalState& operator=(const CurlGlobalState&) = delete;
}; };
class CURLWebClient : public WebClient { class CURLWebClient : public WebClient {
public: public:
CURLWebClient(); CURLWebClient();
~CURLWebClient() override; ~CURLWebClient() override;
void DownloadToFile(const std::string &url, void DownloadToFile(const std::string& url,
const std::string &file_path) override; const std::string& file_path) override;
std::string Get(const std::string &url) override; std::string Get(const std::string& url) override;
std::string UrlEncode(const std::string &value) override; std::string UrlEncode(const std::string& value) override;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -4,19 +4,19 @@
#include <string> #include <string>
class WebClient { class WebClient {
public: public:
virtual ~WebClient() = default; virtual ~WebClient() = default;
// Downloads content from a URL to a file. Throws on error. // Downloads content from a URL to a file. Throws on error.
virtual void DownloadToFile(const std::string &url, virtual void DownloadToFile(const std::string& url,
const std::string &file_path) = 0; const std::string& file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on // Performs a GET request and returns the response body as a string. Throws
// error. // on error.
virtual std::string Get(const std::string &url) = 0; virtual std::string Get(const std::string& url) = 0;
// URL-encodes a string. // URL-encodes a string.
virtual std::string UrlEncode(const std::string &value) = 0; virtual std::string UrlEncode(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -10,7 +10,7 @@
/// @brief Provides cached Wikipedia summary lookups for city and country pairs. /// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService { class WikipediaService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::shared_ptr<WebClient> client); explicit WikipediaService(std::shared_ptr<WebClient> client);
@@ -18,7 +18,7 @@ public:
[[nodiscard]] std::string GetSummary(std::string_view city, [[nodiscard]] std::string GetSummary(std::string_view city,
std::string_view country); std::string_view country);
private: private:
std::string FetchExtract(std::string_view query); std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_; std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_; std::unordered_map<std::string, std::string> cache_;

View File

@@ -1,21 +1,20 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <filesystem> #include <filesystem>
#include <unordered_map> #include <unordered_map>
#include <spdlog/spdlog.h>
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "json_handling/json_loader.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
#include "json_handling/json_loader.h"
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
BiergartenDataGenerator::BiergartenDataGenerator( BiergartenDataGenerator::BiergartenDataGenerator(
const ApplicationOptions &options, const ApplicationOptions& options, std::shared_ptr<WebClient> web_client,
std::shared_ptr<WebClient> web_client, SqliteDatabase& database)
SqliteDatabase &database)
: options_(options), webClient_(web_client), database_(database) {} : options_(options), webClient_(web_client), database_(database) {}
std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() { std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() {
@@ -62,57 +61,79 @@ void BiergartenDataGenerator::LoadGeographicData() {
} }
} }
void BiergartenDataGenerator::GenerateSampleBreweries() { std::vector<std::pair<City, std::string>>
auto generator = InitializeGenerator(); BiergartenDataGenerator::QueryCitiesWithCountries() {
WikipediaService wikipedia_service(webClient_);
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
auto countries = database_.QueryCountries(50);
auto states = database_.QueryStates(50);
auto cities = database_.QueryCities(); auto cities = database_.QueryCities();
// Build a quick map of country id -> name for per-city lookups. // Build a quick map of country id -> name for per-city lookups.
auto all_countries = database_.QueryCountries(0); auto all_countries = database_.QueryCountries(0);
std::unordered_map<int, std::string> country_map; std::unordered_map<int, std::string> country_map;
for (const auto &c : all_countries) for (const auto& c : all_countries) {
country_map[c.id] = c.name; country_map[c.id] = c.name;
}
spdlog::info("\nTotal records loaded:"); spdlog::info("\nTotal records loaded:");
spdlog::info(" Countries: {}", database_.QueryCountries(0).size()); spdlog::info(" Countries: {}", database_.QueryCountries(0).size());
spdlog::info(" States: {}", database_.QueryStates(0).size()); spdlog::info(" States: {}", database_.QueryStates(0).size());
spdlog::info(" Cities: {}", cities.size()); spdlog::info(" Cities: {}", cities.size());
generatedBreweries_.clear(); // Cap at 30 entries.
const size_t sample_count = std::min(size_t(30), cities.size()); const size_t sample_count = std::min(size_t(30), cities.size());
std::vector<std::pair<City, std::string>> result;
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
for (size_t i = 0; i < sample_count; i++) { for (size_t i = 0; i < sample_count; i++) {
const auto &city = cities[i]; const auto& city = cities[i];
const int city_id = city.id; std::string country_name;
const std::string city_name = city.name;
std::string local_country;
const auto country_it = country_map.find(city.country_id); const auto country_it = country_map.find(city.country_id);
if (country_it != country_map.end()) { if (country_it != country_map.end()) {
local_country = country_it->second; country_name = country_it->second;
}
result.push_back({city, country_name});
} }
return result;
}
std::vector<BiergartenDataGenerator::EnrichedCity>
BiergartenDataGenerator::EnrichWithWikipedia(
const std::vector<std::pair<City, std::string>>& cities) {
WikipediaService wikipedia_service(webClient_);
std::vector<EnrichedCity> enriched;
for (const auto& [city, country_name] : cities) {
const std::string region_context = const std::string region_context =
wikipedia_service.GetSummary(city_name, local_country); wikipedia_service.GetSummary(city.name, country_name);
spdlog::debug("[Pipeline] Region context for {}: {}", city_name, spdlog::debug("[Pipeline] Region context for {}: {}", city.name,
region_context); region_context);
auto brewery = enriched.push_back({city.id, city.name, country_name, region_context});
generator->GenerateBrewery(city_name, local_country, region_context);
generatedBreweries_.push_back({city_id, city_name, brewery});
} }
return enriched;
}
void BiergartenDataGenerator::GenerateBreweries(
DataGenerator& generator, const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generatedBreweries_.clear();
for (const auto& enriched_city : cities) {
auto brewery = generator.GenerateBrewery(enriched_city.city_name,
enriched_city.country_name,
enriched_city.region_context);
generatedBreweries_.push_back(
{enriched_city.city_id, enriched_city.city_name, brewery});
}
}
void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ==="); spdlog::info("\n=== GENERATED DATA DUMP ===");
for (size_t i = 0; i < generatedBreweries_.size(); i++) { for (size_t i = 0; i < generatedBreweries_.size(); i++) {
const auto &entry = generatedBreweries_[i]; const auto& entry = generatedBreweries_[i];
spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId, spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.city_id,
entry.cityName); entry.city_name);
spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); spdlog::info(" brewery_name=\"{}\"", entry.brewery.name);
spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); spdlog::info(" brewery_description=\"{}\"", entry.brewery.description);
} }
@@ -121,11 +142,15 @@ void BiergartenDataGenerator::GenerateSampleBreweries() {
int BiergartenDataGenerator::Run() { int BiergartenDataGenerator::Run() {
try { try {
LoadGeographicData(); LoadGeographicData();
GenerateSampleBreweries(); auto generator = InitializeGenerator();
auto cities = QueryCitiesWithCountries();
auto enriched = EnrichWithWikipedia(cities);
GenerateBreweries(*generator, enriched);
LogResults();
spdlog::info("\nOK: Pipeline completed successfully"); spdlog::info("\nOK: Pipeline completed successfully");
return 0; return 0;
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::error("ERROR: Pipeline failed: {}", e.what()); spdlog::error("ERROR: Pipeline failed: {}", e.what());
return 1; return 1;
} }

View File

@@ -1,23 +1,25 @@
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <spdlog/spdlog.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include "web_client/web_client.h"
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client) DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {} : web_client_(std::move(web_client)) {}
DataDownloader::~DataDownloader() {} DataDownloader::~DataDownloader() {}
bool DataDownloader::FileExists(const std::string &file_path) { bool DataDownloader::FileExists(const std::string& file_path) {
return std::filesystem::exists(file_path); return std::filesystem::exists(file_path);
} }
std::string std::string DataDownloader::DownloadCountriesDatabase(
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, const std::string& cache_path, const std::string& commit) {
const std::string &commit) {
if (FileExists(cache_path)) { if (FileExists(cache_path)) {
spdlog::info("[DataDownloader] Cache hit: {}", cache_path); spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
return cache_path; return cache_path;
@@ -28,7 +30,8 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cache_path,
short_commit = commit.substr(0, 7); short_commit = commit.substr(0, 7);
} }
std::string url = "https://raw.githubusercontent.com/dr5hn/" std::string url =
"https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" + "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json"; short_commit + "/json/countries+states+cities.json";

View File

@@ -0,0 +1,16 @@
#include "data_generation/llama_generator.h"
#include "llama.h"
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}

View File

@@ -0,0 +1,74 @@
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
BreweryResult LlamaGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string system_prompt =
"You are the brewmaster and owner of a local craft brewery. "
"Write a name and a short, soulful description for your brewery that "
"reflects your pride in the local community and your craft. "
"The tone should be authentic and welcoming, like a note on a "
"chalkboard "
"menu. Output ONLY a single JSON object with keys \"name\" and "
"\"description\". "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific long description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int max_attempts = 3;
std::string raw;
std::string last_error;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validation_error =
ValidateBreweryJsonPublic(raw, name, description);
if (validation_error.empty()) {
return {std::move(name), std::move(description)};
}
last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}

View File

@@ -0,0 +1,57 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200) bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception& e) {
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -0,0 +1,398 @@
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "data_generation/llama_generator.h"
#include "llama.h"
namespace {
std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
}
continue;
}
in_whitespace = false;
out.push_back(static_cast<char>(ch));
}
return Trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized += "...";
return normalized;
}
std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return user_prompt;
}
const llama_chat_message message{"user", user_prompt.c_str()};
std::vector<char> buffer(
std::max<std::size_t>(1024, user_prompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + user_prompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", user_prompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()),
0, true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
}
continue;
}
if (ch == '"') {
in_string = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
error_out.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
} // namespace
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt);
}
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) {
AppendTokenPiece(vocab, token, output);
}
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -0,0 +1,111 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next =
llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
return output;
}

View File

@@ -0,0 +1,42 @@
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}

View File

@@ -0,0 +1,25 @@
#include <stdexcept>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}

View File

@@ -1,734 +0,0 @@
#include <algorithm>
#include <array>
#include <cctype>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h"
namespace {
std::string trim(std::string value) {
auto notSpace = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), notSpace));
value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool inWhitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!inWhitespace) {
out.push_back(' ');
inWhitespace = true;
}
continue;
}
inWhitespace = false;
out.push_back(static_cast<char>(ch));
}
return trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view regionContext,
std::size_t maxChars = 700) {
std::string normalized = CondenseWhitespace(std::string(regionContext));
if (normalized.size() <= maxChars) {
return normalized;
}
normalized.resize(maxChars);
const std::size_t lastSpace = normalized.find_last_of(' ');
if (lastSpace != std::string::npos && lastSpace > maxChars / 2) {
normalized.resize(lastSpace);
}
normalized += "...";
return normalized;
}
std::string stripCommonPrefix(std::string line) {
line = trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = trim(line.substr(i + 1));
}
}
auto stripLabel = [&line](const std::string &label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = trim(line.substr(label.size()));
}
}
};
stripLabel("name:");
stripLabel("brewery name:");
stripLabel("description:");
stripLabel("username:");
stripLabel("bio:");
return trim(std::move(line));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty())
lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>')
continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2)
throw std::runtime_error(errorMessage);
std::string first = trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty())
second += ' ';
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty())
throw std::runtime_error(errorMessage);
return {first, second};
}
std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return userPrompt;
}
const llama_chat_message message{"user", userPrompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string toChatPrompt(const llama_model *model,
const std::string &system_prompt,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + userPrompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", userPrompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (systemPrompt.size() + userPrompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void appendTokenPiece(const llama_vocab *vocab, llama_token token,
std::string &output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamicBuffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(),
static_cast<int32_t>(dynamicBuffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamicBuffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) {
std::size_t start = std::string::npos;
int depth = 0;
bool inString = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (inString) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
inString = false;
}
continue;
}
if (ch == '"') {
inString = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
jsonOut = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut,
std::string &descriptionOut) {
auto validateObject = [&](const boost::json::value &jv,
std::string &errorOut) -> bool {
if (!jv.is_object()) {
errorOut = "JSON root must be an object";
return false;
}
const auto &obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
errorOut = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
errorOut = "JSON field 'description' is missing or not a string";
return false;
}
nameOut = trim(std::string(obj.at("name").as_string().c_str()));
descriptionOut =
trim(std::string(obj.at("description").as_string().c_str()));
if (nameOut.empty()) {
errorOut = "JSON field 'name' must not be empty";
return false;
}
if (descriptionOut.empty()) {
errorOut = "JSON field 'description' must not be empty";
return false;
}
std::string nameLower = nameOut;
std::string descriptionLower = descriptionOut;
std::transform(
nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(descriptionLower.begin(), descriptionLower.end(),
descriptionLower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (nameLower == "string" || descriptionLower == "string") {
errorOut = "JSON appears to be a schema placeholder, not content";
return false;
}
errorOut.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validationError;
if (ec) {
std::string extracted;
if (!extractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
} // namespace
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::Load(const std::string &model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}
std::string LlamaGenerator::Infer(const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt = toChatPrompt(model_, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
const int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
std::string LlamaGenerator::Infer(const std::string &system_prompt,
const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt =
toChatPrompt(model_, system_prompt, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
BreweryResult
LlamaGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string safe_region_context = PrepareRegionContext(region_context);
const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. "
"Your writing is vivid, specific to place, and avoids generic beer "
"cliches. "
"You must output ONLY valid JSON. "
"The JSON schema must be exactly: {\"name\": \"string\", "
"\"description\": \"string\"}. "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int maxAttempts = 3;
std::string raw;
std::string lastError;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validationError =
ValidateBreweryJson(raw, name, description);
if (validationError.empty()) {
return {std::move(name), std::move(description)};
}
lastError = validationError;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validationError);
prompt = "Your previous response was invalid. Error: " + validationError +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
maxAttempts, lastError.empty() ? raw : lastError);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}
UserResult LlamaGenerator::GenerateUser(const std::string &locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int maxAttempts = 3;
std::string raw;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200)
bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception &e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
maxAttempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -1,7 +1,7 @@
#include "data_generation/mock_generator.h" #include <string>
#include <vector>
#include <functional> #include "data_generation/mock_generator.h"
#include <spdlog/spdlog.h>
const std::vector<std::string> MockGenerator::kBreweryAdjectives = { const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
@@ -63,42 +63,3 @@ const std::vector<std::string> MockGenerator::kBios = {
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}
std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}
BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}
UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,12 @@
#include <string>
#include "data_generation/mock_generator.h"
std::size_t MockGenerator::DeterministicHash(const std::string& a,
const std::string& b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}

View File

@@ -0,0 +1,21 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash =
region_context.empty() ? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}

View File

@@ -0,0 +1,13 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,9 @@
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -1,11 +1,13 @@
#include "database/database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *schema = R"( const char* schema = R"(
CREATE TABLE IF NOT EXISTS countries ( CREATE TABLE IF NOT EXISTS countries (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -33,7 +35,7 @@ void SqliteDatabase::InitializeSchema() {
); );
)"; )";
char *errMsg = nullptr; char* errMsg = nullptr;
int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg); int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
std::string error = errMsg ? std::string(errMsg) : "Unknown error"; std::string error = errMsg ? std::string(errMsg) : "Unknown error";
@@ -48,7 +50,7 @@ SqliteDatabase::~SqliteDatabase() {
} }
} }
void SqliteDatabase::Initialize(const std::string &db_path) { void SqliteDatabase::Initialize(const std::string& db_path) {
int rc = sqlite3_open(db_path.c_str(), &db_); int rc = sqlite3_open(db_path.c_str(), &db_);
if (rc) { if (rc) {
throw std::runtime_error("Failed to open SQLite database: " + db_path); throw std::runtime_error("Failed to open SQLite database: " + db_path);
@@ -59,7 +61,7 @@ void SqliteDatabase::Initialize(const std::string &db_path) {
void SqliteDatabase::BeginTransaction() { void SqliteDatabase::BeginTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) != if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) !=
SQLITE_OK) { SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
@@ -70,7 +72,7 @@ void SqliteDatabase::BeginTransaction() {
void SqliteDatabase::CommitTransaction() { void SqliteDatabase::CommitTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
@@ -78,17 +80,17 @@ void SqliteDatabase::CommitTransaction() {
} }
} }
void SqliteDatabase::InsertCountry(int id, const std::string &name, void SqliteDatabase::InsertCountry(int id, const std::string& name,
const std::string &iso2, const std::string& iso2,
const std::string &iso3) { const std::string& iso3) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO countries (id, name, iso2, iso3) INSERT OR IGNORE INTO countries (id, name, iso2, iso3)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare country insert"); throw std::runtime_error("Failed to prepare country insert");
@@ -104,16 +106,17 @@ void SqliteDatabase::InsertCountry(int id, const std::string &name,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id,
const std::string &iso2) { const std::string& name,
const std::string& iso2) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO states (id, country_id, name, iso2) INSERT OR IGNORE INTO states (id, country_id, name, iso2)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare state insert"); throw std::runtime_error("Failed to prepare state insert");
@@ -130,16 +133,16 @@ void SqliteDatabase::InsertState(int id, int country_id, const std::string &name
} }
void SqliteDatabase::InsertCity(int id, int state_id, int country_id, void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
const std::string &name, double latitude, const std::string& name, double latitude,
double longitude) { double longitude) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude) INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare city insert"); throw std::runtime_error("Failed to prepare city insert");
@@ -160,9 +163,9 @@ void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
std::vector<City> SqliteDatabase::QueryCities() { std::vector<City> SqliteDatabase::QueryCities() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<City> cities; std::vector<City> cities;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
const char *query = "SELECT id, name, country_id FROM cities ORDER BY name"; const char* query = "SELECT id, name, country_id FROM cities ORDER BY name";
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
@@ -171,8 +174,8 @@ std::vector<City> SqliteDatabase::QueryCities() {
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
int country_id = sqlite3_column_int(stmt, 2); int country_id = sqlite3_column_int(stmt, 2);
cities.push_back({id, name ? std::string(name) : "", country_id}); cities.push_back({id, name ? std::string(name) : "", country_id});
} }
@@ -185,7 +188,7 @@ std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<Country> countries; std::vector<Country> countries;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, iso3 FROM countries ORDER BY name"; "SELECT id, name, iso2, iso3 FROM countries ORDER BY name";
@@ -201,12 +204,12 @@ std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
const char *iso3 = const char* iso3 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 3));
countries.push_back({id, name ? std::string(name) : "", countries.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", iso2 ? std::string(iso2) : "",
iso3 ? std::string(iso3) : ""}); iso3 ? std::string(iso3) : ""});
@@ -220,7 +223,7 @@ std::vector<State> SqliteDatabase::QueryStates(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<State> states; std::vector<State> states;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, country_id FROM states ORDER BY name"; "SELECT id, name, iso2, country_id FROM states ORDER BY name";
@@ -236,10 +239,10 @@ std::vector<State> SqliteDatabase::QueryStates(int limit) {
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
int country_id = sqlite3_column_int(stmt, 3); int country_id = sqlite3_column_int(stmt, 3);
states.push_back({id, name ? std::string(name) : "", states.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", country_id}); iso2 ? std::string(iso2) : "", country_id});

View File

@@ -1,12 +1,13 @@
#include <chrono> #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "json_handling/json_loader.h" #include <chrono>
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string &json_path, void JsonLoader::LoadWorldCities(const std::string& json_path,
SqliteDatabase &db) { SqliteDatabase& db) {
constexpr size_t kBatchSize = 10000; constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
@@ -19,7 +20,7 @@ void JsonLoader::LoadWorldCities(const std::string &json_path,
try { try {
StreamingJsonParser::Parse( StreamingJsonParser::Parse(
json_path, db, json_path, db,
[&](const CityRecord &record) { [&](const CityRecord& record) {
db.InsertCity(record.id, record.state_id, record.country_id, db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude); record.name, record.latitude, record.longitude);
++citiesProcessed; ++citiesProcessed;

View File

@@ -1,25 +1,26 @@
#include <cstdio> #include "json_handling/stream_parser.h"
#include <stdexcept>
#include <spdlog/spdlog.h>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp> #include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <cstdio>
#include <stdexcept>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler { class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>; friend class boost::json::basic_parser<CityRecordHandler>;
public: public:
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
struct ParseContext { struct ParseContext {
SqliteDatabase *db = nullptr; SqliteDatabase* db = nullptr;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t cities_emitted = 0; size_t cities_emitted = 0;
size_t total_file_size = 0; size_t total_file_size = 0;
@@ -27,10 +28,10 @@ public:
int states_inserted = 0; int states_inserted = 0;
}; };
explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {} explicit CityRecordHandler(ParseContext& ctx) : context(ctx) {}
private: private:
ParseContext &context; ParseContext& context;
int depth = 0; int depth = 0;
bool in_countries_array = false; bool in_countries_array = false;
@@ -51,10 +52,10 @@ private:
std::string state_info[2]; std::string state_info[2];
// Boost.JSON SAX Hooks // Boost.JSON SAX Hooks
bool on_document_begin(boost::system::error_code &) { return true; } bool on_document_begin(boost::system::error_code&) { return true; }
bool on_document_end(boost::system::error_code &) { return true; } bool on_document_end(boost::system::error_code&) { return true; }
bool on_array_begin(boost::system::error_code &) { bool on_array_begin(boost::system::error_code&) {
depth++; depth++;
if (depth == 1) { if (depth == 1) {
in_countries_array = true; in_countries_array = true;
@@ -66,7 +67,7 @@ private:
return true; return true;
} }
bool on_array_end(std::size_t, boost::system::error_code &) { bool on_array_end(std::size_t, boost::system::error_code&) {
if (depth == 1) { if (depth == 1) {
in_countries_array = false; in_countries_array = false;
} else if (depth == 3) { } else if (depth == 3) {
@@ -78,7 +79,7 @@ private:
return true; return true;
} }
bool on_object_begin(boost::system::error_code &) { bool on_object_begin(boost::system::error_code&) {
depth++; depth++;
if (depth == 2 && in_countries_array) { if (depth == 2 && in_countries_array) {
in_country_object = true; in_country_object = true;
@@ -98,7 +99,7 @@ private:
return true; return true;
} }
bool on_object_end(std::size_t, boost::system::error_code &) { bool on_object_end(std::size_t, boost::system::error_code&) {
if (depth == 6 && building_city) { if (depth == 6 && building_city) {
if (current_city.id > 0 && current_state_id > 0 && if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) { current_country_id > 0) {
@@ -113,7 +114,7 @@ private:
context.on_progress(context.cities_emitted, context.on_progress(context.cities_emitted,
context.total_file_size); context.total_file_size);
} }
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("Record parsing failed: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
@@ -124,7 +125,7 @@ private:
context.db->InsertState(current_state_id, current_country_id, context.db->InsertState(current_state_id, current_country_id,
state_info[0], state_info[1]); state_info[0], state_info[1]);
context.states_inserted++; context.states_inserted++;
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("Record parsing failed: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
@@ -135,7 +136,7 @@ private:
context.db->InsertCountry(current_country_id, country_info[0], context.db->InsertCountry(current_country_id, country_info[0],
country_info[1], country_info[2]); country_info[1], country_info[2]);
context.countries_inserted++; context.countries_inserted++;
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("Record parsing failed: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
@@ -147,13 +148,13 @@ private:
} }
bool on_key_part(boost::json::string_view s, std::size_t, bool on_key_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_key_val.append(s.data(), s.size()); current_key_val.append(s.data(), s.size());
return true; return true;
} }
bool on_key(boost::json::string_view s, std::size_t, bool on_key(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_key_val.append(s.data(), s.size()); current_key_val.append(s.data(), s.size());
current_key = current_key_val; current_key = current_key_val;
current_key_val.clear(); current_key_val.clear();
@@ -161,13 +162,13 @@ private:
} }
bool on_string_part(boost::json::string_view s, std::size_t, bool on_string_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_string_val.append(s.data(), s.size()); current_string_val.append(s.data(), s.size());
return true; return true;
} }
bool on_string(boost::json::string_view s, std::size_t, bool on_string(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_string_val.append(s.data(), s.size()); current_string_val.append(s.data(), s.size());
if (building_city && current_key == "name") { if (building_city && current_key == "name") {
@@ -188,12 +189,12 @@ private:
return true; return true;
} }
bool on_number_part(boost::json::string_view, boost::system::error_code &) { bool on_number_part(boost::json::string_view, boost::system::error_code&) {
return true; return true;
} }
bool on_int64(int64_t i, boost::json::string_view, bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code &) { boost::system::error_code&) {
if (building_city && current_key == "id") { if (building_city && current_key == "id") {
current_city.id = static_cast<int>(i); current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") { } else if (in_state_object && current_key == "id") {
@@ -205,12 +206,12 @@ private:
} }
bool on_uint64(uint64_t u, boost::json::string_view, bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code &ec) { boost::system::error_code& ec) {
return on_int64(static_cast<int64_t>(u), "", ec); return on_int64(static_cast<int64_t>(u), "", ec);
} }
bool on_double(double d, boost::json::string_view, bool on_double(double d, boost::json::string_view,
boost::system::error_code &) { boost::system::error_code&) {
if (building_city) { if (building_city) {
if (current_key == "latitude") { if (current_key == "latitude") {
current_city.latitude = d; current_city.latitude = d;
@@ -221,24 +222,23 @@ private:
return true; return true;
} }
bool on_bool(bool, boost::system::error_code &) { return true; } bool on_bool(bool, boost::system::error_code&) { return true; }
bool on_null(boost::system::error_code &) { return true; } bool on_null(boost::system::error_code&) { return true; }
bool on_comment_part(boost::json::string_view, boost::system::error_code &) { bool on_comment_part(boost::json::string_view, boost::system::error_code&) {
return true; return true;
} }
bool on_comment(boost::json::string_view, boost::system::error_code &) { bool on_comment(boost::json::string_view, boost::system::error_code&) {
return true; return true;
} }
}; };
void StreamingJsonParser::Parse( void StreamingJsonParser::Parse(
const std::string &file_path, SqliteDatabase &db, const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
FILE *file = std::fopen(file_path.c_str(), "rb"); FILE* file = std::fopen(file_path.c_str(), "rb");
if (!file) { if (!file) {
throw std::runtime_error("Failed to open JSON file: " + file_path); throw std::runtime_error("Failed to open JSON file: " + file_path);
} }
@@ -252,8 +252,8 @@ void StreamingJsonParser::Parse(
std::rewind(file); std::rewind(file);
} }
CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size,
total_size, 0, 0}; 0, 0};
boost::json::basic_parser<CityRecordHandler> parser( boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx); boost::json::parse_options{}, ctx);
@@ -262,7 +262,7 @@ void StreamingJsonParser::Parse(
boost::system::error_code ec; boost::system::error_code ec;
while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) { while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
char const *p = buf; char const* p = buf;
std::size_t remain = bytes_read; std::size_t remain = bytes_read;
while (remain > 0) { while (remain > 0) {
@@ -284,5 +284,6 @@ void StreamingJsonParser::Parse(
} }
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); ctx.countries_inserted, ctx.states_inserted,
ctx.cities_emitted);
} }

View File

@@ -1,6 +1,8 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdio>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
@@ -17,20 +19,20 @@ CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
namespace { namespace {
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
size_t WriteCallbackString(void *contents, size_t size, size_t nmemb, size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *s = static_cast<std::string *>(userp); auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char *>(contents), realsize); s->append(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
// curl write callback that writes to a file stream // curl write callback that writes to a file stream
size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb, size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *outFile = static_cast<std::ofstream *>(userp); auto* outFile = static_cast<std::ofstream*>(userp);
outFile->write(static_cast<char *>(contents), realsize); outFile->write(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
@@ -38,7 +40,7 @@ size_t WriteCallbackString(void *contents, size_t size, size_t nmemb,
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>; using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() { CurlHandle create_handle() {
CURL *handle = curl_easy_init(); CURL* handle = curl_easy_init();
if (!handle) { if (!handle) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle"); "[CURLWebClient] Failed to initialize libcurl handle");
@@ -46,7 +48,7 @@ CurlHandle create_handle() {
return CurlHandle(handle, &curl_easy_cleanup); return CurlHandle(handle, &curl_easy_cleanup);
} }
void set_common_get_options(CURL *curl, const std::string &url, void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) { long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
@@ -62,20 +64,20 @@ CURLWebClient::CURLWebClient() {}
CURLWebClient::~CURLWebClient() {} CURLWebClient::~CURLWebClient() {}
void CURLWebClient::DownloadToFile(const std::string &url, void CURLWebClient::DownloadToFile(const std::string& url,
const std::string &file_path) { const std::string& file_path) {
auto curl = create_handle(); auto curl = create_handle();
std::ofstream outFile(file_path, std::ios::binary); std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) { if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + throw std::runtime_error(
file_path); "[CURLWebClient] Cannot open file for writing: " + file_path);
} }
set_common_get_options(curl.get(), url, 30L, 300L); set_common_get_options(curl.get(), url, 30L, 300L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void *>(&outFile)); static_cast<void*>(&outFile));
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
outFile.close(); outFile.close();
@@ -98,7 +100,7 @@ void CURLWebClient::DownloadToFile(const std::string &url,
} }
} }
std::string CURLWebClient::Get(const std::string &url) { std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle(); auto curl = create_handle();
std::string response_string; std::string response_string;
@@ -126,9 +128,9 @@ std::string CURLWebClient::Get(const std::string &url) {
return response_string; return response_string;
} }
std::string CURLWebClient::UrlEncode(const std::string &value) { std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char *output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (output) { if (output) {
std::string result(output); std::string result(output);

View File

@@ -1,8 +1,10 @@
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
WikipediaService::WikipediaService(std::shared_ptr<IWebClient> client) #include <boost/json.hpp>
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}
std::string WikipediaService::FetchExtract(std::string_view query) { std::string WikipediaService::FetchExtract(std::string_view query) {
@@ -17,9 +19,9 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
boost::json::value doc = boost::json::parse(body, ec); boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) { if (!ec && doc.is_object()) {
auto &pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto &page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str()); std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'", spdlog::debug("WikipediaService fetched {} chars for '{}'",
@@ -63,11 +65,10 @@ std::string WikipediaService::GetSummary(std::string_view city,
result += regionExtract; result += regionExtract;
} }
if (!beerExtract.empty()) { if (!beerExtract.empty()) {
if (!result.empty()) if (!result.empty()) result += "\n\n";
result += "\n\n";
result += beerExtract; result += beerExtract;
} }
} catch (const std::runtime_error &e) { } catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
e.what()); e.what());
} }