format codebase

This commit is contained in:
Aaron Po
2026-04-02 21:46:46 -04:00
parent ba165d8aa7
commit 3af053f0eb
31 changed files with 1479 additions and 1445 deletions

View File

@@ -1,10 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
Standard: c++23 ColumnLimit: 80
ColumnLimit: 100 IndentWidth: 3
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
... ...

View File

@@ -90,7 +90,11 @@ set(PIPELINE_SOURCES
src/data_generation/llama/generate_brewery.cpp src/data_generation/llama/generate_brewery.cpp
src/data_generation/llama/generate_user.cpp src/data_generation/llama/generate_user.cpp
src/data_generation/llama/helpers.cpp src/data_generation/llama/helpers.cpp
src/data_generation/mock_generator.cpp src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/load.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
src/json_handling/stream_parser.cpp src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp src/wikipedia/wikipedia_service.cpp
src/main.cpp src/main.cpp

View File

@@ -9,22 +9,23 @@
/// @brief Downloads and caches source geography JSON payloads. /// @brief Downloads and caches source geography JSON payloads.
class DataDownloader { class DataDownloader {
public: public:
/// @brief Initializes global curl state used by this downloader. /// @brief Initializes global curl state used by this downloader.
explicit DataDownloader(std::shared_ptr<WebClient> web_client); explicit DataDownloader(std::shared_ptr<WebClient> web_client);
/// @brief Cleans up global curl state. /// @brief Cleans up global curl state.
~DataDownloader(); ~DataDownloader();
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string &cache_path, const std::string& cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string& commit =
); "c5eb7772" // Stable commit: 2026-03-28 export
);
private: private:
static bool FileExists(const std::string &file_path); static bool FileExists(const std::string& file_path);
std::shared_ptr<WebClient> web_client_; std::shared_ptr<WebClient> web_client_;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_

View File

@@ -4,26 +4,26 @@
#include <string> #include <string>
struct BreweryResult { struct BreweryResult {
std::string name; std::string name;
std::string description; std::string description;
}; };
struct UserResult { struct UserResult {
std::string username; std::string username;
std::string bio; std::string bio;
}; };
class DataGenerator { class DataGenerator {
public: public:
virtual ~DataGenerator() = default; virtual ~DataGenerator() = default;
virtual void Load(const std::string &model_path) = 0; virtual void Load(const std::string& model_path) = 0;
virtual BreweryResult GenerateBrewery(const std::string &city_name, virtual BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) = 0; const std::string& region_context) = 0;
virtual UserResult GenerateUser(const std::string &locale) = 0; virtual UserResult GenerateUser(const std::string& locale) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -10,32 +10,32 @@ struct llama_model;
struct llama_context; struct llama_context;
class LlamaGenerator final : public DataGenerator { class LlamaGenerator final : public DataGenerator {
public: public:
LlamaGenerator() = default; LlamaGenerator() = default;
~LlamaGenerator() override; ~LlamaGenerator() override;
void SetSamplingOptions(float temperature, float top_p, int seed = -1); void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void Load(const std::string &model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) override; const std::string& region_context) override;
UserResult GenerateUser(const std::string &locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
std::string Infer(const std::string &prompt, int max_tokens = 10000); std::string Infer(const std::string& prompt, int max_tokens = 10000);
// Overload that allows passing a system message separately so chat-capable // Overload that allows passing a system message separately so chat-capable
// models receive a proper system role instead of having the system text // models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal // concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output). // reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt, std::string Infer(const std::string& system_prompt,
int max_tokens = 10000); const std::string& prompt, int max_tokens = 10000);
llama_model *model_ = nullptr; llama_model* model_ = nullptr;
llama_context *context_ = nullptr; llama_context* context_ = nullptr;
float sampling_temperature_ = 0.8f; float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f; float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu; uint32_t sampling_seed_ = 0xFFFFFFFFu;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -12,18 +12,17 @@ typedef int llama_token;
std::string PrepareRegionContextPublic(std::string_view region_context, std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700); std::size_t max_chars = 700);
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponsePublic(
ParseTwoLineResponsePublic(const std::string& raw, const std::string& raw, const std::string& error_message);
const std::string& error_message);
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt); const std::string& user_prompt);
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt); const std::string& user_prompt);
void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output); std::string& output);
std::string ValidateBreweryJsonPublic(const std::string& raw, std::string ValidateBreweryJsonPublic(const std::string& raw,

View File

@@ -1,27 +1,28 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "data_generation/data_generator.h"
class MockGenerator final : public DataGenerator { class MockGenerator final : public DataGenerator {
public: public:
void Load(const std::string &model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) override; const std::string& region_context) override;
UserResult GenerateUser(const std::string &locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
static std::size_t DeterministicHash(const std::string &a, static std::size_t DeterministicHash(const std::string& a,
const std::string &b); const std::string& b);
static const std::vector<std::string> kBreweryAdjectives; static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns; static const std::vector<std::string> kBreweryNouns;
static const std::vector<std::string> kBreweryDescriptions; static const std::vector<std::string> kBreweryDescriptions;
static const std::vector<std::string> kUsernames; static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios; static const std::vector<std::string> kBios;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -1,83 +1,84 @@
#ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
struct Country { struct Country {
/// @brief Country identifier from the source dataset. /// @brief Country identifier from the source dataset.
int id; int id;
/// @brief Country display name. /// @brief Country display name.
std::string name; std::string name;
/// @brief ISO 3166-1 alpha-2 code. /// @brief ISO 3166-1 alpha-2 code.
std::string iso2; std::string iso2;
/// @brief ISO 3166-1 alpha-3 code. /// @brief ISO 3166-1 alpha-3 code.
std::string iso3; std::string iso3;
}; };
struct State { struct State {
/// @brief State or province identifier from the source dataset. /// @brief State or province identifier from the source dataset.
int id; int id;
/// @brief State or province display name. /// @brief State or province display name.
std::string name; std::string name;
/// @brief State or province short code. /// @brief State or province short code.
std::string iso2; std::string iso2;
/// @brief Parent country identifier. /// @brief Parent country identifier.
int country_id; int country_id;
}; };
struct City { struct City {
/// @brief City identifier from the source dataset. /// @brief City identifier from the source dataset.
int id; int id;
/// @brief City display name. /// @brief City display name.
std::string name; std::string name;
/// @brief Parent country identifier. /// @brief Parent country identifier.
int country_id; int country_id;
}; };
/// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks.
class SqliteDatabase { class SqliteDatabase {
private: private:
sqlite3 *db_ = nullptr; sqlite3* db_ = nullptr;
std::mutex db_mutex_; std::mutex db_mutex_;
void InitializeSchema(); void InitializeSchema();
public: public:
/// @brief Closes the SQLite connection if initialized. /// @brief Closes the SQLite connection if initialized.
~SqliteDatabase(); ~SqliteDatabase();
/// @brief Opens the SQLite database at db_path and creates schema objects. /// @brief Opens the SQLite database at db_path and creates schema objects.
void Initialize(const std::string &db_path = ":memory:"); void Initialize(const std::string& db_path = ":memory:");
/// @brief Starts a database transaction for batched writes. /// @brief Starts a database transaction for batched writes.
void BeginTransaction(); void BeginTransaction();
/// @brief Commits the active database transaction. /// @brief Commits the active database transaction.
void CommitTransaction(); void CommitTransaction();
/// @brief Inserts a country row. /// @brief Inserts a country row.
void InsertCountry(int id, const std::string &name, const std::string &iso2, void InsertCountry(int id, const std::string& name, const std::string& iso2,
const std::string &iso3); const std::string& iso3);
/// @brief Inserts a state row linked to a country. /// @brief Inserts a state row linked to a country.
void InsertState(int id, int country_id, const std::string &name, void InsertState(int id, int country_id, const std::string& name,
const std::string &iso2); const std::string& iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int state_id, int country_id, const std::string &name, void InsertCity(int id, int state_id, int country_id,
double latitude, double longitude); const std::string& name, double latitude, double longitude);
/// @brief Returns city records including parent country id. /// @brief Returns city records including parent country id.
std::vector<City> QueryCities(); std::vector<City> QueryCities();
/// @brief Returns countries with optional row limit. /// @brief Returns countries with optional row limit.
std::vector<Country> QueryCountries(int limit = 0); std::vector<Country> QueryCountries(int limit = 0);
/// @brief Returns states with optional row limit. /// @brief Returns states with optional row limit.
std::vector<State> QueryStates(int limit = 0); std::vector<State> QueryStates(int limit = 0);
}; };
#endif // BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #endif // BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_

View File

@@ -1,15 +1,17 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include <string>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing. /// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON file and writes country/state/city rows into db. /// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); static void LoadWorldCities(const std::string& json_path,
SqliteDatabase& db);
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,51 +1,52 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "database/database.h"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
class SqliteDatabase; class SqliteDatabase;
/// @brief In-memory representation of one parsed city entry. /// @brief In-memory representation of one parsed city entry.
struct CityRecord { struct CityRecord {
int id; int id;
int state_id; int state_id;
int country_id; int country_id;
std::string name; std::string name;
double latitude; double latitude;
double longitude; double longitude;
}; };
/// @brief Streaming SAX parser that emits city records during traversal. /// @brief Streaming SAX parser that emits city records during traversal.
class StreamingJsonParser { class StreamingJsonParser {
public: public:
/// @brief Parses file_path and invokes callbacks for city rows and progress. /// @brief Parses file_path and invokes callbacks for city rows and progress.
static void Parse(const std::string &file_path, SqliteDatabase &db, static void Parse(const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress = nullptr); std::function<void(size_t, size_t)> on_progress = nullptr);
private: private:
/// @brief Mutable SAX handler state while traversing nested JSON arrays. /// @brief Mutable SAX handler state while traversing nested JSON arrays.
struct ParseState { struct ParseState {
int current_country_id = 0; int current_country_id = 0;
int current_state_id = 0; int current_state_id = 0;
CityRecord current_city = {}; CityRecord current_city = {};
bool building_city = false; bool building_city = false;
std::string current_key; std::string current_key;
int array_depth = 0; int array_depth = 0;
int object_depth = 0; int object_depth = 0;
bool in_countries_array = false; bool in_countries_array = false;
bool in_states_array = false; bool in_states_array = false;
bool in_cities_array = false; bool in_cities_array = false;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t bytes_processed = 0; size_t bytes_processed = 0;
}; };
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_

View File

@@ -1,29 +1,30 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory> #include <memory>
#include "web_client/web_client.h"
// RAII for curl_global_init/cleanup. // RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl // An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application. // operations and exist for the lifetime of the application.
class CurlGlobalState { class CurlGlobalState {
public: public:
CurlGlobalState(); CurlGlobalState();
~CurlGlobalState(); ~CurlGlobalState();
CurlGlobalState(const CurlGlobalState &) = delete; CurlGlobalState(const CurlGlobalState&) = delete;
CurlGlobalState &operator=(const CurlGlobalState &) = delete; CurlGlobalState& operator=(const CurlGlobalState&) = delete;
}; };
class CURLWebClient : public WebClient { class CURLWebClient : public WebClient {
public: public:
CURLWebClient(); CURLWebClient();
~CURLWebClient() override; ~CURLWebClient() override;
void DownloadToFile(const std::string &url, void DownloadToFile(const std::string& url,
const std::string &file_path) override; const std::string& file_path) override;
std::string Get(const std::string &url) override; std::string Get(const std::string& url) override;
std::string UrlEncode(const std::string &value) override; std::string UrlEncode(const std::string& value) override;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -4,19 +4,19 @@
#include <string> #include <string>
class WebClient { class WebClient {
public: public:
virtual ~WebClient() = default; virtual ~WebClient() = default;
// Downloads content from a URL to a file. Throws on error. // Downloads content from a URL to a file. Throws on error.
virtual void DownloadToFile(const std::string &url, virtual void DownloadToFile(const std::string& url,
const std::string &file_path) = 0; const std::string& file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on // Performs a GET request and returns the response body as a string. Throws
// error. // on error.
virtual std::string Get(const std::string &url) = 0; virtual std::string Get(const std::string& url) = 0;
// URL-encodes a string. // URL-encodes a string.
virtual std::string UrlEncode(const std::string &value) = 0; virtual std::string UrlEncode(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -10,18 +10,18 @@
/// @brief Provides cached Wikipedia summary lookups for city and country pairs. /// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService { class WikipediaService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::shared_ptr<WebClient> client); explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia summary extract for city and country. /// @brief Returns the Wikipedia summary extract for city and country.
[[nodiscard]] std::string GetSummary(std::string_view city, [[nodiscard]] std::string GetSummary(std::string_view city,
std::string_view country); std::string_view country);
private: private:
std::string FetchExtract(std::string_view query); std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_; std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_; std::unordered_map<std::string, std::string> cache_;
}; };
#endif // BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_ #endif // BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_

View File

@@ -1,46 +1,49 @@
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <spdlog/spdlog.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include "web_client/web_client.h"
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client) DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {} : web_client_(std::move(web_client)) {}
DataDownloader::~DataDownloader() {} DataDownloader::~DataDownloader() {}
bool DataDownloader::FileExists(const std::string &file_path) { bool DataDownloader::FileExists(const std::string& file_path) {
return std::filesystem::exists(file_path); return std::filesystem::exists(file_path);
} }
std::string std::string DataDownloader::DownloadCountriesDatabase(
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, const std::string& cache_path, const std::string& commit) {
const std::string &commit) { if (FileExists(cache_path)) {
if (FileExists(cache_path)) { spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
spdlog::info("[DataDownloader] Cache hit: {}", cache_path); return cache_path;
return cache_path; }
}
std::string short_commit = commit; std::string short_commit = commit;
if (commit.length() > 7) { if (commit.length() > 7) {
short_commit = commit.substr(0, 7); short_commit = commit.substr(0, 7);
} }
std::string url = "https://raw.githubusercontent.com/dr5hn/" std::string url =
"countries-states-cities-database/" + "https://raw.githubusercontent.com/dr5hn/"
short_commit + "/json/countries+states+cities.json"; "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json";
spdlog::info("[DataDownloader] Downloading: {}", url); spdlog::info("[DataDownloader] Downloading: {}", url);
web_client_->DownloadToFile(url, cache_path); web_client_->DownloadToFile(url, cache_path);
std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate); std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate);
std::streamsize size = file_check.tellg(); std::streamsize size = file_check.tellg();
file_check.close(); file_check.close();
spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)", spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)",
cache_path, (size / (1024.0 * 1024.0))); cache_path, (size / (1024.0 * 1024.0)));
return cache_path; return cache_path;
} }

View File

@@ -1,17 +1,16 @@
#include "data_generation/llama_generator.h"
#include "llama.h" #include "llama.h"
#include "data_generation/llama_generator.h"
LlamaGenerator::~LlamaGenerator() { LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) { if (context_ != nullptr) {
llama_free(context_); llama_free(context_);
context_ = nullptr; context_ = nullptr;
} }
if (model_ != nullptr) { if (model_ != nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
} }
llama_backend_free(); llama_backend_free();
} }

View File

@@ -1,72 +1,74 @@
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
BreweryResult BreweryResult LlamaGenerator::GenerateBrewery(
LlamaGenerator::GenerateBrewery(const std::string& city_name, const std::string& city_name, const std::string& country_name,
const std::string& country_name, const std::string& region_context) {
const std::string& region_context) { const std::string safe_region_context =
const std::string safe_region_context = PrepareRegionContextPublic(region_context);
PrepareRegionContextPublic(region_context);
const std::string system_prompt = const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. " "You are the brewmaster and owner of a local craft brewery. "
"Your writing is vivid, specific to place, and avoids generic beer " "Write a name and a short, soulful description for your brewery that "
"cliches. " "reflects your pride in the local community and your craft. "
"You must output ONLY valid JSON. " "The tone should be authentic and welcoming, like a note on a "
"The JSON schema must be exactly: {\"name\": \"string\", " "chalkboard "
"\"description\": \"string\"}. " "menu. Output ONLY a single JSON object with keys \"name\" and "
"Do not include markdown formatting or backticks."; "\"description\". "
"Do not include markdown formatting or backticks.";
std::string prompt = std::string prompt =
"Write a brewery name and place-specific description for a craft " "Write a brewery name and place-specific description for a craft "
"brewery in " + "brewery in " +
city_name + city_name +
(country_name.empty() ? std::string("") (country_name.empty() ? std::string("")
: std::string(", ") + country_name) + : std::string(", ") + country_name) +
(safe_region_context.empty() (safe_region_context.empty()
? std::string(".") ? std::string(".")
: std::string(". Regional context: ") + safe_region_context); : std::string(". Regional context: ") + safe_region_context);
const int max_attempts = 3; const int max_attempts = 3;
std::string raw; std::string raw;
std::string last_error; std::string last_error;
for (int attempt = 0; attempt < max_attempts; ++attempt) { for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384); raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw); raw);
std::string name; std::string name;
std::string description; std::string description;
const std::string validation_error = const std::string validation_error =
ValidateBreweryJsonPublic(raw, name, description); ValidateBreweryJsonPublic(raw, name, description);
if (validation_error.empty()) { if (validation_error.empty()) {
return {std::move(name), std::move(description)}; return {std::move(name), std::move(description)};
} }
last_error = validation_error; last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error); attempt + 1, validation_error);
prompt = "Your previous response was invalid. Error: " + validation_error + prompt =
"\nReturn ONLY valid JSON with this exact schema: " "Your previous response was invalid. Error: " + validation_error +
"{\"name\": \"string\", \"description\": \"string\"}." "\nReturn ONLY valid JSON with this exact schema: "
"\nDo not include markdown, comments, or extra keys." "{\"name\": \"string\", \"description\": \"string\"}."
"\n\nLocation: " + "\nDo not include markdown, comments, or extra keys."
city_name + "\n\nLocation: " +
(country_name.empty() ? std::string("") city_name +
: std::string(", ") + country_name) + (country_name.empty() ? std::string("")
(safe_region_context.empty() : std::string(", ") + country_name) +
? std::string("") (safe_region_context.empty()
: std::string("\nRegional context: ") + safe_region_context); ? std::string("")
} : std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " spdlog::error(
"{}", "LlamaGenerator: malformed brewery response after {} attempts: "
max_attempts, last_error.empty() ? raw : last_error); "{}",
throw std::runtime_error("LlamaGenerator: malformed brewery response"); max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
} }

View File

@@ -1,56 +1,57 @@
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) { UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
const std::string system_prompt = const std::string system_prompt =
"You generate plausible social media profiles for craft beer " "You generate plausible social media profiles for craft beer "
"enthusiasts. " "enthusiasts. "
"Respond with exactly two lines: " "Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), " "the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). " "the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. " "The profile should feel consistent with the locale. "
"No preamble, no labels."; "No preamble, no labels.";
std::string prompt = std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale; "Generate a craft beer enthusiast profile. Locale: " + locale;
const int max_attempts = 3; const int max_attempts = 3;
std::string raw; std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) { for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128); raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw); attempt + 1, raw);
try { try {
auto [username, bio] = ParseTwoLineResponsePublic( auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response"); raw, "LlamaGenerator: malformed user response");
username.erase( username.erase(
std::remove_if(username.begin(), username.end(), std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }), [](unsigned char ch) { return std::isspace(ch); }),
username.end()); username.end());
if (username.empty() || bio.empty()) { if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response"); throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200) bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception& e) {
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
} }
}
if (bio.size() > 200) spdlog::error(
bio = bio.substr(0, 200); "LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
return {username, bio}; throw std::runtime_error("LlamaGenerator: malformed user response");
} catch (const std::exception &e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
} }

View File

@@ -1,367 +1,365 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <boost/json.hpp>
#include <cctype> #include <cctype>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
namespace { namespace {
std::string Trim(std::string value) { std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); }; auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(), value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space)); std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(), value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end()); value.end());
return value; return value;
} }
std::string CondenseWhitespace(std::string text) { std::string CondenseWhitespace(std::string text) {
std::string out; std::string out;
out.reserve(text.size()); out.reserve(text.size());
bool in_whitespace = false; bool in_whitespace = false;
for (unsigned char ch : text) { for (unsigned char ch : text) {
if (std::isspace(ch)) { if (std::isspace(ch)) {
if (!in_whitespace) { if (!in_whitespace) {
out.push_back(' '); out.push_back(' ');
in_whitespace = true; in_whitespace = true;
}
continue;
} }
continue;
}
in_whitespace = false; in_whitespace = false;
out.push_back(static_cast<char>(ch)); out.push_back(static_cast<char>(ch));
} }
return Trim(std::move(out)); return Trim(std::move(out));
} }
std::string PrepareRegionContext(std::string_view region_context, std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) { std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context)); std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) { if (normalized.size() <= max_chars) {
return normalized; return normalized;
} }
normalized.resize(max_chars); normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' '); const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) { if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space); normalized.resize(last_space);
} }
normalized += "..."; normalized += "...";
return normalized; return normalized;
} }
std::string StripCommonPrefix(std::string line) { std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line)); line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) { if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1)); line = Trim(line.substr(1));
} else { } else {
std::size_t i = 0; std::size_t i = 0;
while (i < line.size() && while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) { std::isdigit(static_cast<unsigned char>(line[i]))) {
++i; ++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
} }
if (matches) { if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(label.size())); line = Trim(line.substr(i + 1));
} }
} }
};
strip_label("name:"); auto strip_label = [&line](const std::string& label) {
strip_label("brewery name:"); if (line.size() >= label.size()) {
strip_label("description:"); bool matches = true;
strip_label("username:"); for (std::size_t i = 0; i < label.size(); ++i) {
strip_label("bio:"); if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
return Trim(std::move(line)); strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
} }
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponse(
ParseTwoLineResponse(const std::string& raw, const std::string& error_message) { const std::string& raw, const std::string& error_message) {
std::string normalized = raw; std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n'); std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines; std::vector<std::string> lines;
std::stringstream stream(normalized); std::stringstream stream(normalized);
std::string line; std::string line;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line)); line = StripCommonPrefix(std::move(line));
if (!line.empty()) if (!line.empty()) lines.push_back(std::move(line));
lines.push_back(std::move(line)); }
}
std::vector<std::string> filtered; std::vector<std::string> filtered;
for (auto &l : lines) { for (auto& l : lines) {
std::string low = l; std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (!l.empty() && l.front() == '<' && low.back() == '>') if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
continue; if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) filtered.push_back(std::move(l));
continue; }
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) if (filtered.size() < 2) throw std::runtime_error(error_message);
throw std::runtime_error(error_message);
std::string first = Trim(filtered.front()); std::string first = Trim(filtered.front());
std::string second; std::string second;
for (size_t i = 1; i < filtered.size(); ++i) { for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) if (!second.empty()) second += ' ';
second += ' '; second += filtered[i];
second += filtered[i]; }
} second = Trim(std::move(second));
second = Trim(std::move(second));
if (first.empty() || second.empty()) if (first.empty() || second.empty()) throw std::runtime_error(error_message);
throw std::runtime_error(error_message); return {first, second};
return {first, second};
} }
std::string ToChatPrompt(const llama_model *model, std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) { const std::string& user_prompt) {
const char *tmpl = llama_model_chat_template(model, nullptr); const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) { if (tmpl == nullptr) {
return user_prompt; return user_prompt;
} }
const llama_chat_message message{"user", user_prompt.c_str()}; const llama_chat_message message{"user", user_prompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, user_prompt.size() * 4)); std::vector<char> buffer(
int32_t required = std::max<std::size_t>(1024, user_prompt.size() * 4));
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), int32_t required =
static_cast<int32_t>(buffer.size())); llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) { if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template"); throw std::runtime_error("LlamaGenerator: failed to apply chat template");
} }
}
return std::string(buffer.data(), static_cast<std::size_t>(required)); if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
} }
std::string ToChatPrompt(const llama_model *model, std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
const char *tmpl = llama_model_chat_template(model, nullptr); const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) { if (tmpl == nullptr) {
return system_prompt + "\n\n" + user_prompt; return system_prompt + "\n\n" + user_prompt;
} }
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()}, const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", user_prompt.c_str()}}; {"user", user_prompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>( std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4)); 1024, (system_prompt.size() + user_prompt.size()) * 4));
int32_t required = int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
if (required < 0) { if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template"); throw std::runtime_error("LlamaGenerator: failed to apply chat template");
} }
}
return std::string(buffer.data(), static_cast<std::size_t>(required)); if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
} }
void AppendTokenPiece(const llama_vocab *vocab, llama_token token, void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
std::array<char, 256> buffer{}; std::array<char, 256> buffer{};
int32_t bytes = int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(), llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true); static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) { if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes)); std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0, static_cast<int32_t>(dynamic_buffer.size()),
true); 0, true);
if (bytes < 0) { if (bytes < 0) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece"); "LlamaGenerator: failed to decode sampled token piece");
} }
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes)); output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return; return;
} }
output.append(buffer.data(), static_cast<std::size_t>(bytes)); output.append(buffer.data(), static_cast<std::size_t>(bytes));
} }
bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) { bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) {
std::size_t start = std::string::npos; std::size_t start = std::string::npos;
int depth = 0; int depth = 0;
bool in_string = false; bool in_string = false;
bool escaped = false; bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) { for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i]; const char ch = text[i];
if (in_string) { if (in_string) {
if (escaped) { if (escaped) {
escaped = false; escaped = false;
} else if (ch == '\\') { } else if (ch == '\\') {
escaped = true; escaped = true;
} else if (ch == '"') { } else if (ch == '"') {
in_string = false; in_string = false;
}
continue;
} }
continue;
}
if (ch == '"') { if (ch == '"') {
in_string = true; in_string = true;
continue; continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
} }
++depth;
continue;
}
if (ch == '}') { if (ch == '{') {
if (depth == 0) { if (depth == 0) {
continue; start = i;
}
++depth;
continue;
} }
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
}
}
}
return false; if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
} }
std::string ValidateBreweryJson(const std::string& raw, std::string& name_out, std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
std::string& description_out) { std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv, auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool { std::string& error_out) -> bool {
if (!jv.is_object()) { if (!jv.is_object()) {
error_out = "JSON root must be an object"; error_out = "JSON root must be an object";
return false; return false;
} }
const auto& obj = jv.get_object(); const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) { if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string"; error_out = "JSON field 'name' is missing or not a string";
return false; return false;
} }
if (!obj.contains("description") || !obj.at("description").is_string()) { if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string"; error_out = "JSON field 'description' is missing or not a string";
return false; return false;
} }
name_out = Trim(std::string(obj.at("name").as_string().c_str())); name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out = description_out =
Trim(std::string(obj.at("description").as_string().c_str())); Trim(std::string(obj.at("description").as_string().c_str()));
if (name_out.empty()) { if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty"; error_out = "JSON field 'name' must not be empty";
return false; return false;
} }
if (description_out.empty()) { if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty"; error_out = "JSON field 'description' must not be empty";
return false; return false;
} }
std::string name_lower = name_out; std::string name_lower = name_out;
std::string description_lower = description_out; std::string description_lower = description_out;
std::transform( std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(), name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); }); [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(), std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) { description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (name_lower == "string" || description_lower == "string") { if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content"; error_out = "JSON appears to be a schema placeholder, not content";
return false; return false;
} }
error_out.clear(); error_out.clear();
return true; return true;
}; };
boost::system::error_code ec; boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec); boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error; std::string validation_error;
if (ec) { if (ec) {
std::string extracted; std::string extracted;
if (!ExtractFirstJsonObject(raw, extracted)) { if (!ExtractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message(); return "JSON parse error: " + ec.message();
} }
ec.clear(); ec.clear();
jv = boost::json::parse(extracted, ec); jv = boost::json::parse(extracted, ec);
if (ec) { if (ec) {
return "JSON parse error: " + ec.message(); return "JSON parse error: " + ec.message();
} }
if (!validate_object(jv, validation_error)) { if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error; return validation_error;
} }
return {}; return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
} }
} // namespace } // namespace
@@ -369,33 +367,32 @@ std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
// Forward declarations for helper functions exposed to other translation units // Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context, std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) { std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars); return PrepareRegionContext(region_context, max_chars);
} }
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponsePublic(
ParseTwoLineResponsePublic(const std::string& raw, const std::string& raw, const std::string& error_message) {
const std::string& error_message) { return ParseTwoLineResponse(raw, error_message);
return ParseTwoLineResponse(raw, error_message);
} }
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) { const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt); return ToChatPrompt(model, user_prompt);
} }
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt); return ToChatPrompt(model, system_prompt, user_prompt);
} }
void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
AppendTokenPiece(vocab, token, output); AppendTokenPiece(vocab, token, output);
} }
std::string ValidateBreweryJsonPublic(const std::string& raw, std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out, std::string& name_out,
std::string& description_out) { std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out); return ValidateBreweryJson(raw, name_out, description_out);
} }

View File

@@ -1,195 +1,199 @@
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "llama.h"
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr) if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded"); throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_); const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable"); throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true); llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt = ToChatPromptPublic(model_, prompt); const std::string formatted_prompt = ToChatPromptPublic(model_, prompt);
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8); std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize( int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true); static_cast<int32_t>(prompt_tokens.size()), true, true);
if (token_count < 0) { if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count)); prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize( token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true); static_cast<int32_t>(prompt_tokens.size()), true, true);
} }
if (token_count < 0) if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_)); const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_)); const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) { if (n_ctx <= 1 || n_batch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
} }
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); const int32_t effective_max_tokens =
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); std::max(1, std::min(max_tokens, n_ctx - 1));
prompt_budget = std::max<int32_t>(1, prompt_budget); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"to fit n_batch/n_ctx limits", "tokens "
token_count, prompt_budget); "to fit n_batch/n_ctx limits",
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); token_count, prompt_budget);
token_count = prompt_budget; prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
} token_count = prompt_budget;
}
const llama_batch prompt_batch = llama_batch_get_one( const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size())); prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed"); throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params = llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params(); llama_sampler_chain_default_params();
using SamplerPtr = using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>; std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params), SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free); &llama_sampler_free);
if (!sampler) if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(), llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_)); llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(), llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1)); llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_)); llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens; std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); const llama_token next =
if (llama_vocab_is_eog(vocab, next)) llama_sampler_sample(sampler.get(), context_, -1);
break; if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next); generated_tokens.push_back(next);
llama_token token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0) if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: decode failed during generation"); "LlamaGenerator: decode failed during generation");
} }
std::string output; std::string output;
for (const llama_token token : generated_tokens) for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output); AppendTokenPiecePublic(vocab, token, output);
return output; return output;
} }
std::string LlamaGenerator::Infer(const std::string& system_prompt, std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens) { const std::string& prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr) if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded"); throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_); const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable"); throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true); llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt = const std::string formatted_prompt =
ToChatPromptPublic(model_, system_prompt, prompt); ToChatPromptPublic(model_, system_prompt, prompt);
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8); std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize( int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true); static_cast<int32_t>(prompt_tokens.size()), true, true);
if (token_count < 0) { if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count)); prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize( token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true); static_cast<int32_t>(prompt_tokens.size()), true, true);
} }
if (token_count < 0) if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_)); const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_)); const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0) { if (n_ctx <= 1 || n_batch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
} }
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); const int32_t effective_max_tokens =
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); std::max(1, std::min(max_tokens, n_ctx - 1));
prompt_budget = std::max<int32_t>(1, prompt_budget); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"to fit n_batch/n_ctx limits", "tokens "
token_count, prompt_budget); "to fit n_batch/n_ctx limits",
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); token_count, prompt_budget);
token_count = prompt_budget; prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
} token_count = prompt_budget;
}
const llama_batch prompt_batch = llama_batch_get_one( const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size())); prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0) if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed"); throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params = llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params(); llama_sampler_chain_default_params();
using SamplerPtr = using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>; std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params), SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free); &llama_sampler_free);
if (!sampler) if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(), llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_)); llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(), llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1)); llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_)); llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens; std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); const llama_token next =
if (llama_vocab_is_eog(vocab, next)) llama_sampler_sample(sampler.get(), context_, -1);
break; if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next); generated_tokens.push_back(next);
llama_token token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0) if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: decode failed during generation"); "LlamaGenerator: decode failed during generation");
} }
std::string output; std::string output;
for (const llama_token token : generated_tokens) for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output); AppendTokenPiecePublic(vocab, token, output);
return output; return output;
} }

View File

@@ -1,42 +1,42 @@
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "llama.h"
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) { void LlamaGenerator::Load(const std::string& model_path) {
if (model_path.empty()) if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty"); throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) { if (context_ != nullptr) {
llama_free(context_); llama_free(context_);
context_ = nullptr; context_ = nullptr;
} }
if (model_ != nullptr) { if (model_ != nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
} }
llama_backend_init(); llama_backend_init();
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params); model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) { if (model_ == nullptr) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path); "LlamaGenerator: failed to load model from path: " + model_path);
} }
llama_context_params context_params = llama_context_default_params(); llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048; context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params); context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) { if (context_ == nullptr) {
llama_model_free(model_); llama_model_free(model_);
model_ = nullptr; model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context"); throw std::runtime_error("LlamaGenerator: failed to create context");
} }
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path); spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
} }

View File

@@ -1,26 +1,25 @@
#include <stdexcept> #include <stdexcept>
#include "llama.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) { int seed) {
if (temperature < 0.0f) { if (temperature < 0.0f) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0"); "LlamaGenerator: sampling temperature must be >= 0");
} }
if (!(top_p > 0.0f && top_p <= 1.0f)) { if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]"); "LlamaGenerator: sampling top-p must be in (0, 1]");
} }
if (seed < -1) { if (seed < -1) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random"); "LlamaGenerator: seed must be >= 0, or -1 for random");
} }
sampling_temperature_ = temperature; sampling_temperature_ = temperature;
sampling_top_p_ = top_p; sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED) sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed); : static_cast<uint32_t>(seed);
} }

View File

@@ -1,7 +1,7 @@
#include "data_generation/mock_generator.h" #include <string>
#include <vector>
#include <functional> #include "data_generation/mock_generator.h"
#include <spdlog/spdlog.h>
const std::vector<std::string> MockGenerator::kBreweryAdjectives = { const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
@@ -63,42 +63,3 @@ const std::vector<std::string> MockGenerator::kBios = {
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}
std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}
BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}
UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,12 @@
#include <string>
#include "data_generation/mock_generator.h"
std::size_t MockGenerator::DeterministicHash(const std::string& a,
const std::string& b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}

View File

@@ -0,0 +1,21 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash =
region_context.empty() ? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}

View File

@@ -0,0 +1,13 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,9 @@
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -1,11 +1,13 @@
#include "database/database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *schema = R"( const char* schema = R"(
CREATE TABLE IF NOT EXISTS countries ( CREATE TABLE IF NOT EXISTS countries (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -33,218 +35,219 @@ void SqliteDatabase::InitializeSchema() {
); );
)"; )";
char *errMsg = nullptr; char* errMsg = nullptr;
int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg); int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
std::string error = errMsg ? std::string(errMsg) : "Unknown error"; std::string error = errMsg ? std::string(errMsg) : "Unknown error";
sqlite3_free(errMsg); sqlite3_free(errMsg);
throw std::runtime_error("Failed to create schema: " + error); throw std::runtime_error("Failed to create schema: " + error);
} }
} }
SqliteDatabase::~SqliteDatabase() { SqliteDatabase::~SqliteDatabase() {
if (db_) { if (db_) {
sqlite3_close(db_); sqlite3_close(db_);
} }
} }
void SqliteDatabase::Initialize(const std::string &db_path) { void SqliteDatabase::Initialize(const std::string& db_path) {
int rc = sqlite3_open(db_path.c_str(), &db_); int rc = sqlite3_open(db_path.c_str(), &db_);
if (rc) { if (rc) {
throw std::runtime_error("Failed to open SQLite database: " + db_path); throw std::runtime_error("Failed to open SQLite database: " + db_path);
} }
spdlog::info("OK: SQLite database opened: {}", db_path); spdlog::info("OK: SQLite database opened: {}", db_path);
InitializeSchema(); InitializeSchema();
} }
void SqliteDatabase::BeginTransaction() { void SqliteDatabase::BeginTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) != if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) !=
SQLITE_OK) { SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
throw std::runtime_error("BeginTransaction failed: " + msg); throw std::runtime_error("BeginTransaction failed: " + msg);
} }
} }
void SqliteDatabase::CommitTransaction() { void SqliteDatabase::CommitTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
throw std::runtime_error("CommitTransaction failed: " + msg); throw std::runtime_error("CommitTransaction failed: " + msg);
} }
} }
void SqliteDatabase::InsertCountry(int id, const std::string &name, void SqliteDatabase::InsertCountry(int id, const std::string& name,
const std::string &iso2, const std::string& iso2,
const std::string &iso3) { const std::string& iso3) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO countries (id, name, iso2, iso3) INSERT OR IGNORE INTO countries (id, name, iso2, iso3)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare country insert"); throw std::runtime_error("Failed to prepare country insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, iso2.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso3.c_str(), -1, SQLITE_STATIC);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert country"); throw std::runtime_error("Failed to insert country");
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id,
const std::string &iso2) { const std::string& name,
std::lock_guard<std::mutex> lock(db_mutex_); const std::string& iso2) {
std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO states (id, country_id, name, iso2) INSERT OR IGNORE INTO states (id, country_id, name, iso2)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare state insert"); throw std::runtime_error("Failed to prepare state insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, country_id); sqlite3_bind_int(stmt, 2, country_id);
sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert state"); throw std::runtime_error("Failed to insert state");
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertCity(int id, int state_id, int country_id, void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
const std::string &name, double latitude, const std::string& name, double latitude,
double longitude) { double longitude) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude) INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare city insert"); throw std::runtime_error("Failed to prepare city insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, state_id); sqlite3_bind_int(stmt, 2, state_id);
sqlite3_bind_int(stmt, 3, country_id); sqlite3_bind_int(stmt, 3, country_id);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_double(stmt, 5, latitude); sqlite3_bind_double(stmt, 5, latitude);
sqlite3_bind_double(stmt, 6, longitude); sqlite3_bind_double(stmt, 6, longitude);
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
throw std::runtime_error("Failed to insert city"); throw std::runtime_error("Failed to insert city");
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
std::vector<City> SqliteDatabase::QueryCities() { std::vector<City> SqliteDatabase::QueryCities() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<City> cities; std::vector<City> cities;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
const char *query = "SELECT id, name, country_id FROM cities ORDER BY name"; const char* query = "SELECT id, name, country_id FROM cities ORDER BY name";
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare query"); throw std::runtime_error("Failed to prepare query");
} }
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
int country_id = sqlite3_column_int(stmt, 2); int country_id = sqlite3_column_int(stmt, 2);
cities.push_back({id, name ? std::string(name) : "", country_id}); cities.push_back({id, name ? std::string(name) : "", country_id});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return cities; return cities;
} }
std::vector<Country> SqliteDatabase::QueryCountries(int limit) { std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<Country> countries; std::vector<Country> countries;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, iso3 FROM countries ORDER BY name"; "SELECT id, name, iso2, iso3 FROM countries ORDER BY name";
if (limit > 0) { if (limit > 0) {
query += " LIMIT " + std::to_string(limit); query += " LIMIT " + std::to_string(limit);
} }
int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare countries query"); throw std::runtime_error("Failed to prepare countries query");
} }
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
const char *iso3 = const char* iso3 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 3));
countries.push_back({id, name ? std::string(name) : "", countries.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", iso2 ? std::string(iso2) : "",
iso3 ? std::string(iso3) : ""}); iso3 ? std::string(iso3) : ""});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return countries; return countries;
} }
std::vector<State> SqliteDatabase::QueryStates(int limit) { std::vector<State> SqliteDatabase::QueryStates(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<State> states; std::vector<State> states;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, country_id FROM states ORDER BY name"; "SELECT id, name, iso2, country_id FROM states ORDER BY name";
if (limit > 0) { if (limit > 0) {
query += " LIMIT " + std::to_string(limit); query += " LIMIT " + std::to_string(limit);
} }
int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare states query"); throw std::runtime_error("Failed to prepare states query");
} }
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
int country_id = sqlite3_column_int(stmt, 3); int country_id = sqlite3_column_int(stmt, 3);
states.push_back({id, name ? std::string(name) : "", states.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", country_id}); iso2 ? std::string(iso2) : "", country_id});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
return states; return states;
} }

View File

@@ -1,65 +1,66 @@
#include <chrono> #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "json_handling/json_loader.h" #include <chrono>
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string &json_path, void JsonLoader::LoadWorldCities(const std::string& json_path,
SqliteDatabase &db) { SqliteDatabase& db) {
constexpr size_t kBatchSize = 10000; constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path); spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path);
db.BeginTransaction(); db.BeginTransaction();
bool transactionOpen = true; bool transactionOpen = true;
size_t citiesProcessed = 0; size_t citiesProcessed = 0;
try { try {
StreamingJsonParser::Parse( StreamingJsonParser::Parse(
json_path, db, json_path, db,
[&](const CityRecord &record) { [&](const CityRecord& record) {
db.InsertCity(record.id, record.state_id, record.country_id, db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude); record.name, record.latitude, record.longitude);
++citiesProcessed; ++citiesProcessed;
if (citiesProcessed % kBatchSize == 0) { if (citiesProcessed % kBatchSize == 0) {
db.CommitTransaction(); db.CommitTransaction();
db.BeginTransaction(); db.BeginTransaction();
} }
}, },
[&](size_t current, size_t /*total*/) { [&](size_t current, size_t /*total*/) {
if (current % kBatchSize == 0 && current > 0) { if (current % kBatchSize == 0 && current > 0) {
spdlog::info(" [Progress] Parsed {} cities...", current); spdlog::info(" [Progress] Parsed {} cities...", current);
} }
}); });
spdlog::info(" OK: Parsed all cities from JSON"); spdlog::info(" OK: Parsed all cities from JSON");
if (transactionOpen) { if (transactionOpen) {
db.CommitTransaction(); db.CommitTransaction();
transactionOpen = false; transactionOpen = false;
} }
} catch (...) { } catch (...) {
if (transactionOpen) { if (transactionOpen) {
db.CommitTransaction(); db.CommitTransaction();
} }
throw; throw;
} }
auto endTime = std::chrono::high_resolution_clock::now(); auto endTime = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
endTime - startTime); endTime - startTime);
spdlog::info("\n=== World City Data Loading Summary ===\n"); spdlog::info("\n=== World City Data Loading Summary ===\n");
spdlog::info("Cities inserted: {}", citiesProcessed); spdlog::info("Cities inserted: {}", citiesProcessed);
spdlog::info("Elapsed time: {} ms", duration.count()); spdlog::info("Elapsed time: {} ms", duration.count());
long long throughput = long long throughput =
(citiesProcessed > 0 && duration.count() > 0) (citiesProcessed > 0 && duration.count() > 0)
? (1000LL * static_cast<long long>(citiesProcessed)) / ? (1000LL * static_cast<long long>(citiesProcessed)) /
static_cast<long long>(duration.count()) static_cast<long long>(duration.count())
: 0LL; : 0LL;
spdlog::info("Throughput: {} cities/sec", throughput); spdlog::info("Throughput: {} cities/sec", throughput);
spdlog::info("=======================================\n"); spdlog::info("=======================================\n");
} }

View File

@@ -1,288 +1,289 @@
#include <cstdio> #include "json_handling/stream_parser.h"
#include <stdexcept>
#include <spdlog/spdlog.h>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp> #include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <cstdio>
#include <stdexcept>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler { class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>; friend class boost::json::basic_parser<CityRecordHandler>;
public: public:
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
struct ParseContext { struct ParseContext {
SqliteDatabase *db = nullptr; SqliteDatabase* db = nullptr;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t cities_emitted = 0; size_t cities_emitted = 0;
size_t total_file_size = 0; size_t total_file_size = 0;
int countries_inserted = 0; int countries_inserted = 0;
int states_inserted = 0; int states_inserted = 0;
}; };
explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {} explicit CityRecordHandler(ParseContext& ctx) : context(ctx) {}
private: private:
ParseContext &context; ParseContext& context;
int depth = 0; int depth = 0;
bool in_countries_array = false; bool in_countries_array = false;
bool in_country_object = false; bool in_country_object = false;
bool in_states_array = false; bool in_states_array = false;
bool in_state_object = false; bool in_state_object = false;
bool in_cities_array = false; bool in_cities_array = false;
bool building_city = false; bool building_city = false;
int current_country_id = 0; int current_country_id = 0;
int current_state_id = 0; int current_state_id = 0;
CityRecord current_city = {}; CityRecord current_city = {};
std::string current_key; std::string current_key;
std::string current_key_val; std::string current_key_val;
std::string current_string_val; std::string current_string_val;
std::string country_info[3]; std::string country_info[3];
std::string state_info[2]; std::string state_info[2];
// Boost.JSON SAX Hooks // Boost.JSON SAX Hooks
bool on_document_begin(boost::system::error_code &) { return true; } bool on_document_begin(boost::system::error_code&) { return true; }
bool on_document_end(boost::system::error_code &) { return true; } bool on_document_end(boost::system::error_code&) { return true; }
bool on_array_begin(boost::system::error_code &) { bool on_array_begin(boost::system::error_code&) {
depth++; depth++;
if (depth == 1) { if (depth == 1) {
in_countries_array = true; in_countries_array = true;
} else if (depth == 3 && current_key == "states") { } else if (depth == 3 && current_key == "states") {
in_states_array = true; in_states_array = true;
} else if (depth == 5 && current_key == "cities") { } else if (depth == 5 && current_key == "cities") {
in_cities_array = true; in_cities_array = true;
}
return true;
}
bool on_array_end(std::size_t, boost::system::error_code &) {
if (depth == 1) {
in_countries_array = false;
} else if (depth == 3) {
in_states_array = false;
} else if (depth == 5) {
in_cities_array = false;
}
depth--;
return true;
}
bool on_object_begin(boost::system::error_code &) {
depth++;
if (depth == 2 && in_countries_array) {
in_country_object = true;
current_country_id = 0;
country_info[0].clear();
country_info[1].clear();
country_info[2].clear();
} else if (depth == 4 && in_states_array) {
in_state_object = true;
current_state_id = 0;
state_info[0].clear();
state_info[1].clear();
} else if (depth == 6 && in_cities_array) {
building_city = true;
current_city = {};
}
return true;
}
bool on_object_end(std::size_t, boost::system::error_code &) {
if (depth == 6 && building_city) {
if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) {
current_city.state_id = current_state_id;
current_city.country_id = current_country_id;
try {
context.on_city(current_city);
context.cities_emitted++;
if (context.on_progress && context.cities_emitted % 10000 == 0) {
context.on_progress(context.cities_emitted,
context.total_file_size);
}
} catch (const std::exception &e) {
spdlog::warn("Record parsing failed: {}", e.what());
}
} }
building_city = false; return true;
} else if (depth == 4 && in_state_object) { }
if (current_state_id > 0 && current_country_id > 0) {
try { bool on_array_end(std::size_t, boost::system::error_code&) {
context.db->InsertState(current_state_id, current_country_id, if (depth == 1) {
state_info[0], state_info[1]); in_countries_array = false;
context.states_inserted++; } else if (depth == 3) {
} catch (const std::exception &e) { in_states_array = false;
spdlog::warn("Record parsing failed: {}", e.what()); } else if (depth == 5) {
} in_cities_array = false;
} }
in_state_object = false; depth--;
} else if (depth == 2 && in_country_object) { return true;
if (current_country_id > 0) { }
try {
context.db->InsertCountry(current_country_id, country_info[0], bool on_object_begin(boost::system::error_code&) {
country_info[1], country_info[2]); depth++;
context.countries_inserted++; if (depth == 2 && in_countries_array) {
} catch (const std::exception &e) { in_country_object = true;
spdlog::warn("Record parsing failed: {}", e.what()); current_country_id = 0;
} country_info[0].clear();
country_info[1].clear();
country_info[2].clear();
} else if (depth == 4 && in_states_array) {
in_state_object = true;
current_state_id = 0;
state_info[0].clear();
state_info[1].clear();
} else if (depth == 6 && in_cities_array) {
building_city = true;
current_city = {};
} }
in_country_object = false; return true;
} }
depth--; bool on_object_end(std::size_t, boost::system::error_code&) {
return true; if (depth == 6 && building_city) {
} if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) {
current_city.state_id = current_state_id;
current_city.country_id = current_country_id;
bool on_key_part(boost::json::string_view s, std::size_t, try {
boost::system::error_code &) { context.on_city(current_city);
current_key_val.append(s.data(), s.size()); context.cities_emitted++;
return true;
}
bool on_key(boost::json::string_view s, std::size_t, if (context.on_progress && context.cities_emitted % 10000 == 0) {
boost::system::error_code &) { context.on_progress(context.cities_emitted,
current_key_val.append(s.data(), s.size()); context.total_file_size);
current_key = current_key_val; }
current_key_val.clear(); } catch (const std::exception& e) {
return true; spdlog::warn("Record parsing failed: {}", e.what());
} }
}
bool on_string_part(boost::json::string_view s, std::size_t, building_city = false;
boost::system::error_code &) { } else if (depth == 4 && in_state_object) {
current_string_val.append(s.data(), s.size()); if (current_state_id > 0 && current_country_id > 0) {
return true; try {
} context.db->InsertState(current_state_id, current_country_id,
state_info[0], state_info[1]);
bool on_string(boost::json::string_view s, std::size_t, context.states_inserted++;
boost::system::error_code &) { } catch (const std::exception& e) {
current_string_val.append(s.data(), s.size()); spdlog::warn("Record parsing failed: {}", e.what());
}
if (building_city && current_key == "name") { }
current_city.name = current_string_val; in_state_object = false;
} else if (in_state_object && current_key == "name") { } else if (depth == 2 && in_country_object) {
state_info[0] = current_string_val; if (current_country_id > 0) {
} else if (in_state_object && current_key == "iso2") { try {
state_info[1] = current_string_val; context.db->InsertCountry(current_country_id, country_info[0],
} else if (in_country_object && current_key == "name") { country_info[1], country_info[2]);
country_info[0] = current_string_val; context.countries_inserted++;
} else if (in_country_object && current_key == "iso2") { } catch (const std::exception& e) {
country_info[1] = current_string_val; spdlog::warn("Record parsing failed: {}", e.what());
} else if (in_country_object && current_key == "iso3") { }
country_info[2] = current_string_val; }
} in_country_object = false;
current_string_val.clear();
return true;
}
bool on_number_part(boost::json::string_view, boost::system::error_code &) {
return true;
}
bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code &) {
if (building_city && current_key == "id") {
current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") {
current_state_id = static_cast<int>(i);
} else if (in_country_object && current_key == "id") {
current_country_id = static_cast<int>(i);
}
return true;
}
bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code &ec) {
return on_int64(static_cast<int64_t>(u), "", ec);
}
bool on_double(double d, boost::json::string_view,
boost::system::error_code &) {
if (building_city) {
if (current_key == "latitude") {
current_city.latitude = d;
} else if (current_key == "longitude") {
current_city.longitude = d;
} }
}
return true;
}
bool on_bool(bool, boost::system::error_code &) { return true; } depth--;
bool on_null(boost::system::error_code &) { return true; } return true;
bool on_comment_part(boost::json::string_view, boost::system::error_code &) { }
return true;
} bool on_key_part(boost::json::string_view s, std::size_t,
bool on_comment(boost::json::string_view, boost::system::error_code &) { boost::system::error_code&) {
return true; current_key_val.append(s.data(), s.size());
} return true;
}
bool on_key(boost::json::string_view s, std::size_t,
boost::system::error_code&) {
current_key_val.append(s.data(), s.size());
current_key = current_key_val;
current_key_val.clear();
return true;
}
bool on_string_part(boost::json::string_view s, std::size_t,
boost::system::error_code&) {
current_string_val.append(s.data(), s.size());
return true;
}
bool on_string(boost::json::string_view s, std::size_t,
boost::system::error_code&) {
current_string_val.append(s.data(), s.size());
if (building_city && current_key == "name") {
current_city.name = current_string_val;
} else if (in_state_object && current_key == "name") {
state_info[0] = current_string_val;
} else if (in_state_object && current_key == "iso2") {
state_info[1] = current_string_val;
} else if (in_country_object && current_key == "name") {
country_info[0] = current_string_val;
} else if (in_country_object && current_key == "iso2") {
country_info[1] = current_string_val;
} else if (in_country_object && current_key == "iso3") {
country_info[2] = current_string_val;
}
current_string_val.clear();
return true;
}
bool on_number_part(boost::json::string_view, boost::system::error_code&) {
return true;
}
bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code&) {
if (building_city && current_key == "id") {
current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") {
current_state_id = static_cast<int>(i);
} else if (in_country_object && current_key == "id") {
current_country_id = static_cast<int>(i);
}
return true;
}
bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code& ec) {
return on_int64(static_cast<int64_t>(u), "", ec);
}
bool on_double(double d, boost::json::string_view,
boost::system::error_code&) {
if (building_city) {
if (current_key == "latitude") {
current_city.latitude = d;
} else if (current_key == "longitude") {
current_city.longitude = d;
}
}
return true;
}
bool on_bool(bool, boost::system::error_code&) { return true; }
bool on_null(boost::system::error_code&) { return true; }
bool on_comment_part(boost::json::string_view, boost::system::error_code&) {
return true;
}
bool on_comment(boost::json::string_view, boost::system::error_code&) {
return true;
}
}; };
void StreamingJsonParser::Parse( void StreamingJsonParser::Parse(
const std::string &file_path, SqliteDatabase &db, const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); FILE* file = std::fopen(file_path.c_str(), "rb");
if (!file) {
throw std::runtime_error("Failed to open JSON file: " + file_path);
}
FILE *file = std::fopen(file_path.c_str(), "rb"); size_t total_size = 0;
if (!file) { if (std::fseek(file, 0, SEEK_END) == 0) {
throw std::runtime_error("Failed to open JSON file: " + file_path); long file_size = std::ftell(file);
} if (file_size > 0) {
total_size = static_cast<size_t>(file_size);
size_t total_size = 0;
if (std::fseek(file, 0, SEEK_END) == 0) {
long file_size = std::ftell(file);
if (file_size > 0) {
total_size = static_cast<size_t>(file_size);
}
std::rewind(file);
}
CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0,
total_size, 0, 0};
boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx);
char buf[65536];
size_t bytes_read;
boost::system::error_code ec;
while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
char const *p = buf;
std::size_t remain = bytes_read;
while (remain > 0) {
std::size_t consumed = parser.write_some(true, p, remain, ec);
if (ec) {
std::fclose(file);
throw std::runtime_error("JSON parse error: " + ec.message());
} }
p += consumed; std::rewind(file);
remain -= consumed; }
}
}
parser.write_some(false, nullptr, 0, ec); // Signal EOF CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size,
std::fclose(file); 0, 0};
boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx);
if (ec) { char buf[65536];
throw std::runtime_error("JSON parse error at EOF: " + ec.message()); size_t bytes_read;
} boost::system::error_code ec;
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); char const* p = buf;
std::size_t remain = bytes_read;
while (remain > 0) {
std::size_t consumed = parser.write_some(true, p, remain, ec);
if (ec) {
std::fclose(file);
throw std::runtime_error("JSON parse error: " + ec.message());
}
p += consumed;
remain -= consumed;
}
}
parser.write_some(false, nullptr, 0, ec); // Signal EOF
std::fclose(file);
if (ec) {
throw std::runtime_error("JSON parse error at EOF: " + ec.message());
}
spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted,
ctx.cities_emitted);
} }

View File

@@ -1,139 +1,141 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdio>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
CurlGlobalState::CurlGlobalState() { CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally"); "[CURLWebClient] Failed to initialize libcurl globally");
} }
} }
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); } CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
namespace { namespace {
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
size_t WriteCallbackString(void *contents, size_t size, size_t nmemb, size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *s = static_cast<std::string *>(userp); auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char *>(contents), realsize); s->append(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
// curl write callback that writes to a file stream // curl write callback that writes to a file stream
size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb, size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *outFile = static_cast<std::ofstream *>(userp); auto* outFile = static_cast<std::ofstream*>(userp);
outFile->write(static_cast<char *>(contents), realsize); outFile->write(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
// RAII wrapper for CURL handle using unique_ptr // RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>; using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() { CurlHandle create_handle() {
CURL *handle = curl_easy_init(); CURL* handle = curl_easy_init();
if (!handle) { if (!handle) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle"); "[CURLWebClient] Failed to initialize libcurl handle");
} }
return CurlHandle(handle, &curl_easy_cleanup); return CurlHandle(handle, &curl_easy_cleanup);
} }
void set_common_get_options(CURL *curl, const std::string &url, void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) { long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout); curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout); curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
} }
} // namespace } // namespace
CURLWebClient::CURLWebClient() {} CURLWebClient::CURLWebClient() {}
CURLWebClient::~CURLWebClient() {} CURLWebClient::~CURLWebClient() {}
void CURLWebClient::DownloadToFile(const std::string &url, void CURLWebClient::DownloadToFile(const std::string& url,
const std::string &file_path) { const std::string& file_path) {
auto curl = create_handle(); auto curl = create_handle();
std::ofstream outFile(file_path, std::ios::binary); std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) { if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + throw std::runtime_error(
file_path); "[CURLWebClient] Cannot open file for writing: " + file_path);
} }
set_common_get_options(curl.get(), url, 30L, 300L); set_common_get_options(curl.get(), url, 30L, 300L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void *>(&outFile)); static_cast<void*>(&outFile));
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
outFile.close(); outFile.close();
if (res != CURLE_OK) { if (res != CURLE_OK) {
std::remove(file_path.c_str()); std::remove(file_path.c_str());
std::string error = std::string("[CURLWebClient] Download failed: ") + std::string error = std::string("[CURLWebClient] Download failed: ") +
curl_easy_strerror(res); curl_easy_strerror(res);
throw std::runtime_error(error); throw std::runtime_error(error);
} }
long httpCode = 0; long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) { if (httpCode != 200) {
std::remove(file_path.c_str()); std::remove(file_path.c_str());
std::stringstream ss; std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
} }
std::string CURLWebClient::Get(const std::string &url) { std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle(); auto curl = create_handle();
std::string response_string; std::string response_string;
set_common_get_options(curl.get(), url, 10L, 20L); set_common_get_options(curl.get(), url, 10L, 20L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) { if (res != CURLE_OK) {
std::string error = std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error); throw std::runtime_error(error);
} }
long httpCode = 0; long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) { if (httpCode != 200) {
std::stringstream ss; std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
return response_string; return response_string;
} }
std::string CURLWebClient::UrlEncode(const std::string &value) { std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char *output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (output) { if (output) {
std::string result(output); std::string result(output);
curl_free(output); curl_free(output);
return result; return result;
} }
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
} }

View File

@@ -1,77 +1,78 @@
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <boost/json.hpp>
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client) WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}
std::string WikipediaService::FetchExtract(std::string_view query) { std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string encoded = client_->UrlEncode(std::string(query)); const std::string encoded = client_->UrlEncode(std::string(query));
const std::string url = const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=true&format=json"; "&prop=extracts&explaintext=true&format=json";
const std::string body = client_->Get(url); const std::string body = client_->Get(url);
boost::system::error_code ec; boost::system::error_code ec;
boost::json::value doc = boost::json::parse(body, ec); boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) { if (!ec && doc.is_object()) {
auto &pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto &page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str()); std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'", spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query); extract.size(), query);
return extract; return extract;
}
} }
} }
}
return {}; return {};
} }
std::string WikipediaService::GetSummary(std::string_view city, std::string WikipediaService::GetSummary(std::string_view city,
std::string_view country) { std::string_view country) {
const std::string key = std::string(city) + "|" + std::string(country); const std::string key = std::string(city) + "|" + std::string(country);
const auto cacheIt = cache_.find(key); const auto cacheIt = cache_.find(key);
if (cacheIt != cache_.end()) { if (cacheIt != cache_.end()) {
return cacheIt->second; return cacheIt->second;
} }
std::string result; std::string result;
if (!client_) { if (!client_) {
cache_.emplace(key, result); cache_.emplace(key, result);
return result; return result;
} }
std::string regionQuery(city); std::string regionQuery(city);
if (!country.empty()) { if (!country.empty()) {
regionQuery += ", "; regionQuery += ", ";
regionQuery += country; regionQuery += country;
} }
const std::string beerQuery = "beer in " + std::string(city); const std::string beerQuery = "beer in " + std::string(city);
try { try {
const std::string regionExtract = FetchExtract(regionQuery); const std::string regionExtract = FetchExtract(regionQuery);
const std::string beerExtract = FetchExtract(beerQuery); const std::string beerExtract = FetchExtract(beerQuery);
if (!regionExtract.empty()) { if (!regionExtract.empty()) {
result += regionExtract; result += regionExtract;
} }
if (!beerExtract.empty()) { if (!beerExtract.empty()) {
if (!result.empty()) if (!result.empty()) result += "\n\n";
result += "\n\n"; result += beerExtract;
result += beerExtract; }
} } catch (const std::runtime_error& e) {
} catch (const std::runtime_error &e) { spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, e.what());
e.what()); }
}
cache_.emplace(key, result); cache_.emplace(key, result);
return result; return result;
} }