Compare commits

4 Commits

Author SHA1 Message Date
Aaron Po
077f6ab4ae edit prompt 2026-04-02 22:56:18 -04:00
Aaron Po
534403734a Refactor BiergartenDataGenerator and LlamaGenerator 2026-04-02 22:46:00 -04:00
Aaron Po
3af053f0eb format codebase 2026-04-02 21:46:46 -04:00
Aaron Po
ba165d8aa7 Separate llama generator class src file into method files 2026-04-02 21:37:46 -04:00
34 changed files with 1879 additions and 1754 deletions

View File

@@ -1,10 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
Standard: c++23 ColumnLimit: 80
ColumnLimit: 100 IndentWidth: 3
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
... ...

View File

@@ -83,8 +83,18 @@ set(PIPELINE_SOURCES
src/data_generation/data_downloader.cpp src/data_generation/data_downloader.cpp
src/database/database.cpp src/database/database.cpp
src/json_handling/json_loader.cpp src/json_handling/json_loader.cpp
src/data_generation/llama_generator.cpp src/data_generation/llama/destructor.cpp
src/data_generation/mock_generator.cpp src/data_generation/llama/set_sampling_options.cpp
src/data_generation/llama/load.cpp
src/data_generation/llama/infer.cpp
src/data_generation/llama/generate_brewery.cpp
src/data_generation/llama/generate_user.cpp
src/data_generation/llama/helpers.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/load.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
src/json_handling/stream_parser.cpp src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp src/wikipedia/wikipedia_service.cpp
src/main.cpp src/main.cpp

View File

@@ -3,114 +3,151 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "database/database.h" #include "database/database.h"
#include "web_client/web_client.h" #include "web_client/web_client.h"
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
/** /**
* @brief Program options for the Biergarten pipeline application. * @brief Program options for the Biergarten pipeline application.
*/ */
struct ApplicationOptions { struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with use_mocked. /// @brief Path to the LLM model file (gguf format); mutually exclusive with
std::string model_path; /// use_mocked.
std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with model_path. /// @brief Use mocked generator instead of LLM; mutually exclusive with
bool use_mocked = false; /// model_path.
bool use_mocked = false;
/// @brief Directory for cached JSON and database files. /// @brief Directory for cached JSON and database files.
std::string cache_dir; std::string cache_dir;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 0.8f; float temperature = 0.8f;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more random). /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
float top_p = 0.92f; /// random).
float top_p = 0.92f;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
/// @brief Git commit hash for database consistency (always pinned to c5eb7772). /// @brief Git commit hash for database consistency (always pinned to
std::string commit = "c5eb7772"; /// c5eb7772).
std::string commit = "c5eb7772";
}; };
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_
/** /**
* @brief Main data generator class for the Biergarten pipeline. * @brief Main data generator class for the Biergarten pipeline.
* *
* This class encapsulates the core logic for generating brewery data. * This class encapsulates the core logic for generating brewery data.
* It handles database initialization, data loading/downloading, and brewery generation. * It handles database initialization, data loading/downloading, and brewery
* generation.
*/ */
class BiergartenDataGenerator { class BiergartenDataGenerator {
public: public:
/** /**
* @brief Construct a BiergartenDataGenerator with injected dependencies. * @brief Construct a BiergartenDataGenerator with injected dependencies.
* *
* @param options Application configuration options. * @param options Application configuration options.
* @param web_client HTTP client for downloading data. * @param web_client HTTP client for downloading data.
* @param database SQLite database instance. * @param database SQLite database instance.
*/ */
BiergartenDataGenerator(const ApplicationOptions &options, BiergartenDataGenerator(const ApplicationOptions& options,
std::shared_ptr<WebClient> web_client, std::shared_ptr<WebClient> web_client,
SqliteDatabase &database); SqliteDatabase& database);
/** /**
* @brief Run the data generation pipeline. * @brief Run the data generation pipeline.
* *
* Performs the following steps: * Performs the following steps:
* 1. Initialize database * 1. Initialize database
* 2. Download geographic data if needed * 2. Download geographic data if needed
* 3. Initialize the generator (LLM or Mock) * 3. Initialize the generator (LLM or Mock)
* 4. Generate brewery data for sample cities * 4. Generate brewery data for sample cities
* *
* @return 0 on success, 1 on failure. * @return 0 on success, 1 on failure.
*/ */
int Run(); int Run();
private: private:
/// @brief Immutable application options. /// @brief Immutable application options.
const ApplicationOptions options_; const ApplicationOptions options_;
/// @brief Shared HTTP client dependency. /// @brief Shared HTTP client dependency.
std::shared_ptr<WebClient> webClient_; std::shared_ptr<WebClient> webClient_;
/// @brief Database dependency. /// @brief Database dependency.
SqliteDatabase &database_; SqliteDatabase& database_;
/** /**
* @brief Initialize the data generator based on options. * @brief Enriched city data with Wikipedia context.
* */
* Creates either a MockGenerator (if no model path) or LlamaGenerator. struct EnrichedCity {
* int city_id;
* @return A unique_ptr to the initialized generator. std::string city_name;
*/ std::string country_name;
std::unique_ptr<DataGenerator> InitializeGenerator(); std::string region_context;
};
/** /**
* @brief Download and load geographic data if not cached. * @brief Initialize the data generator based on options.
*/ *
void LoadGeographicData(); * Creates either a MockGenerator (if no model path) or LlamaGenerator.
*
* @return A unique_ptr to the initialized generator.
*/
std::unique_ptr<DataGenerator> InitializeGenerator();
/** /**
* @brief Generate sample breweries for demonstration. * @brief Download and load geographic data if not cached.
*/ */
void GenerateSampleBreweries(); void LoadGeographicData();
/** /**
* @brief Helper struct to store generated brewery data. * @brief Query cities from database and build country name map.
*/ *
struct GeneratedBrewery { * @return Vector of (City, country_name) pairs capped at 30 entries.
int cityId; */
std::string cityName; std::vector<std::pair<City, std::string>> QueryCitiesWithCountries();
BreweryResult brewery;
};
/// @brief Stores generated brewery data. /**
std::vector<GeneratedBrewery> generatedBreweries_; * @brief Enrich cities with Wikipedia summaries.
*
* @param cities Vector of (City, country_name) pairs.
* @return Vector of enriched city data with context.
*/
std::vector<EnrichedCity> EnrichWithWikipedia(
const std::vector<std::pair<City, std::string>>& cities);
/**
* @brief Generate breweries for enriched cities.
*
* @param generator The data generator instance.
* @param cities Vector of enriched city data.
*/
void GenerateBreweries(DataGenerator& generator,
const std::vector<EnrichedCity>& cities);
/**
* @brief Log the generated brewery results.
*/
void LogResults() const;
/**
* @brief Helper struct to store generated brewery data.
*/
struct GeneratedBrewery {
int city_id;
std::string city_name;
BreweryResult brewery;
};
/// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generatedBreweries_;
}; };
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_

View File

@@ -9,22 +9,23 @@
/// @brief Downloads and caches source geography JSON payloads. /// @brief Downloads and caches source geography JSON payloads.
class DataDownloader { class DataDownloader {
public: public:
/// @brief Initializes global curl state used by this downloader. /// @brief Initializes global curl state used by this downloader.
explicit DataDownloader(std::shared_ptr<WebClient> web_client); explicit DataDownloader(std::shared_ptr<WebClient> web_client);
/// @brief Cleans up global curl state. /// @brief Cleans up global curl state.
~DataDownloader(); ~DataDownloader();
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string &cache_path, const std::string& cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string& commit =
); "c5eb7772" // Stable commit: 2026-03-28 export
);
private: private:
static bool FileExists(const std::string &file_path); static bool FileExists(const std::string& file_path);
std::shared_ptr<WebClient> web_client_; std::shared_ptr<WebClient> web_client_;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_

View File

@@ -4,26 +4,26 @@
#include <string> #include <string>
struct BreweryResult { struct BreweryResult {
std::string name; std::string name;
std::string description; std::string description;
}; };
struct UserResult { struct UserResult {
std::string username; std::string username;
std::string bio; std::string bio;
}; };
class DataGenerator { class DataGenerator {
public: public:
virtual ~DataGenerator() = default; virtual ~DataGenerator() = default;
virtual void Load(const std::string &model_path) = 0; virtual void Load(const std::string& model_path) = 0;
virtual BreweryResult GenerateBrewery(const std::string &city_name, virtual BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) = 0; const std::string& region_context) = 0;
virtual UserResult GenerateUser(const std::string &locale) = 0; virtual UserResult GenerateUser(const std::string& locale) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -10,32 +10,35 @@ struct llama_model;
struct llama_context; struct llama_context;
class LlamaGenerator final : public DataGenerator { class LlamaGenerator final : public DataGenerator {
public: public:
LlamaGenerator() = default; LlamaGenerator() = default;
~LlamaGenerator() override; ~LlamaGenerator() override;
void SetSamplingOptions(float temperature, float top_p, int seed = -1); void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void Load(const std::string &model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) override; const std::string& region_context) override;
UserResult GenerateUser(const std::string &locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
std::string Infer(const std::string &prompt, int max_tokens = 10000); std::string Infer(const std::string& prompt, int max_tokens = 10000);
// Overload that allows passing a system message separately so chat-capable // Overload that allows passing a system message separately so chat-capable
// models receive a proper system role instead of having the system text // models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal // concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output). // reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt, std::string Infer(const std::string& system_prompt,
int max_tokens = 10000); const std::string& prompt, int max_tokens = 10000);
llama_model *model_ = nullptr; std::string InferFormatted(const std::string& formatted_prompt,
llama_context *context_ = nullptr; int max_tokens = 10000);
float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f; llama_model* model_ = nullptr;
uint32_t sampling_seed_ = 0xFFFFFFFFu; llama_context* context_ = nullptr;
float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -0,0 +1,32 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#include <string>
#include <utility>
struct llama_model;
struct llama_vocab;
typedef int llama_token;
// Helper functions for LlamaGenerator methods
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700);
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message);
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt);
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output);
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out);
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_

View File

@@ -1,27 +1,28 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "data_generation/data_generator.h"
class MockGenerator final : public DataGenerator { class MockGenerator final : public DataGenerator {
public: public:
void Load(const std::string &model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) override; const std::string& region_context) override;
UserResult GenerateUser(const std::string &locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
static std::size_t DeterministicHash(const std::string &a, static std::size_t DeterministicHash(const std::string& a,
const std::string &b); const std::string& b);
static const std::vector<std::string> kBreweryAdjectives; static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns; static const std::vector<std::string> kBreweryNouns;
static const std::vector<std::string> kBreweryDescriptions; static const std::vector<std::string> kBreweryDescriptions;
static const std::vector<std::string> kUsernames; static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios; static const std::vector<std::string> kBios;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -1,83 +1,84 @@
#ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
struct Country { struct Country {
/// @brief Country identifier from the source dataset. /// @brief Country identifier from the source dataset.
int id; int id;
/// @brief Country display name. /// @brief Country display name.
std::string name; std::string name;
/// @brief ISO 3166-1 alpha-2 code. /// @brief ISO 3166-1 alpha-2 code.
std::string iso2; std::string iso2;
/// @brief ISO 3166-1 alpha-3 code. /// @brief ISO 3166-1 alpha-3 code.
std::string iso3; std::string iso3;
}; };
struct State { struct State {
/// @brief State or province identifier from the source dataset. /// @brief State or province identifier from the source dataset.
int id; int id;
/// @brief State or province display name. /// @brief State or province display name.
std::string name; std::string name;
/// @brief State or province short code. /// @brief State or province short code.
std::string iso2; std::string iso2;
/// @brief Parent country identifier. /// @brief Parent country identifier.
int country_id; int country_id;
}; };
struct City { struct City {
/// @brief City identifier from the source dataset. /// @brief City identifier from the source dataset.
int id; int id;
/// @brief City display name. /// @brief City display name.
std::string name; std::string name;
/// @brief Parent country identifier. /// @brief Parent country identifier.
int country_id; int country_id;
}; };
/// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks.
class SqliteDatabase { class SqliteDatabase {
private: private:
sqlite3 *db_ = nullptr; sqlite3* db_ = nullptr;
std::mutex db_mutex_; std::mutex db_mutex_;
void InitializeSchema(); void InitializeSchema();
public: public:
/// @brief Closes the SQLite connection if initialized. /// @brief Closes the SQLite connection if initialized.
~SqliteDatabase(); ~SqliteDatabase();
/// @brief Opens the SQLite database at db_path and creates schema objects. /// @brief Opens the SQLite database at db_path and creates schema objects.
void Initialize(const std::string &db_path = ":memory:"); void Initialize(const std::string& db_path = ":memory:");
/// @brief Starts a database transaction for batched writes. /// @brief Starts a database transaction for batched writes.
void BeginTransaction(); void BeginTransaction();
/// @brief Commits the active database transaction. /// @brief Commits the active database transaction.
void CommitTransaction(); void CommitTransaction();
/// @brief Inserts a country row. /// @brief Inserts a country row.
void InsertCountry(int id, const std::string &name, const std::string &iso2, void InsertCountry(int id, const std::string& name, const std::string& iso2,
const std::string &iso3); const std::string& iso3);
/// @brief Inserts a state row linked to a country. /// @brief Inserts a state row linked to a country.
void InsertState(int id, int country_id, const std::string &name, void InsertState(int id, int country_id, const std::string& name,
const std::string &iso2); const std::string& iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int state_id, int country_id, const std::string &name, void InsertCity(int id, int state_id, int country_id,
double latitude, double longitude); const std::string& name, double latitude, double longitude);
/// @brief Returns city records including parent country id. /// @brief Returns city records including parent country id.
std::vector<City> QueryCities(); std::vector<City> QueryCities();
/// @brief Returns countries with optional row limit. /// @brief Returns countries with optional row limit.
std::vector<Country> QueryCountries(int limit = 0); std::vector<Country> QueryCountries(int limit = 0);
/// @brief Returns states with optional row limit. /// @brief Returns states with optional row limit.
std::vector<State> QueryStates(int limit = 0); std::vector<State> QueryStates(int limit = 0);
}; };
#endif // BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #endif // BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_

View File

@@ -1,15 +1,17 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include <string>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing. /// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON file and writes country/state/city rows into db. /// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); static void LoadWorldCities(const std::string& json_path,
SqliteDatabase& db);
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,51 +1,52 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "database/database.h"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
class SqliteDatabase; class SqliteDatabase;
/// @brief In-memory representation of one parsed city entry. /// @brief In-memory representation of one parsed city entry.
struct CityRecord { struct CityRecord {
int id; int id;
int state_id; int state_id;
int country_id; int country_id;
std::string name; std::string name;
double latitude; double latitude;
double longitude; double longitude;
}; };
/// @brief Streaming SAX parser that emits city records during traversal. /// @brief Streaming SAX parser that emits city records during traversal.
class StreamingJsonParser { class StreamingJsonParser {
public: public:
/// @brief Parses file_path and invokes callbacks for city rows and progress. /// @brief Parses file_path and invokes callbacks for city rows and progress.
static void Parse(const std::string &file_path, SqliteDatabase &db, static void Parse(const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress = nullptr); std::function<void(size_t, size_t)> on_progress = nullptr);
private: private:
/// @brief Mutable SAX handler state while traversing nested JSON arrays. /// @brief Mutable SAX handler state while traversing nested JSON arrays.
struct ParseState { struct ParseState {
int current_country_id = 0; int current_country_id = 0;
int current_state_id = 0; int current_state_id = 0;
CityRecord current_city = {}; CityRecord current_city = {};
bool building_city = false; bool building_city = false;
std::string current_key; std::string current_key;
int array_depth = 0; int array_depth = 0;
int object_depth = 0; int object_depth = 0;
bool in_countries_array = false; bool in_countries_array = false;
bool in_states_array = false; bool in_states_array = false;
bool in_cities_array = false; bool in_cities_array = false;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t bytes_processed = 0; size_t bytes_processed = 0;
}; };
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_

View File

@@ -1,29 +1,30 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory> #include <memory>
#include "web_client/web_client.h"
// RAII for curl_global_init/cleanup. // RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl // An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application. // operations and exist for the lifetime of the application.
class CurlGlobalState { class CurlGlobalState {
public: public:
CurlGlobalState(); CurlGlobalState();
~CurlGlobalState(); ~CurlGlobalState();
CurlGlobalState(const CurlGlobalState &) = delete; CurlGlobalState(const CurlGlobalState&) = delete;
CurlGlobalState &operator=(const CurlGlobalState &) = delete; CurlGlobalState& operator=(const CurlGlobalState&) = delete;
}; };
class CURLWebClient : public WebClient { class CURLWebClient : public WebClient {
public: public:
CURLWebClient(); CURLWebClient();
~CURLWebClient() override; ~CURLWebClient() override;
void DownloadToFile(const std::string &url, void DownloadToFile(const std::string& url,
const std::string &file_path) override; const std::string& file_path) override;
std::string Get(const std::string &url) override; std::string Get(const std::string& url) override;
std::string UrlEncode(const std::string &value) override; std::string UrlEncode(const std::string& value) override;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -4,19 +4,19 @@
#include <string> #include <string>
class WebClient { class WebClient {
public: public:
virtual ~WebClient() = default; virtual ~WebClient() = default;
// Downloads content from a URL to a file. Throws on error. // Downloads content from a URL to a file. Throws on error.
virtual void DownloadToFile(const std::string &url, virtual void DownloadToFile(const std::string& url,
const std::string &file_path) = 0; const std::string& file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on // Performs a GET request and returns the response body as a string. Throws
// error. // on error.
virtual std::string Get(const std::string &url) = 0; virtual std::string Get(const std::string& url) = 0;
// URL-encodes a string. // URL-encodes a string.
virtual std::string UrlEncode(const std::string &value) = 0; virtual std::string UrlEncode(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -10,18 +10,18 @@
/// @brief Provides cached Wikipedia summary lookups for city and country pairs. /// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService { class WikipediaService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::shared_ptr<WebClient> client); explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia summary extract for city and country. /// @brief Returns the Wikipedia summary extract for city and country.
[[nodiscard]] std::string GetSummary(std::string_view city, [[nodiscard]] std::string GetSummary(std::string_view city,
std::string_view country); std::string_view country);
private: private:
std::string FetchExtract(std::string_view query); std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_; std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_; std::unordered_map<std::string, std::string> cache_;
}; };
#endif // BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_ #endif // BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_

View File

@@ -1,132 +1,157 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <filesystem> #include <filesystem>
#include <unordered_map> #include <unordered_map>
#include <spdlog/spdlog.h>
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "json_handling/json_loader.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
#include "json_handling/json_loader.h"
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
BiergartenDataGenerator::BiergartenDataGenerator( BiergartenDataGenerator::BiergartenDataGenerator(
const ApplicationOptions &options, const ApplicationOptions& options, std::shared_ptr<WebClient> web_client,
std::shared_ptr<WebClient> web_client, SqliteDatabase& database)
SqliteDatabase &database)
: options_(options), webClient_(web_client), database_(database) {} : options_(options), webClient_(web_client), database_(database) {}
std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() { std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() {
spdlog::info("Initializing brewery generator..."); spdlog::info("Initializing brewery generator...");
std::unique_ptr<DataGenerator> generator; std::unique_ptr<DataGenerator> generator;
if (options_.model_path.empty()) { if (options_.model_path.empty()) {
generator = std::make_unique<MockGenerator>(); generator = std::make_unique<MockGenerator>();
spdlog::info("[Generator] Using MockGenerator (no model path provided)"); spdlog::info("[Generator] Using MockGenerator (no model path provided)");
} else { } else {
auto llama_generator = std::make_unique<LlamaGenerator>(); auto llama_generator = std::make_unique<LlamaGenerator>();
llama_generator->SetSamplingOptions(options_.temperature, options_.top_p, llama_generator->SetSamplingOptions(options_.temperature, options_.top_p,
options_.seed); options_.seed);
spdlog::info( spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, " "[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"seed={})", "seed={})",
options_.model_path, options_.temperature, options_.top_p, options_.model_path, options_.temperature, options_.top_p,
options_.seed); options_.seed);
generator = std::move(llama_generator); generator = std::move(llama_generator);
} }
generator->Load(options_.model_path); generator->Load(options_.model_path);
return generator; return generator;
} }
void BiergartenDataGenerator::LoadGeographicData() { void BiergartenDataGenerator::LoadGeographicData() {
std::string json_path = options_.cache_dir + "/countries+states+cities.json"; std::string json_path = options_.cache_dir + "/countries+states+cities.json";
std::string db_path = options_.cache_dir + "/biergarten-pipeline.db"; std::string db_path = options_.cache_dir + "/biergarten-pipeline.db";
bool has_json_cache = std::filesystem::exists(json_path); bool has_json_cache = std::filesystem::exists(json_path);
bool has_db_cache = std::filesystem::exists(db_path); bool has_db_cache = std::filesystem::exists(db_path);
spdlog::info("Initializing SQLite database at {}...", db_path); spdlog::info("Initializing SQLite database at {}...", db_path);
database_.Initialize(db_path); database_.Initialize(db_path);
if (has_db_cache && has_json_cache) { if (has_db_cache && has_json_cache) {
spdlog::info("[Pipeline] Cache hit: skipping download and parse"); spdlog::info("[Pipeline] Cache hit: skipping download and parse");
} else { } else {
spdlog::info("\n[Pipeline] Downloading geographic data from GitHub..."); spdlog::info("\n[Pipeline] Downloading geographic data from GitHub...");
DataDownloader downloader(webClient_); DataDownloader downloader(webClient_);
downloader.DownloadCountriesDatabase(json_path, options_.commit); downloader.DownloadCountriesDatabase(json_path, options_.commit);
JsonLoader::LoadWorldCities(json_path, database_); JsonLoader::LoadWorldCities(json_path, database_);
} }
} }
void BiergartenDataGenerator::GenerateSampleBreweries() { std::vector<std::pair<City, std::string>>
auto generator = InitializeGenerator(); BiergartenDataGenerator::QueryCitiesWithCountries() {
WikipediaService wikipedia_service(webClient_); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); auto cities = database_.QueryCities();
auto countries = database_.QueryCountries(50); // Build a quick map of country id -> name for per-city lookups.
auto states = database_.QueryStates(50); auto all_countries = database_.QueryCountries(0);
auto cities = database_.QueryCities(); std::unordered_map<int, std::string> country_map;
for (const auto& c : all_countries) {
country_map[c.id] = c.name;
}
// Build a quick map of country id -> name for per-city lookups. spdlog::info("\nTotal records loaded:");
auto all_countries = database_.QueryCountries(0); spdlog::info(" Countries: {}", database_.QueryCountries(0).size());
std::unordered_map<int, std::string> country_map; spdlog::info(" States: {}", database_.QueryStates(0).size());
for (const auto &c : all_countries) spdlog::info(" Cities: {}", cities.size());
country_map[c.id] = c.name;
spdlog::info("\nTotal records loaded:"); // Cap at 30 entries.
spdlog::info(" Countries: {}", database_.QueryCountries(0).size()); const size_t sample_count = std::min(size_t(30), cities.size());
spdlog::info(" States: {}", database_.QueryStates(0).size()); std::vector<std::pair<City, std::string>> result;
spdlog::info(" Cities: {}", cities.size());
generatedBreweries_.clear(); for (size_t i = 0; i < sample_count; i++) {
const size_t sample_count = std::min(size_t(30), cities.size()); const auto& city = cities[i];
std::string country_name;
const auto country_it = country_map.find(city.country_id);
if (country_it != country_map.end()) {
country_name = country_it->second;
}
result.push_back({city, country_name});
}
spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); return result;
for (size_t i = 0; i < sample_count; i++) { }
const auto &city = cities[i];
const int city_id = city.id;
const std::string city_name = city.name;
std::string local_country; std::vector<BiergartenDataGenerator::EnrichedCity>
const auto country_it = country_map.find(city.country_id); BiergartenDataGenerator::EnrichWithWikipedia(
if (country_it != country_map.end()) { const std::vector<std::pair<City, std::string>>& cities) {
local_country = country_it->second; WikipediaService wikipedia_service(webClient_);
} std::vector<EnrichedCity> enriched;
const std::string region_context = for (const auto& [city, country_name] : cities) {
wikipedia_service.GetSummary(city_name, local_country); const std::string region_context =
spdlog::debug("[Pipeline] Region context for {}: {}", city_name, wikipedia_service.GetSummary(city.name, country_name);
region_context); spdlog::debug("[Pipeline] Region context for {}: {}", city.name,
region_context);
auto brewery = enriched.push_back({city.id, city.name, country_name, region_context});
generator->GenerateBrewery(city_name, local_country, region_context); }
generatedBreweries_.push_back({city_id, city_name, brewery});
}
spdlog::info("\n=== GENERATED DATA DUMP ==="); return enriched;
for (size_t i = 0; i < generatedBreweries_.size(); i++) { }
const auto &entry = generatedBreweries_[i];
spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId, void BiergartenDataGenerator::GenerateBreweries(
entry.cityName); DataGenerator& generator, const std::vector<EnrichedCity>& cities) {
spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); generatedBreweries_.clear();
}
for (const auto& enriched_city : cities) {
auto brewery = generator.GenerateBrewery(enriched_city.city_name,
enriched_city.country_name,
enriched_city.region_context);
generatedBreweries_.push_back(
{enriched_city.city_id, enriched_city.city_name, brewery});
}
}
void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ===");
for (size_t i = 0; i < generatedBreweries_.size(); i++) {
const auto& entry = generatedBreweries_[i];
spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.city_id,
entry.city_name);
spdlog::info(" brewery_name=\"{}\"", entry.brewery.name);
spdlog::info(" brewery_description=\"{}\"", entry.brewery.description);
}
} }
int BiergartenDataGenerator::Run() { int BiergartenDataGenerator::Run() {
try { try {
LoadGeographicData(); LoadGeographicData();
GenerateSampleBreweries(); auto generator = InitializeGenerator();
auto cities = QueryCitiesWithCountries();
auto enriched = EnrichWithWikipedia(cities);
GenerateBreweries(*generator, enriched);
LogResults();
spdlog::info("\nOK: Pipeline completed successfully"); spdlog::info("\nOK: Pipeline completed successfully");
return 0; return 0;
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::error("ERROR: Pipeline failed: {}", e.what()); spdlog::error("ERROR: Pipeline failed: {}", e.what());
return 1; return 1;
} }
} }

View File

@@ -1,46 +1,49 @@
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <spdlog/spdlog.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include "web_client/web_client.h"
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client) DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {} : web_client_(std::move(web_client)) {}
DataDownloader::~DataDownloader() {} DataDownloader::~DataDownloader() {}
bool DataDownloader::FileExists(const std::string &file_path) { bool DataDownloader::FileExists(const std::string& file_path) {
return std::filesystem::exists(file_path); return std::filesystem::exists(file_path);
} }
std::string std::string DataDownloader::DownloadCountriesDatabase(
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, const std::string& cache_path, const std::string& commit) {
const std::string &commit) { if (FileExists(cache_path)) {
if (FileExists(cache_path)) { spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
spdlog::info("[DataDownloader] Cache hit: {}", cache_path); return cache_path;
return cache_path; }
}
std::string short_commit = commit; std::string short_commit = commit;
if (commit.length() > 7) { if (commit.length() > 7) {
short_commit = commit.substr(0, 7); short_commit = commit.substr(0, 7);
} }
std::string url = "https://raw.githubusercontent.com/dr5hn/" std::string url =
"countries-states-cities-database/" + "https://raw.githubusercontent.com/dr5hn/"
short_commit + "/json/countries+states+cities.json"; "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json";
spdlog::info("[DataDownloader] Downloading: {}", url); spdlog::info("[DataDownloader] Downloading: {}", url);
web_client_->DownloadToFile(url, cache_path); web_client_->DownloadToFile(url, cache_path);
std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate); std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate);
std::streamsize size = file_check.tellg(); std::streamsize size = file_check.tellg();
file_check.close(); file_check.close();
spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)", spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)",
cache_path, (size / (1024.0 * 1024.0))); cache_path, (size / (1024.0 * 1024.0)));
return cache_path; return cache_path;
} }

View File

@@ -0,0 +1,16 @@
#include "data_generation/llama_generator.h"
#include "llama.h"
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}

View File

@@ -0,0 +1,74 @@
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
BreweryResult LlamaGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string system_prompt =
"You are the brewmaster and owner of a local craft brewery. "
"Write a name and a short, soulful description for your brewery that "
"reflects your pride in the local community and your craft. "
"The tone should be authentic and welcoming, like a note on a "
"chalkboard "
"menu. Output ONLY a single JSON object with keys \"name\" and "
"\"description\". "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific long description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int max_attempts = 3;
std::string raw;
std::string last_error;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validation_error =
ValidateBreweryJsonPublic(raw, name, description);
if (validation_error.empty()) {
return {std::move(name), std::move(description)};
}
last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}

View File

@@ -0,0 +1,57 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200) bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception& e) {
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -0,0 +1,398 @@
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "data_generation/llama_generator.h"
#include "llama.h"
namespace {
std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
}
continue;
}
in_whitespace = false;
out.push_back(static_cast<char>(ch));
}
return Trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized += "...";
return normalized;
}
std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return user_prompt;
}
const llama_chat_message message{"user", user_prompt.c_str()};
std::vector<char> buffer(
std::max<std::size_t>(1024, user_prompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + user_prompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", user_prompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()),
0, true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
}
continue;
}
if (ch == '"') {
in_string = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
error_out.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
} // namespace
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt);
}
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) {
AppendTokenPiece(vocab, token, output);
}
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -0,0 +1,111 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next =
llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
return output;
}

View File

@@ -0,0 +1,42 @@
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}

View File

@@ -0,0 +1,25 @@
#include <stdexcept>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}

View File

@@ -1,734 +0,0 @@
#include <algorithm>
#include <array>
#include <cctype>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h"
namespace {
std::string trim(std::string value) {
auto notSpace = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), notSpace));
value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool inWhitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!inWhitespace) {
out.push_back(' ');
inWhitespace = true;
}
continue;
}
inWhitespace = false;
out.push_back(static_cast<char>(ch));
}
return trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view regionContext,
std::size_t maxChars = 700) {
std::string normalized = CondenseWhitespace(std::string(regionContext));
if (normalized.size() <= maxChars) {
return normalized;
}
normalized.resize(maxChars);
const std::size_t lastSpace = normalized.find_last_of(' ');
if (lastSpace != std::string::npos && lastSpace > maxChars / 2) {
normalized.resize(lastSpace);
}
normalized += "...";
return normalized;
}
std::string stripCommonPrefix(std::string line) {
line = trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = trim(line.substr(i + 1));
}
}
auto stripLabel = [&line](const std::string &label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = trim(line.substr(label.size()));
}
}
};
stripLabel("name:");
stripLabel("brewery name:");
stripLabel("description:");
stripLabel("username:");
stripLabel("bio:");
return trim(std::move(line));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty())
lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>')
continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2)
throw std::runtime_error(errorMessage);
std::string first = trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty())
second += ' ';
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty())
throw std::runtime_error(errorMessage);
return {first, second};
}
std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return userPrompt;
}
const llama_chat_message message{"user", userPrompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string toChatPrompt(const llama_model *model,
const std::string &system_prompt,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + userPrompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", userPrompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (systemPrompt.size() + userPrompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void appendTokenPiece(const llama_vocab *vocab, llama_token token,
std::string &output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamicBuffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(),
static_cast<int32_t>(dynamicBuffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamicBuffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) {
std::size_t start = std::string::npos;
int depth = 0;
bool inString = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (inString) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
inString = false;
}
continue;
}
if (ch == '"') {
inString = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
jsonOut = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut,
std::string &descriptionOut) {
auto validateObject = [&](const boost::json::value &jv,
std::string &errorOut) -> bool {
if (!jv.is_object()) {
errorOut = "JSON root must be an object";
return false;
}
const auto &obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
errorOut = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
errorOut = "JSON field 'description' is missing or not a string";
return false;
}
nameOut = trim(std::string(obj.at("name").as_string().c_str()));
descriptionOut =
trim(std::string(obj.at("description").as_string().c_str()));
if (nameOut.empty()) {
errorOut = "JSON field 'name' must not be empty";
return false;
}
if (descriptionOut.empty()) {
errorOut = "JSON field 'description' must not be empty";
return false;
}
std::string nameLower = nameOut;
std::string descriptionLower = descriptionOut;
std::transform(
nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(descriptionLower.begin(), descriptionLower.end(),
descriptionLower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (nameLower == "string" || descriptionLower == "string") {
errorOut = "JSON appears to be a schema placeholder, not content";
return false;
}
errorOut.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validationError;
if (ec) {
std::string extracted;
if (!extractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
} // namespace
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::Load(const std::string &model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}
std::string LlamaGenerator::Infer(const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt = toChatPrompt(model_, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
const int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
std::string LlamaGenerator::Infer(const std::string &system_prompt,
const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt =
toChatPrompt(model_, system_prompt, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
BreweryResult
LlamaGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string safe_region_context = PrepareRegionContext(region_context);
const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. "
"Your writing is vivid, specific to place, and avoids generic beer "
"cliches. "
"You must output ONLY valid JSON. "
"The JSON schema must be exactly: {\"name\": \"string\", "
"\"description\": \"string\"}. "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int maxAttempts = 3;
std::string raw;
std::string lastError;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validationError =
ValidateBreweryJson(raw, name, description);
if (validationError.empty()) {
return {std::move(name), std::move(description)};
}
lastError = validationError;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validationError);
prompt = "Your previous response was invalid. Error: " + validationError +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
maxAttempts, lastError.empty() ? raw : lastError);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}
UserResult LlamaGenerator::GenerateUser(const std::string &locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int maxAttempts = 3;
std::string raw;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200)
bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception &e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
maxAttempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -1,7 +1,7 @@
#include "data_generation/mock_generator.h" #include <string>
#include <vector>
#include <functional> #include "data_generation/mock_generator.h"
#include <spdlog/spdlog.h>
const std::vector<std::string> MockGenerator::kBreweryAdjectives = { const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
@@ -63,42 +63,3 @@ const std::vector<std::string> MockGenerator::kBios = {
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}
std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}
BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}
UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,12 @@
#include <string>
#include "data_generation/mock_generator.h"
std::size_t MockGenerator::DeterministicHash(const std::string& a,
const std::string& b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}

View File

@@ -0,0 +1,21 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash =
region_context.empty() ? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}

View File

@@ -0,0 +1,13 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,9 @@
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -1,11 +1,13 @@
#include "database/database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *schema = R"( const char* schema = R"(
CREATE TABLE IF NOT EXISTS countries ( CREATE TABLE IF NOT EXISTS countries (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -33,218 +35,219 @@ void SqliteDatabase::InitializeSchema() {
); );
)"; )";
char *errMsg = nullptr; char* errMsg = nullptr;
int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg); int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
std::string error = errMsg ? std::string(errMsg) : "Unknown error"; std::string error = errMsg ? std::string(errMsg) : "Unknown error";
sqlite3_free(errMsg); sqlite3_free(errMsg);
throw std::runtime_error("Failed to create schema: " + error); throw std::runtime_error("Failed to create schema: " + error);
} }
} }
SqliteDatabase::~SqliteDatabase() { SqliteDatabase::~SqliteDatabase() {
if (db_) { if (db_) {
sqlite3_close(db_); sqlite3_close(db_);
} }
} }
void SqliteDatabase::Initialize(const std::string &db_path) { void SqliteDatabase::Initialize(const std::string& db_path) {
int rc = sqlite3_open(db_path.c_str(), &db_); int rc = sqlite3_open(db_path.c_str(), &db_);
if (rc) { if (rc) {
throw std::runtime_error("Failed to open SQLite database: " + db_path); throw std::runtime_error("Failed to open SQLite database: " + db_path);
} }
spdlog::info("OK: SQLite database opened: {}", db_path); spdlog::info("OK: SQLite database opened: {}", db_path);
InitializeSchema(); InitializeSchema();
} }
void SqliteDatabase::BeginTransaction() { void SqliteDatabase::BeginTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) != if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) !=
SQLITE_OK) { SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
throw std::runtime_error("BeginTransaction failed: " + msg); throw std::runtime_error("BeginTransaction failed: " + msg);
} }
} }
void SqliteDatabase::CommitTransaction() { void SqliteDatabase::CommitTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
throw std::runtime_error("CommitTransaction failed: " + msg); throw std::runtime_error("CommitTransaction failed: " + msg);
} }
} }
void SqliteDatabase::InsertCountry(int id, const std::string &name, void SqliteDatabase::InsertCountry(int id, const std::string& name,
const std::string &iso2, const std::string& iso2,
const std::string &iso3) { const std::string& iso3) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO countries (id, name, iso2, iso3) INSERT OR IGNORE INTO countries (id, name, iso2, iso3)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare country insert"); throw std::runtime_error("Failed to prepare country insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert country"); throw std::runtime_error("Failed to insert country");
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id,
const std::string &iso2) { const std::string& name,
std::lock_guard<std::mutex> lock(db_mutex_); const std::string& iso2) {
std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO states (id, country_id, name, iso2) INSERT OR IGNORE INTO states (id, country_id, name, iso2)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare state insert"); throw std::runtime_error("Failed to prepare state insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, country_id); sqlite3_bind_int(stmt, 2, country_id);
sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert state"); throw std::runtime_error("Failed to insert state");
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertCity(int id, int state_id, int country_id, void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
const std::string &name, double latitude, const std::string& name, double latitude,
double longitude) { double longitude) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude) INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare city insert"); throw std::runtime_error("Failed to prepare city insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, state_id); sqlite3_bind_int(stmt, 2, state_id);
sqlite3_bind_int(stmt, 3, country_id); sqlite3_bind_int(stmt, 3, country_id);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_double(stmt, 5, latitude); sqlite3_bind_double(stmt, 5, latitude);
sqlite3_bind_double(stmt, 6, longitude); sqlite3_bind_double(stmt, 6, longitude);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert city"); throw std::runtime_error("Failed to insert city");
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
std::vector<City> SqliteDatabase::QueryCities() { std::vector<City> SqliteDatabase::QueryCities() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<City> cities; std::vector<City> cities;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
const char *query = "SELECT id, name, country_id FROM cities ORDER BY name"; const char* query = "SELECT id, name, country_id FROM cities ORDER BY name";
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare query"); throw std::runtime_error("Failed to prepare query");
} }
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
int country_id = sqlite3_column_int(stmt, 2); int country_id = sqlite3_column_int(stmt, 2);
cities.push_back({id, name ? std::string(name) : "", country_id}); cities.push_back({id, name ? std::string(name) : "", country_id});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return cities; return cities;
} }
std::vector<Country> SqliteDatabase::QueryCountries(int limit) { std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<Country> countries; std::vector<Country> countries;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, iso3 FROM countries ORDER BY name"; "SELECT id, name, iso2, iso3 FROM countries ORDER BY name";
if (limit > 0) { if (limit > 0) {
query += " LIMIT " + std::to_string(limit); query += " LIMIT " + std::to_string(limit);
} }
int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare countries query"); throw std::runtime_error("Failed to prepare countries query");
} }
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
const char *iso3 = const char* iso3 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 3));
countries.push_back({id, name ? std::string(name) : "", countries.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", iso2 ? std::string(iso2) : "",
iso3 ? std::string(iso3) : ""}); iso3 ? std::string(iso3) : ""});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return countries; return countries;
} }
std::vector<State> SqliteDatabase::QueryStates(int limit) { std::vector<State> SqliteDatabase::QueryStates(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<State> states; std::vector<State> states;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, country_id FROM states ORDER BY name"; "SELECT id, name, iso2, country_id FROM states ORDER BY name";
if (limit > 0) { if (limit > 0) {
query += " LIMIT " + std::to_string(limit); query += " LIMIT " + std::to_string(limit);
} }
int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare states query"); throw std::runtime_error("Failed to prepare states query");
} }
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
int country_id = sqlite3_column_int(stmt, 3); int country_id = sqlite3_column_int(stmt, 3);
states.push_back({id, name ? std::string(name) : "", states.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", country_id}); iso2 ? std::string(iso2) : "", country_id});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return states; return states;
} }

View File

@@ -1,65 +1,66 @@
#include <chrono> #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "json_handling/json_loader.h" #include <chrono>
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string &json_path, void JsonLoader::LoadWorldCities(const std::string& json_path,
SqliteDatabase &db) { SqliteDatabase& db) {
constexpr size_t kBatchSize = 10000; constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path); spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path);
db.BeginTransaction(); db.BeginTransaction();
bool transactionOpen = true; bool transactionOpen = true;
size_t citiesProcessed = 0; size_t citiesProcessed = 0;
try { try {
StreamingJsonParser::Parse( StreamingJsonParser::Parse(
json_path, db, json_path, db,
[&](const CityRecord &record) { [&](const CityRecord& record) {
db.InsertCity(record.id, record.state_id, record.country_id, db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude); record.name, record.latitude, record.longitude);
++citiesProcessed; ++citiesProcessed;
if (citiesProcessed % kBatchSize == 0) { if (citiesProcessed % kBatchSize == 0) {
db.CommitTransaction(); db.CommitTransaction();
db.BeginTransaction(); db.BeginTransaction();
} }
}, },
[&](size_t current, size_t /*total*/) { [&](size_t current, size_t /*total*/) {
if (current % kBatchSize == 0 && current > 0) { if (current % kBatchSize == 0 && current > 0) {
spdlog::info(" [Progress] Parsed {} cities...", current); spdlog::info(" [Progress] Parsed {} cities...", current);
} }
}); });
spdlog::info(" OK: Parsed all cities from JSON"); spdlog::info(" OK: Parsed all cities from JSON");
if (transactionOpen) { if (transactionOpen) {
db.CommitTransaction(); db.CommitTransaction();
transactionOpen = false; transactionOpen = false;
} }
} catch (...) { } catch (...) {
if (transactionOpen) { if (transactionOpen) {
db.CommitTransaction(); db.CommitTransaction();
} }
throw; throw;
} }
auto endTime = std::chrono::high_resolution_clock::now(); auto endTime = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
endTime - startTime); endTime - startTime);
spdlog::info("\n=== World City Data Loading Summary ===\n"); spdlog::info("\n=== World City Data Loading Summary ===\n");
spdlog::info("Cities inserted: {}", citiesProcessed); spdlog::info("Cities inserted: {}", citiesProcessed);
spdlog::info("Elapsed time: {} ms", duration.count()); spdlog::info("Elapsed time: {} ms", duration.count());
long long throughput = long long throughput =
(citiesProcessed > 0 && duration.count() > 0) (citiesProcessed > 0 && duration.count() > 0)
? (1000LL * static_cast<long long>(citiesProcessed)) / ? (1000LL * static_cast<long long>(citiesProcessed)) /
static_cast<long long>(duration.count()) static_cast<long long>(duration.count())
: 0LL; : 0LL;
spdlog::info("Throughput: {} cities/sec", throughput); spdlog::info("Throughput: {} cities/sec", throughput);
spdlog::info("=======================================\n"); spdlog::info("=======================================\n");
} }

View File

@@ -1,288 +1,289 @@
#include <cstdio> #include "json_handling/stream_parser.h"
#include <stdexcept>
#include <spdlog/spdlog.h>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp> #include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <cstdio>
#include <stdexcept>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler { class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>; friend class boost::json::basic_parser<CityRecordHandler>;
public: public:
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
struct ParseContext { struct ParseContext {
SqliteDatabase *db = nullptr; SqliteDatabase* db = nullptr;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t cities_emitted = 0; size_t cities_emitted = 0;
size_t total_file_size = 0; size_t total_file_size = 0;
int countries_inserted = 0; int countries_inserted = 0;
int states_inserted = 0; int states_inserted = 0;
}; };
explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {} explicit CityRecordHandler(ParseContext& ctx) : context(ctx) {}
private: private:
ParseContext &context; ParseContext& context;
int depth = 0; int depth = 0;
bool in_countries_array = false; bool in_countries_array = false;
bool in_country_object = false; bool in_country_object = false;
bool in_states_array = false; bool in_states_array = false;
bool in_state_object = false; bool in_state_object = false;
bool in_cities_array = false; bool in_cities_array = false;
bool building_city = false; bool building_city = false;
int current_country_id = 0; int current_country_id = 0;
int current_state_id = 0; int current_state_id = 0;
CityRecord current_city = {}; CityRecord current_city = {};
std::string current_key; std::string current_key;
std::string current_key_val; std::string current_key_val;
std::string current_string_val; std::string current_string_val;
std::string country_info[3]; std::string country_info[3];
std::string state_info[2]; std::string state_info[2];
// Boost.JSON SAX Hooks // Boost.JSON SAX Hooks
bool on_document_begin(boost::system::error_code &) { return true; } bool on_document_begin(boost::system::error_code&) { return true; }
bool on_document_end(boost::system::error_code &) { return true; } bool on_document_end(boost::system::error_code&) { return true; }
bool on_array_begin(boost::system::error_code &) { bool on_array_begin(boost::system::error_code&) {
depth++; depth++;
if (depth == 1) { if (depth == 1) {
in_countries_array = true; in_countries_array = true;
} else if (depth == 3 && current_key == "states") { } else if (depth == 3 && current_key == "states") {
in_states_array = true; in_states_array = true;
} else if (depth == 5 && current_key == "cities") { } else if (depth == 5 && current_key == "cities") {
in_cities_array = true; in_cities_array = true;
}
return true;
}
bool on_array_end(std::size_t, boost::system::error_code &) {
if (depth == 1) {
in_countries_array = false;
} else if (depth == 3) {
in_states_array = false;
} else if (depth == 5) {
in_cities_array = false;
}
depth--;
return true;
}
bool on_object_begin(boost::system::error_code &) {
depth++;
if (depth == 2 && in_countries_array) {
in_country_object = true;
current_country_id = 0;
country_info[0].clear();
country_info[1].clear();
country_info[2].clear();
} else if (depth == 4 && in_states_array) {
in_state_object = true;
current_state_id = 0;
state_info[0].clear();
state_info[1].clear();
} else if (depth == 6 && in_cities_array) {
building_city = true;
current_city = {};
}
return true;
}
bool on_object_end(std::size_t, boost::system::error_code &) {
if (depth == 6 && building_city) {
if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) {
current_city.state_id = current_state_id;
current_city.country_id = current_country_id;
try {
context.on_city(current_city);
context.cities_emitted++;
if (context.on_progress && context.cities_emitted % 10000 == 0) {
context.on_progress(context.cities_emitted,
context.total_file_size);
}
} catch (const std::exception &e) {
spdlog::warn("Record parsing failed: {}", e.what());
}
} }
building_city = false; return true;
} else if (depth == 4 && in_state_object) { }
if (current_state_id > 0 && current_country_id > 0) {
try { bool on_array_end(std::size_t, boost::system::error_code&) {
context.db->InsertState(current_state_id, current_country_id, if (depth == 1) {
state_info[0], state_info[1]); in_countries_array = false;
context.states_inserted++; } else if (depth == 3) {
} catch (const std::exception &e) { in_states_array = false;
spdlog::warn("Record parsing failed: {}", e.what()); } else if (depth == 5) {
} in_cities_array = false;
} }
in_state_object = false; depth--;
} else if (depth == 2 && in_country_object) { return true;
if (current_country_id > 0) { }
try {
context.db->InsertCountry(current_country_id, country_info[0], bool on_object_begin(boost::system::error_code&) {
country_info[1], country_info[2]); depth++;
context.countries_inserted++; if (depth == 2 && in_countries_array) {
} catch (const std::exception &e) { in_country_object = true;
spdlog::warn("Record parsing failed: {}", e.what()); current_country_id = 0;
} country_info[0].clear();
country_info[1].clear();
country_info[2].clear();
} else if (depth == 4 && in_states_array) {
in_state_object = true;
current_state_id = 0;
state_info[0].clear();
state_info[1].clear();
} else if (depth == 6 && in_cities_array) {
building_city = true;
current_city = {};
} }
in_country_object = false; return true;
} }
depth--; bool on_object_end(std::size_t, boost::system::error_code&) {
return true; if (depth == 6 && building_city) {
} if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) {
current_city.state_id = current_state_id;
current_city.country_id = current_country_id;
bool on_key_part(boost::json::string_view s, std::size_t, try {
boost::system::error_code &) { context.on_city(current_city);
current_key_val.append(s.data(), s.size()); context.cities_emitted++;
return true;
}
bool on_key(boost::json::string_view s, std::size_t, if (context.on_progress && context.cities_emitted % 10000 == 0) {
boost::system::error_code &) { context.on_progress(context.cities_emitted,
current_key_val.append(s.data(), s.size()); context.total_file_size);
current_key = current_key_val; }
current_key_val.clear(); } catch (const std::exception& e) {
return true; spdlog::warn("Record parsing failed: {}", e.what());
} }
}
bool on_string_part(boost::json::string_view s, std::size_t, building_city = false;
boost::system::error_code &) { } else if (depth == 4 && in_state_object) {
current_string_val.append(s.data(), s.size()); if (current_state_id > 0 && current_country_id > 0) {
return true; try {
} context.db->InsertState(current_state_id, current_country_id,
state_info[0], state_info[1]);
bool on_string(boost::json::string_view s, std::size_t, context.states_inserted++;
boost::system::error_code &) { } catch (const std::exception& e) {
current_string_val.append(s.data(), s.size()); spdlog::warn("Record parsing failed: {}", e.what());
}
if (building_city && current_key == "name") { }
current_city.name = current_string_val; in_state_object = false;
} else if (in_state_object && current_key == "name") { } else if (depth == 2 && in_country_object) {
state_info[0] = current_string_val; if (current_country_id > 0) {
} else if (in_state_object && current_key == "iso2") { try {
state_info[1] = current_string_val; context.db->InsertCountry(current_country_id, country_info[0],
} else if (in_country_object && current_key == "name") { country_info[1], country_info[2]);
country_info[0] = current_string_val; context.countries_inserted++;
} else if (in_country_object && current_key == "iso2") { } catch (const std::exception& e) {
country_info[1] = current_string_val; spdlog::warn("Record parsing failed: {}", e.what());
} else if (in_country_object && current_key == "iso3") { }
country_info[2] = current_string_val; }
} in_country_object = false;
current_string_val.clear();
return true;
}
bool on_number_part(boost::json::string_view, boost::system::error_code &) {
return true;
}
bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code &) {
if (building_city && current_key == "id") {
current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") {
current_state_id = static_cast<int>(i);
} else if (in_country_object && current_key == "id") {
current_country_id = static_cast<int>(i);
}
return true;
}
bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code &ec) {
return on_int64(static_cast<int64_t>(u), "", ec);
}
bool on_double(double d, boost::json::string_view,
boost::system::error_code &) {
if (building_city) {
if (current_key == "latitude") {
current_city.latitude = d;
} else if (current_key == "longitude") {
current_city.longitude = d;
} }
}
return true;
}
bool on_bool(bool, boost::system::error_code &) { return true; } depth--;
bool on_null(boost::system::error_code &) { return true; } return true;
bool on_comment_part(boost::json::string_view, boost::system::error_code &) { }
return true;
} bool on_key_part(boost::json::string_view s, std::size_t,
bool on_comment(boost::json::string_view, boost::system::error_code &) { boost::system::error_code&) {
return true; current_key_val.append(s.data(), s.size());
} return true;
}
bool on_key(boost::json::string_view s, std::size_t,
boost::system::error_code&) {
current_key_val.append(s.data(), s.size());
current_key = current_key_val;
current_key_val.clear();
return true;
}
bool on_string_part(boost::json::string_view s, std::size_t,
boost::system::error_code&) {
current_string_val.append(s.data(), s.size());
return true;
}
bool on_string(boost::json::string_view s, std::size_t,
boost::system::error_code&) {
current_string_val.append(s.data(), s.size());
if (building_city && current_key == "name") {
current_city.name = current_string_val;
} else if (in_state_object && current_key == "name") {
state_info[0] = current_string_val;
} else if (in_state_object && current_key == "iso2") {
state_info[1] = current_string_val;
} else if (in_country_object && current_key == "name") {
country_info[0] = current_string_val;
} else if (in_country_object && current_key == "iso2") {
country_info[1] = current_string_val;
} else if (in_country_object && current_key == "iso3") {
country_info[2] = current_string_val;
}
current_string_val.clear();
return true;
}
bool on_number_part(boost::json::string_view, boost::system::error_code&) {
return true;
}
bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code&) {
if (building_city && current_key == "id") {
current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") {
current_state_id = static_cast<int>(i);
} else if (in_country_object && current_key == "id") {
current_country_id = static_cast<int>(i);
}
return true;
}
bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code& ec) {
return on_int64(static_cast<int64_t>(u), "", ec);
}
bool on_double(double d, boost::json::string_view,
boost::system::error_code&) {
if (building_city) {
if (current_key == "latitude") {
current_city.latitude = d;
} else if (current_key == "longitude") {
current_city.longitude = d;
}
}
return true;
}
bool on_bool(bool, boost::system::error_code&) { return true; }
bool on_null(boost::system::error_code&) { return true; }
bool on_comment_part(boost::json::string_view, boost::system::error_code&) {
return true;
}
bool on_comment(boost::json::string_view, boost::system::error_code&) {
return true;
}
}; };
void StreamingJsonParser::Parse( void StreamingJsonParser::Parse(
const std::string &file_path, SqliteDatabase &db, const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); FILE* file = std::fopen(file_path.c_str(), "rb");
if (!file) {
throw std::runtime_error("Failed to open JSON file: " + file_path);
}
FILE *file = std::fopen(file_path.c_str(), "rb"); size_t total_size = 0;
if (!file) { if (std::fseek(file, 0, SEEK_END) == 0) {
throw std::runtime_error("Failed to open JSON file: " + file_path); long file_size = std::ftell(file);
} if (file_size > 0) {
total_size = static_cast<size_t>(file_size);
size_t total_size = 0;
if (std::fseek(file, 0, SEEK_END) == 0) {
long file_size = std::ftell(file);
if (file_size > 0) {
total_size = static_cast<size_t>(file_size);
}
std::rewind(file);
}
CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0,
total_size, 0, 0};
boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx);
char buf[65536];
size_t bytes_read;
boost::system::error_code ec;
while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
char const *p = buf;
std::size_t remain = bytes_read;
while (remain > 0) {
std::size_t consumed = parser.write_some(true, p, remain, ec);
if (ec) {
std::fclose(file);
throw std::runtime_error("JSON parse error: " + ec.message());
} }
p += consumed; std::rewind(file);
remain -= consumed; }
}
}
parser.write_some(false, nullptr, 0, ec); // Signal EOF CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size,
std::fclose(file); 0, 0};
boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx);
if (ec) { char buf[65536];
throw std::runtime_error("JSON parse error at EOF: " + ec.message()); size_t bytes_read;
} boost::system::error_code ec;
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); char const* p = buf;
std::size_t remain = bytes_read;
while (remain > 0) {
std::size_t consumed = parser.write_some(true, p, remain, ec);
if (ec) {
std::fclose(file);
throw std::runtime_error("JSON parse error: " + ec.message());
}
p += consumed;
remain -= consumed;
}
}
parser.write_some(false, nullptr, 0, ec); // Signal EOF
std::fclose(file);
if (ec) {
throw std::runtime_error("JSON parse error at EOF: " + ec.message());
}
spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted,
ctx.cities_emitted);
} }

View File

@@ -1,139 +1,141 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdio>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
CurlGlobalState::CurlGlobalState() { CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally"); "[CURLWebClient] Failed to initialize libcurl globally");
} }
} }
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); } CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
namespace { namespace {
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
size_t WriteCallbackString(void *contents, size_t size, size_t nmemb, size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *s = static_cast<std::string *>(userp); auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char *>(contents), realsize); s->append(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
// curl write callback that writes to a file stream // curl write callback that writes to a file stream
size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb, size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *outFile = static_cast<std::ofstream *>(userp); auto* outFile = static_cast<std::ofstream*>(userp);
outFile->write(static_cast<char *>(contents), realsize); outFile->write(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
// RAII wrapper for CURL handle using unique_ptr // RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>; using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() { CurlHandle create_handle() {
CURL *handle = curl_easy_init(); CURL* handle = curl_easy_init();
if (!handle) { if (!handle) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle"); "[CURLWebClient] Failed to initialize libcurl handle");
} }
return CurlHandle(handle, &curl_easy_cleanup); return CurlHandle(handle, &curl_easy_cleanup);
} }
void set_common_get_options(CURL *curl, const std::string &url, void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) { long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout); curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout); curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
} }
} // namespace } // namespace
CURLWebClient::CURLWebClient() {} CURLWebClient::CURLWebClient() {}
CURLWebClient::~CURLWebClient() {} CURLWebClient::~CURLWebClient() {}
void CURLWebClient::DownloadToFile(const std::string &url, void CURLWebClient::DownloadToFile(const std::string& url,
const std::string &file_path) { const std::string& file_path) {
auto curl = create_handle(); auto curl = create_handle();
std::ofstream outFile(file_path, std::ios::binary); std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) { if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + throw std::runtime_error(
file_path); "[CURLWebClient] Cannot open file for writing: " + file_path);
} }
set_common_get_options(curl.get(), url, 30L, 300L); set_common_get_options(curl.get(), url, 30L, 300L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void *>(&outFile)); static_cast<void*>(&outFile));
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
outFile.close(); outFile.close();
if (res != CURLE_OK) { if (res != CURLE_OK) {
std::remove(file_path.c_str()); std::remove(file_path.c_str());
std::string error = std::string("[CURLWebClient] Download failed: ") + std::string error = std::string("[CURLWebClient] Download failed: ") +
curl_easy_strerror(res); curl_easy_strerror(res);
throw std::runtime_error(error); throw std::runtime_error(error);
} }
long httpCode = 0; long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) { if (httpCode != 200) {
std::remove(file_path.c_str()); std::remove(file_path.c_str());
std::stringstream ss; std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
} }
std::string CURLWebClient::Get(const std::string &url) { std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle(); auto curl = create_handle();
std::string response_string; std::string response_string;
set_common_get_options(curl.get(), url, 10L, 20L); set_common_get_options(curl.get(), url, 10L, 20L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) { if (res != CURLE_OK) {
std::string error = std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error); throw std::runtime_error(error);
} }
long httpCode = 0; long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) { if (httpCode != 200) {
std::stringstream ss; std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
return response_string; return response_string;
} }
std::string CURLWebClient::UrlEncode(const std::string &value) { std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char *output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (output) { if (output) {
std::string result(output); std::string result(output);
curl_free(output); curl_free(output);
return result; return result;
} }
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
} }

View File

@@ -1,77 +1,78 @@
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
WikipediaService::WikipediaService(std::shared_ptr<IWebClient> client) #include <boost/json.hpp>
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}
std::string WikipediaService::FetchExtract(std::string_view query) { std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string encoded = client_->UrlEncode(std::string(query)); const std::string encoded = client_->UrlEncode(std::string(query));
const std::string url = const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=true&format=json"; "&prop=extracts&explaintext=true&format=json";
const std::string body = client_->Get(url); const std::string body = client_->Get(url);
boost::system::error_code ec; boost::system::error_code ec;
boost::json::value doc = boost::json::parse(body, ec); boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) { if (!ec && doc.is_object()) {
auto &pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto &page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str()); std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'", spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query); extract.size(), query);
return extract; return extract;
}
} }
} }
}
return {}; return {};
} }
std::string WikipediaService::GetSummary(std::string_view city, std::string WikipediaService::GetSummary(std::string_view city,
std::string_view country) { std::string_view country) {
const std::string key = std::string(city) + "|" + std::string(country); const std::string key = std::string(city) + "|" + std::string(country);
const auto cacheIt = cache_.find(key); const auto cacheIt = cache_.find(key);
if (cacheIt != cache_.end()) { if (cacheIt != cache_.end()) {
return cacheIt->second; return cacheIt->second;
} }
std::string result; std::string result;
if (!client_) { if (!client_) {
cache_.emplace(key, result); cache_.emplace(key, result);
return result; return result;
} }
std::string regionQuery(city); std::string regionQuery(city);
if (!country.empty()) { if (!country.empty()) {
regionQuery += ", "; regionQuery += ", ";
regionQuery += country; regionQuery += country;
} }
const std::string beerQuery = "beer in " + std::string(city); const std::string beerQuery = "beer in " + std::string(city);
try { try {
const std::string regionExtract = FetchExtract(regionQuery); const std::string regionExtract = FetchExtract(regionQuery);
const std::string beerExtract = FetchExtract(beerQuery); const std::string beerExtract = FetchExtract(beerQuery);
if (!regionExtract.empty()) { if (!regionExtract.empty()) {
result += regionExtract; result += regionExtract;
} }
if (!beerExtract.empty()) { if (!beerExtract.empty()) {
if (!result.empty()) if (!result.empty()) result += "\n\n";
result += "\n\n"; result += beerExtract;
result += beerExtract; }
} } catch (const std::runtime_error& e) {
} catch (const std::runtime_error &e) { spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, e.what());
e.what()); }
}
cache_.emplace(key, result); cache_.emplace(key, result);
return result; return result;
} }