format codebase

This commit is contained in:
Aaron Po
2026-04-02 21:46:46 -04:00
parent ba165d8aa7
commit 3af053f0eb
31 changed files with 1479 additions and 1445 deletions

View File

@@ -1,10 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
Standard: c++23 ColumnLimit: 80
ColumnLimit: 100 IndentWidth: 3
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
... ...

View File

@@ -90,7 +90,11 @@ set(PIPELINE_SOURCES
src/data_generation/llama/generate_brewery.cpp src/data_generation/llama/generate_brewery.cpp
src/data_generation/llama/generate_user.cpp src/data_generation/llama/generate_user.cpp
src/data_generation/llama/helpers.cpp src/data_generation/llama/helpers.cpp
src/data_generation/mock_generator.cpp src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/load.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
src/json_handling/stream_parser.cpp src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp src/wikipedia/wikipedia_service.cpp
src/main.cpp src/main.cpp

View File

@@ -9,7 +9,7 @@
/// @brief Downloads and caches source geography JSON payloads. /// @brief Downloads and caches source geography JSON payloads.
class DataDownloader { class DataDownloader {
public: public:
/// @brief Initializes global curl state used by this downloader. /// @brief Initializes global curl state used by this downloader.
explicit DataDownloader(std::shared_ptr<WebClient> web_client); explicit DataDownloader(std::shared_ptr<WebClient> web_client);
@@ -18,12 +18,13 @@ public:
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string &cache_path, const std::string& cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string& commit =
"c5eb7772" // Stable commit: 2026-03-28 export
); );
private: private:
static bool FileExists(const std::string &file_path); static bool FileExists(const std::string& file_path);
std::shared_ptr<WebClient> web_client_; std::shared_ptr<WebClient> web_client_;
}; };

View File

@@ -14,16 +14,16 @@ struct UserResult {
}; };
class DataGenerator { class DataGenerator {
public: public:
virtual ~DataGenerator() = default; virtual ~DataGenerator() = default;
virtual void Load(const std::string &model_path) = 0; virtual void Load(const std::string& model_path) = 0;
virtual BreweryResult GenerateBrewery(const std::string &city_name, virtual BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) = 0; const std::string& region_context) = 0;
virtual UserResult GenerateUser(const std::string &locale) = 0; virtual UserResult GenerateUser(const std::string& locale) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_ #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -10,29 +10,29 @@ struct llama_model;
struct llama_context; struct llama_context;
class LlamaGenerator final : public DataGenerator { class LlamaGenerator final : public DataGenerator {
public: public:
LlamaGenerator() = default; LlamaGenerator() = default;
~LlamaGenerator() override; ~LlamaGenerator() override;
void SetSamplingOptions(float temperature, float top_p, int seed = -1); void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void Load(const std::string &model_path) override; void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name, BreweryResult GenerateBrewery(const std::string& city_name,
const std::string &country_name, const std::string& country_name,
const std::string &region_context) override; const std::string& region_context) override;
UserResult GenerateUser(const std::string &locale) override; UserResult GenerateUser(const std::string& locale) override;
private: private:
std::string Infer(const std::string &prompt, int max_tokens = 10000); std::string Infer(const std::string& prompt, int max_tokens = 10000);
// Overload that allows passing a system message separately so chat-capable // Overload that allows passing a system message separately so chat-capable
// models receive a proper system role instead of having the system text // models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal // concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output). // reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt, std::string Infer(const std::string& system_prompt,
int max_tokens = 10000); const std::string& prompt, int max_tokens = 10000);
llama_model *model_ = nullptr; llama_model* model_ = nullptr;
llama_context *context_ = nullptr; llama_context* context_ = nullptr;
float sampling_temperature_ = 0.8f; float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f; float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu; uint32_t sampling_seed_ = 0xFFFFFFFFu;

View File

@@ -12,18 +12,17 @@ typedef int llama_token;
std::string PrepareRegionContextPublic(std::string_view region_context, std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700); std::size_t max_chars = 700);
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponsePublic(
ParseTwoLineResponsePublic(const std::string& raw, const std::string& raw, const std::string& error_message);
const std::string& error_message);
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt); const std::string& user_prompt);
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt); const std::string& user_prompt);
void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output); std::string& output);
std::string ValidateBreweryJsonPublic(const std::string& raw, std::string ValidateBreweryJsonPublic(const std::string& raw,

View File

@@ -1,21 +1,22 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string> #include <string>
#include <vector> #include <vector>
class MockGenerator final : public DataGenerator { #include "data_generation/data_generator.h"
public:
void Load(const std::string &model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) override;
UserResult GenerateUser(const std::string &locale) override;
private: class MockGenerator final : public DataGenerator {
static std::size_t DeterministicHash(const std::string &a, public:
const std::string &b); void Load(const std::string& model_path) override;
BreweryResult GenerateBrewery(const std::string& city_name,
const std::string& country_name,
const std::string& region_context) override;
UserResult GenerateUser(const std::string& locale) override;
private:
static std::size_t DeterministicHash(const std::string& a,
const std::string& b);
static const std::vector<std::string> kBreweryAdjectives; static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns; static const std::vector<std::string> kBreweryNouns;

View File

@@ -1,8 +1,9 @@
#ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
@@ -39,18 +40,18 @@ struct City {
/// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks.
class SqliteDatabase { class SqliteDatabase {
private: private:
sqlite3 *db_ = nullptr; sqlite3* db_ = nullptr;
std::mutex db_mutex_; std::mutex db_mutex_;
void InitializeSchema(); void InitializeSchema();
public: public:
/// @brief Closes the SQLite connection if initialized. /// @brief Closes the SQLite connection if initialized.
~SqliteDatabase(); ~SqliteDatabase();
/// @brief Opens the SQLite database at db_path and creates schema objects. /// @brief Opens the SQLite database at db_path and creates schema objects.
void Initialize(const std::string &db_path = ":memory:"); void Initialize(const std::string& db_path = ":memory:");
/// @brief Starts a database transaction for batched writes. /// @brief Starts a database transaction for batched writes.
void BeginTransaction(); void BeginTransaction();
@@ -59,16 +60,16 @@ public:
void CommitTransaction(); void CommitTransaction();
/// @brief Inserts a country row. /// @brief Inserts a country row.
void InsertCountry(int id, const std::string &name, const std::string &iso2, void InsertCountry(int id, const std::string& name, const std::string& iso2,
const std::string &iso3); const std::string& iso3);
/// @brief Inserts a state row linked to a country. /// @brief Inserts a state row linked to a country.
void InsertState(int id, int country_id, const std::string &name, void InsertState(int id, int country_id, const std::string& name,
const std::string &iso2); const std::string& iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int state_id, int country_id, const std::string &name, void InsertCity(int id, int state_id, int country_id,
double latitude, double longitude); const std::string& name, double latitude, double longitude);
/// @brief Returns city records including parent country id. /// @brief Returns city records including parent country id.
std::vector<City> QueryCities(); std::vector<City> QueryCities();

View File

@@ -1,15 +1,17 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include <string>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing. /// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON file and writes country/state/city rows into db. /// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); static void LoadWorldCities(const std::string& json_path,
SqliteDatabase& db);
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,10 +1,11 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "database/database.h"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
class SqliteDatabase; class SqliteDatabase;
@@ -20,13 +21,13 @@ struct CityRecord {
/// @brief Streaming SAX parser that emits city records during traversal. /// @brief Streaming SAX parser that emits city records during traversal.
class StreamingJsonParser { class StreamingJsonParser {
public: public:
/// @brief Parses file_path and invokes callbacks for city rows and progress. /// @brief Parses file_path and invokes callbacks for city rows and progress.
static void Parse(const std::string &file_path, SqliteDatabase &db, static void Parse(const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress = nullptr); std::function<void(size_t, size_t)> on_progress = nullptr);
private: private:
/// @brief Mutable SAX handler state while traversing nested JSON arrays. /// @brief Mutable SAX handler state while traversing nested JSON arrays.
struct ParseState { struct ParseState {
int current_country_id = 0; int current_country_id = 0;
@@ -42,7 +43,7 @@ private:
bool in_states_array = false; bool in_states_array = false;
bool in_cities_array = false; bool in_cities_array = false;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t bytes_processed = 0; size_t bytes_processed = 0;
}; };

View File

@@ -1,29 +1,30 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory> #include <memory>
#include "web_client/web_client.h"
// RAII for curl_global_init/cleanup. // RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl // An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application. // operations and exist for the lifetime of the application.
class CurlGlobalState { class CurlGlobalState {
public: public:
CurlGlobalState(); CurlGlobalState();
~CurlGlobalState(); ~CurlGlobalState();
CurlGlobalState(const CurlGlobalState &) = delete; CurlGlobalState(const CurlGlobalState&) = delete;
CurlGlobalState &operator=(const CurlGlobalState &) = delete; CurlGlobalState& operator=(const CurlGlobalState&) = delete;
}; };
class CURLWebClient : public WebClient { class CURLWebClient : public WebClient {
public: public:
CURLWebClient(); CURLWebClient();
~CURLWebClient() override; ~CURLWebClient() override;
void DownloadToFile(const std::string &url, void DownloadToFile(const std::string& url,
const std::string &file_path) override; const std::string& file_path) override;
std::string Get(const std::string &url) override; std::string Get(const std::string& url) override;
std::string UrlEncode(const std::string &value) override; std::string UrlEncode(const std::string& value) override;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -4,19 +4,19 @@
#include <string> #include <string>
class WebClient { class WebClient {
public: public:
virtual ~WebClient() = default; virtual ~WebClient() = default;
// Downloads content from a URL to a file. Throws on error. // Downloads content from a URL to a file. Throws on error.
virtual void DownloadToFile(const std::string &url, virtual void DownloadToFile(const std::string& url,
const std::string &file_path) = 0; const std::string& file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on // Performs a GET request and returns the response body as a string. Throws
// error. // on error.
virtual std::string Get(const std::string &url) = 0; virtual std::string Get(const std::string& url) = 0;
// URL-encodes a string. // URL-encodes a string.
virtual std::string UrlEncode(const std::string &value) = 0; virtual std::string UrlEncode(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -10,7 +10,7 @@
/// @brief Provides cached Wikipedia summary lookups for city and country pairs. /// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService { class WikipediaService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::shared_ptr<WebClient> client); explicit WikipediaService(std::shared_ptr<WebClient> client);
@@ -18,7 +18,7 @@ public:
[[nodiscard]] std::string GetSummary(std::string_view city, [[nodiscard]] std::string GetSummary(std::string_view city,
std::string_view country); std::string_view country);
private: private:
std::string FetchExtract(std::string_view query); std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_; std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_; std::unordered_map<std::string, std::string> cache_;

View File

@@ -1,23 +1,25 @@
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <spdlog/spdlog.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include "web_client/web_client.h"
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client) DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {} : web_client_(std::move(web_client)) {}
DataDownloader::~DataDownloader() {} DataDownloader::~DataDownloader() {}
bool DataDownloader::FileExists(const std::string &file_path) { bool DataDownloader::FileExists(const std::string& file_path) {
return std::filesystem::exists(file_path); return std::filesystem::exists(file_path);
} }
std::string std::string DataDownloader::DownloadCountriesDatabase(
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, const std::string& cache_path, const std::string& commit) {
const std::string &commit) {
if (FileExists(cache_path)) { if (FileExists(cache_path)) {
spdlog::info("[DataDownloader] Cache hit: {}", cache_path); spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
return cache_path; return cache_path;
@@ -28,7 +30,8 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cache_path,
short_commit = commit.substr(0, 7); short_commit = commit.substr(0, 7);
} }
std::string url = "https://raw.githubusercontent.com/dr5hn/" std::string url =
"https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" + "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json"; short_commit + "/json/countries+states+cities.json";

View File

@@ -1,6 +1,5 @@
#include "llama.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
LlamaGenerator::~LlamaGenerator() { LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) { if (context_ != nullptr) {

View File

@@ -1,25 +1,25 @@
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
BreweryResult BreweryResult LlamaGenerator::GenerateBrewery(
LlamaGenerator::GenerateBrewery(const std::string& city_name, const std::string& city_name, const std::string& country_name,
const std::string& country_name,
const std::string& region_context) { const std::string& region_context) {
const std::string safe_region_context = const std::string safe_region_context =
PrepareRegionContextPublic(region_context); PrepareRegionContextPublic(region_context);
const std::string system_prompt = const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. " "You are the brewmaster and owner of a local craft brewery. "
"Your writing is vivid, specific to place, and avoids generic beer " "Write a name and a short, soulful description for your brewery that "
"cliches. " "reflects your pride in the local community and your craft. "
"You must output ONLY valid JSON. " "The tone should be authentic and welcoming, like a note on a "
"The JSON schema must be exactly: {\"name\": \"string\", " "chalkboard "
"\"description\": \"string\"}. " "menu. Output ONLY a single JSON object with keys \"name\" and "
"\"description\". "
"Do not include markdown formatting or backticks."; "Do not include markdown formatting or backticks.";
std::string prompt = std::string prompt =
@@ -52,7 +52,8 @@ LlamaGenerator::GenerateBrewery(const std::string& city_name,
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error); attempt + 1, validation_error);
prompt = "Your previous response was invalid. Error: " + validation_error + prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: " "\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}." "{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys." "\nDo not include markdown, comments, or extra keys."
@@ -65,7 +66,8 @@ LlamaGenerator::GenerateBrewery(const std::string& city_name,
: std::string("\nRegional context: ") + safe_region_context); : std::string("\nRegional context: ") + safe_region_context);
} }
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}", "{}",
max_attempts, last_error.empty() ? raw : last_error); max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response"); throw std::runtime_error("LlamaGenerator: malformed brewery response");

View File

@@ -1,9 +1,9 @@
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
@@ -40,17 +40,18 @@ UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
throw std::runtime_error("LlamaGenerator: malformed user response"); throw std::runtime_error("LlamaGenerator: malformed user response");
} }
if (bio.size() > 200) if (bio.size() > 200) bio = bio.substr(0, 200);
bio = bio.substr(0, 200);
return {username, bio}; return {username, bio};
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}", spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what()); attempt + 1, e.what());
} }
} }
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}", spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw); max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response"); throw std::runtime_error("LlamaGenerator: malformed user response");
} }

View File

@@ -1,15 +1,14 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <boost/json.hpp>
#include <cctype> #include <cctype>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
namespace { namespace {
@@ -103,8 +102,8 @@ std::string StripCommonPrefix(std::string line) {
return Trim(std::move(line)); return Trim(std::move(line));
} }
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponse(
ParseTwoLineResponse(const std::string& raw, const std::string& error_message) { const std::string& raw, const std::string& error_message) {
std::string normalized = raw; std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n'); std::replace(normalized.begin(), normalized.end(), '\r', '\n');
@@ -113,50 +112,45 @@ ParseTwoLineResponse(const std::string& raw, const std::string& error_message) {
std::string line; std::string line;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line)); line = StripCommonPrefix(std::move(line));
if (!line.empty()) if (!line.empty()) lines.push_back(std::move(line));
lines.push_back(std::move(line));
} }
std::vector<std::string> filtered; std::vector<std::string> filtered;
for (auto &l : lines) { for (auto& l : lines) {
std::string low = l; std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (!l.empty() && l.front() == '<' && low.back() == '>') if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
continue; if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l)); filtered.push_back(std::move(l));
} }
if (filtered.size() < 2) if (filtered.size() < 2) throw std::runtime_error(error_message);
throw std::runtime_error(error_message);
std::string first = Trim(filtered.front()); std::string first = Trim(filtered.front());
std::string second; std::string second;
for (size_t i = 1; i < filtered.size(); ++i) { for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) if (!second.empty()) second += ' ';
second += ' ';
second += filtered[i]; second += filtered[i];
} }
second = Trim(std::move(second)); second = Trim(std::move(second));
if (first.empty() || second.empty()) if (first.empty() || second.empty()) throw std::runtime_error(error_message);
throw std::runtime_error(error_message);
return {first, second}; return {first, second};
} }
std::string ToChatPrompt(const llama_model *model, std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) { const std::string& user_prompt) {
const char *tmpl = llama_model_chat_template(model, nullptr); const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) { if (tmpl == nullptr) {
return user_prompt; return user_prompt;
} }
const llama_chat_message message{"user", user_prompt.c_str()}; const llama_chat_message message{"user", user_prompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, user_prompt.size() * 4)); std::vector<char> buffer(
std::max<std::size_t>(1024, user_prompt.size() * 4));
int32_t required = int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
@@ -167,20 +161,22 @@ std::string ToChatPrompt(const llama_model *model,
if (required >= static_cast<int32_t>(buffer.size())) { if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1); buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
if (required < 0) { if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template"); throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
} }
} }
return std::string(buffer.data(), static_cast<std::size_t>(required)); return std::string(buffer.data(), static_cast<std::size_t>(required));
} }
std::string ToChatPrompt(const llama_model *model, std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
const char *tmpl = llama_model_chat_template(model, nullptr); const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) { if (tmpl == nullptr) {
return system_prompt + "\n\n" + user_prompt; return system_prompt + "\n\n" + user_prompt;
} }
@@ -200,17 +196,19 @@ std::string ToChatPrompt(const llama_model *model,
if (required >= static_cast<int32_t>(buffer.size())) { if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1); buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
if (required < 0) { if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template"); throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
} }
} }
return std::string(buffer.data(), static_cast<std::size_t>(required)); return std::string(buffer.data(), static_cast<std::size_t>(required));
} }
void AppendTokenPiece(const llama_vocab *vocab, llama_token token, void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
std::array<char, 256> buffer{}; std::array<char, 256> buffer{};
int32_t bytes = int32_t bytes =
@@ -220,8 +218,8 @@ void AppendTokenPiece(const llama_vocab *vocab, llama_token token,
if (bytes < 0) { if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes)); std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0, static_cast<int32_t>(dynamic_buffer.size()),
true); 0, true);
if (bytes < 0) { if (bytes < 0) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece"); "LlamaGenerator: failed to decode sampled token piece");
@@ -372,24 +370,23 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
return PrepareRegionContext(region_context, max_chars); return PrepareRegionContext(region_context, max_chars);
} }
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponsePublic(
ParseTwoLineResponsePublic(const std::string& raw, const std::string& raw, const std::string& error_message) {
const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message); return ParseTwoLineResponse(raw, error_message);
} }
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) { const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt); return ToChatPrompt(model, user_prompt);
} }
std::string ToChatPromptPublic(const llama_model *model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt, const std::string& system_prompt,
const std::string& user_prompt) { const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt); return ToChatPrompt(model, system_prompt, user_prompt);
} }
void AppendTokenPiecePublic(const llama_vocab *vocab, llama_token token, void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) { std::string& output) {
AppendTokenPiece(vocab, token, output); AppendTokenPiece(vocab, token, output);
} }

View File

@@ -1,20 +1,20 @@
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "llama.h"
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr) if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded"); throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_); const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable"); throw std::runtime_error("LlamaGenerator: vocab unavailable");
@@ -45,14 +45,16 @@ std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
} }
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget); prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens "
"to fit n_batch/n_ctx limits", "to fit n_batch/n_ctx limits",
token_count, prompt_budget); token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
@@ -84,9 +86,9 @@ std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
generated_tokens.reserve(static_cast<std::size_t>(max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); const llama_token next =
if (llama_vocab_is_eog(vocab, next)) llama_sampler_sample(sampler.get(), context_, -1);
break; if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next); generated_tokens.push_back(next);
llama_token token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
@@ -106,7 +108,7 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
if (model_ == nullptr || context_ == nullptr) if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded"); throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_); const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable"); throw std::runtime_error("LlamaGenerator: vocab unavailable");
@@ -138,14 +140,16 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
} }
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget); prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens "
"to fit n_batch/n_ctx limits", "to fit n_batch/n_ctx limits",
token_count, prompt_budget); token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
@@ -177,9 +181,9 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
generated_tokens.reserve(static_cast<std::size_t>(max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); const llama_token next =
if (llama_vocab_is_eog(vocab, next)) llama_sampler_sample(sampler.get(), context_, -1);
break; if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next); generated_tokens.push_back(next);
llama_token token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);

View File

@@ -1,10 +1,10 @@
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "llama.h"
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) { void LlamaGenerator::Load(const std::string& model_path) {
if (model_path.empty()) if (model_path.empty())

View File

@@ -1,8 +1,7 @@
#include <stdexcept> #include <stdexcept>
#include "llama.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) { int seed) {

View File

@@ -1,7 +1,7 @@
#include "data_generation/mock_generator.h" #include <string>
#include <vector>
#include <functional> #include "data_generation/mock_generator.h"
#include <spdlog/spdlog.h>
const std::vector<std::string> MockGenerator::kBreweryAdjectives = { const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
@@ -63,42 +63,3 @@ const std::vector<std::string> MockGenerator::kBios = {
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}
std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}
BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}
UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,12 @@
#include <string>
#include "data_generation/mock_generator.h"
std::size_t MockGenerator::DeterministicHash(const std::string& a,
const std::string& b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}

View File

@@ -0,0 +1,21 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash =
region_context.empty() ? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}

View File

@@ -0,0 +1,13 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,9 @@
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -1,11 +1,13 @@
#include "database/database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *schema = R"( const char* schema = R"(
CREATE TABLE IF NOT EXISTS countries ( CREATE TABLE IF NOT EXISTS countries (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -33,7 +35,7 @@ void SqliteDatabase::InitializeSchema() {
); );
)"; )";
char *errMsg = nullptr; char* errMsg = nullptr;
int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg); int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
std::string error = errMsg ? std::string(errMsg) : "Unknown error"; std::string error = errMsg ? std::string(errMsg) : "Unknown error";
@@ -48,7 +50,7 @@ SqliteDatabase::~SqliteDatabase() {
} }
} }
void SqliteDatabase::Initialize(const std::string &db_path) { void SqliteDatabase::Initialize(const std::string& db_path) {
int rc = sqlite3_open(db_path.c_str(), &db_); int rc = sqlite3_open(db_path.c_str(), &db_);
if (rc) { if (rc) {
throw std::runtime_error("Failed to open SQLite database: " + db_path); throw std::runtime_error("Failed to open SQLite database: " + db_path);
@@ -59,7 +61,7 @@ void SqliteDatabase::Initialize(const std::string &db_path) {
void SqliteDatabase::BeginTransaction() { void SqliteDatabase::BeginTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) != if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) !=
SQLITE_OK) { SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
@@ -70,7 +72,7 @@ void SqliteDatabase::BeginTransaction() {
void SqliteDatabase::CommitTransaction() { void SqliteDatabase::CommitTransaction() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char* err = nullptr;
if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
@@ -78,17 +80,17 @@ void SqliteDatabase::CommitTransaction() {
} }
} }
void SqliteDatabase::InsertCountry(int id, const std::string &name, void SqliteDatabase::InsertCountry(int id, const std::string& name,
const std::string &iso2, const std::string& iso2,
const std::string &iso3) { const std::string& iso3) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO countries (id, name, iso2, iso3) INSERT OR IGNORE INTO countries (id, name, iso2, iso3)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare country insert"); throw std::runtime_error("Failed to prepare country insert");
@@ -104,16 +106,17 @@ void SqliteDatabase::InsertCountry(int id, const std::string &name,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id,
const std::string &iso2) { const std::string& name,
const std::string& iso2) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO states (id, country_id, name, iso2) INSERT OR IGNORE INTO states (id, country_id, name, iso2)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare state insert"); throw std::runtime_error("Failed to prepare state insert");
@@ -130,16 +133,16 @@ void SqliteDatabase::InsertState(int id, int country_id, const std::string &name
} }
void SqliteDatabase::InsertCity(int id, int state_id, int country_id, void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
const std::string &name, double latitude, const std::string& name, double latitude,
double longitude) { double longitude) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char* query = R"(
INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude) INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare city insert"); throw std::runtime_error("Failed to prepare city insert");
@@ -160,9 +163,9 @@ void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
std::vector<City> SqliteDatabase::QueryCities() { std::vector<City> SqliteDatabase::QueryCities() {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<City> cities; std::vector<City> cities;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
const char *query = "SELECT id, name, country_id FROM cities ORDER BY name"; const char* query = "SELECT id, name, country_id FROM cities ORDER BY name";
int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
@@ -171,8 +174,8 @@ std::vector<City> SqliteDatabase::QueryCities() {
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
int country_id = sqlite3_column_int(stmt, 2); int country_id = sqlite3_column_int(stmt, 2);
cities.push_back({id, name ? std::string(name) : "", country_id}); cities.push_back({id, name ? std::string(name) : "", country_id});
} }
@@ -185,7 +188,7 @@ std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<Country> countries; std::vector<Country> countries;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, iso3 FROM countries ORDER BY name"; "SELECT id, name, iso2, iso3 FROM countries ORDER BY name";
@@ -201,12 +204,12 @@ std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
const char *iso3 = const char* iso3 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 3));
countries.push_back({id, name ? std::string(name) : "", countries.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", iso2 ? std::string(iso2) : "",
iso3 ? std::string(iso3) : ""}); iso3 ? std::string(iso3) : ""});
@@ -220,7 +223,7 @@ std::vector<State> SqliteDatabase::QueryStates(int limit) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<State> states; std::vector<State> states;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt* stmt = nullptr;
std::string query = std::string query =
"SELECT id, name, iso2, country_id FROM states ORDER BY name"; "SELECT id, name, iso2, country_id FROM states ORDER BY name";
@@ -236,10 +239,10 @@ std::vector<State> SqliteDatabase::QueryStates(int limit) {
while (sqlite3_step(stmt) == SQLITE_ROW) { while (sqlite3_step(stmt) == SQLITE_ROW) {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char* name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char* iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
int country_id = sqlite3_column_int(stmt, 3); int country_id = sqlite3_column_int(stmt, 3);
states.push_back({id, name ? std::string(name) : "", states.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", country_id}); iso2 ? std::string(iso2) : "", country_id});

View File

@@ -1,12 +1,13 @@
#include <chrono> #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "json_handling/json_loader.h" #include <chrono>
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string &json_path, void JsonLoader::LoadWorldCities(const std::string& json_path,
SqliteDatabase &db) { SqliteDatabase& db) {
constexpr size_t kBatchSize = 10000; constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
@@ -19,7 +20,7 @@ void JsonLoader::LoadWorldCities(const std::string &json_path,
try { try {
StreamingJsonParser::Parse( StreamingJsonParser::Parse(
json_path, db, json_path, db,
[&](const CityRecord &record) { [&](const CityRecord& record) {
db.InsertCity(record.id, record.state_id, record.country_id, db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude); record.name, record.latitude, record.longitude);
++citiesProcessed; ++citiesProcessed;

View File

@@ -1,25 +1,26 @@
#include <cstdio> #include "json_handling/stream_parser.h"
#include <stdexcept>
#include <spdlog/spdlog.h>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp> #include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <cstdio>
#include <stdexcept>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler { class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>; friend class boost::json::basic_parser<CityRecordHandler>;
public: public:
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1); static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
struct ParseContext { struct ParseContext {
SqliteDatabase *db = nullptr; SqliteDatabase* db = nullptr;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord&)> on_city;
std::function<void(size_t, size_t)> on_progress; std::function<void(size_t, size_t)> on_progress;
size_t cities_emitted = 0; size_t cities_emitted = 0;
size_t total_file_size = 0; size_t total_file_size = 0;
@@ -27,10 +28,10 @@ public:
int states_inserted = 0; int states_inserted = 0;
}; };
explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {} explicit CityRecordHandler(ParseContext& ctx) : context(ctx) {}
private: private:
ParseContext &context; ParseContext& context;
int depth = 0; int depth = 0;
bool in_countries_array = false; bool in_countries_array = false;
@@ -51,10 +52,10 @@ private:
std::string state_info[2]; std::string state_info[2];
// Boost.JSON SAX Hooks // Boost.JSON SAX Hooks
bool on_document_begin(boost::system::error_code &) { return true; } bool on_document_begin(boost::system::error_code&) { return true; }
bool on_document_end(boost::system::error_code &) { return true; } bool on_document_end(boost::system::error_code&) { return true; }
bool on_array_begin(boost::system::error_code &) { bool on_array_begin(boost::system::error_code&) {
depth++; depth++;
if (depth == 1) { if (depth == 1) {
in_countries_array = true; in_countries_array = true;
@@ -66,7 +67,7 @@ private:
return true; return true;
} }
bool on_array_end(std::size_t, boost::system::error_code &) { bool on_array_end(std::size_t, boost::system::error_code&) {
if (depth == 1) { if (depth == 1) {
in_countries_array = false; in_countries_array = false;
} else if (depth == 3) { } else if (depth == 3) {
@@ -78,7 +79,7 @@ private:
return true; return true;
} }
bool on_object_begin(boost::system::error_code &) { bool on_object_begin(boost::system::error_code&) {
depth++; depth++;
if (depth == 2 && in_countries_array) { if (depth == 2 && in_countries_array) {
in_country_object = true; in_country_object = true;
@@ -98,7 +99,7 @@ private:
return true; return true;
} }
bool on_object_end(std::size_t, boost::system::error_code &) { bool on_object_end(std::size_t, boost::system::error_code&) {
if (depth == 6 && building_city) { if (depth == 6 && building_city) {
if (current_city.id > 0 && current_state_id > 0 && if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) { current_country_id > 0) {
@@ -113,7 +114,7 @@ private:
context.on_progress(context.cities_emitted, context.on_progress(context.cities_emitted,
context.total_file_size); context.total_file_size);
} }
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("Record parsing failed: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
@@ -124,7 +125,7 @@ private:
context.db->InsertState(current_state_id, current_country_id, context.db->InsertState(current_state_id, current_country_id,
state_info[0], state_info[1]); state_info[0], state_info[1]);
context.states_inserted++; context.states_inserted++;
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("Record parsing failed: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
@@ -135,7 +136,7 @@ private:
context.db->InsertCountry(current_country_id, country_info[0], context.db->InsertCountry(current_country_id, country_info[0],
country_info[1], country_info[2]); country_info[1], country_info[2]);
context.countries_inserted++; context.countries_inserted++;
} catch (const std::exception &e) { } catch (const std::exception& e) {
spdlog::warn("Record parsing failed: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
@@ -147,13 +148,13 @@ private:
} }
bool on_key_part(boost::json::string_view s, std::size_t, bool on_key_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_key_val.append(s.data(), s.size()); current_key_val.append(s.data(), s.size());
return true; return true;
} }
bool on_key(boost::json::string_view s, std::size_t, bool on_key(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_key_val.append(s.data(), s.size()); current_key_val.append(s.data(), s.size());
current_key = current_key_val; current_key = current_key_val;
current_key_val.clear(); current_key_val.clear();
@@ -161,13 +162,13 @@ private:
} }
bool on_string_part(boost::json::string_view s, std::size_t, bool on_string_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_string_val.append(s.data(), s.size()); current_string_val.append(s.data(), s.size());
return true; return true;
} }
bool on_string(boost::json::string_view s, std::size_t, bool on_string(boost::json::string_view s, std::size_t,
boost::system::error_code &) { boost::system::error_code&) {
current_string_val.append(s.data(), s.size()); current_string_val.append(s.data(), s.size());
if (building_city && current_key == "name") { if (building_city && current_key == "name") {
@@ -188,12 +189,12 @@ private:
return true; return true;
} }
bool on_number_part(boost::json::string_view, boost::system::error_code &) { bool on_number_part(boost::json::string_view, boost::system::error_code&) {
return true; return true;
} }
bool on_int64(int64_t i, boost::json::string_view, bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code &) { boost::system::error_code&) {
if (building_city && current_key == "id") { if (building_city && current_key == "id") {
current_city.id = static_cast<int>(i); current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") { } else if (in_state_object && current_key == "id") {
@@ -205,12 +206,12 @@ private:
} }
bool on_uint64(uint64_t u, boost::json::string_view, bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code &ec) { boost::system::error_code& ec) {
return on_int64(static_cast<int64_t>(u), "", ec); return on_int64(static_cast<int64_t>(u), "", ec);
} }
bool on_double(double d, boost::json::string_view, bool on_double(double d, boost::json::string_view,
boost::system::error_code &) { boost::system::error_code&) {
if (building_city) { if (building_city) {
if (current_key == "latitude") { if (current_key == "latitude") {
current_city.latitude = d; current_city.latitude = d;
@@ -221,24 +222,23 @@ private:
return true; return true;
} }
bool on_bool(bool, boost::system::error_code &) { return true; } bool on_bool(bool, boost::system::error_code&) { return true; }
bool on_null(boost::system::error_code &) { return true; } bool on_null(boost::system::error_code&) { return true; }
bool on_comment_part(boost::json::string_view, boost::system::error_code &) { bool on_comment_part(boost::json::string_view, boost::system::error_code&) {
return true; return true;
} }
bool on_comment(boost::json::string_view, boost::system::error_code &) { bool on_comment(boost::json::string_view, boost::system::error_code&) {
return true; return true;
} }
}; };
void StreamingJsonParser::Parse( void StreamingJsonParser::Parse(
const std::string &file_path, SqliteDatabase &db, const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord &)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
FILE *file = std::fopen(file_path.c_str(), "rb"); FILE* file = std::fopen(file_path.c_str(), "rb");
if (!file) { if (!file) {
throw std::runtime_error("Failed to open JSON file: " + file_path); throw std::runtime_error("Failed to open JSON file: " + file_path);
} }
@@ -252,8 +252,8 @@ void StreamingJsonParser::Parse(
std::rewind(file); std::rewind(file);
} }
CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size,
total_size, 0, 0}; 0, 0};
boost::json::basic_parser<CityRecordHandler> parser( boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx); boost::json::parse_options{}, ctx);
@@ -262,7 +262,7 @@ void StreamingJsonParser::Parse(
boost::system::error_code ec; boost::system::error_code ec;
while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) { while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
char const *p = buf; char const* p = buf;
std::size_t remain = bytes_read; std::size_t remain = bytes_read;
while (remain > 0) { while (remain > 0) {
@@ -284,5 +284,6 @@ void StreamingJsonParser::Parse(
} }
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); ctx.countries_inserted, ctx.states_inserted,
ctx.cities_emitted);
} }

View File

@@ -1,6 +1,8 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdio>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
@@ -17,20 +19,20 @@ CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
namespace { namespace {
// curl write callback that appends response data into a std::string // curl write callback that appends response data into a std::string
size_t WriteCallbackString(void *contents, size_t size, size_t nmemb, size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *s = static_cast<std::string *>(userp); auto* s = static_cast<std::string*>(userp);
s->append(static_cast<char *>(contents), realsize); s->append(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
// curl write callback that writes to a file stream // curl write callback that writes to a file stream
size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb, size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void *userp) { void* userp) {
size_t realsize = size * nmemb; size_t realsize = size * nmemb;
auto *outFile = static_cast<std::ofstream *>(userp); auto* outFile = static_cast<std::ofstream*>(userp);
outFile->write(static_cast<char *>(contents), realsize); outFile->write(static_cast<char*>(contents), realsize);
return realsize; return realsize;
} }
@@ -38,7 +40,7 @@ size_t WriteCallbackString(void *contents, size_t size, size_t nmemb,
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>; using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() { CurlHandle create_handle() {
CURL *handle = curl_easy_init(); CURL* handle = curl_easy_init();
if (!handle) { if (!handle) {
throw std::runtime_error( throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle"); "[CURLWebClient] Failed to initialize libcurl handle");
@@ -46,7 +48,7 @@ CurlHandle create_handle() {
return CurlHandle(handle, &curl_easy_cleanup); return CurlHandle(handle, &curl_easy_cleanup);
} }
void set_common_get_options(CURL *curl, const std::string &url, void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) { long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
@@ -62,20 +64,20 @@ CURLWebClient::CURLWebClient() {}
CURLWebClient::~CURLWebClient() {} CURLWebClient::~CURLWebClient() {}
void CURLWebClient::DownloadToFile(const std::string &url, void CURLWebClient::DownloadToFile(const std::string& url,
const std::string &file_path) { const std::string& file_path) {
auto curl = create_handle(); auto curl = create_handle();
std::ofstream outFile(file_path, std::ios::binary); std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) { if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + throw std::runtime_error(
file_path); "[CURLWebClient] Cannot open file for writing: " + file_path);
} }
set_common_get_options(curl.get(), url, 30L, 300L); set_common_get_options(curl.get(), url, 30L, 300L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void *>(&outFile)); static_cast<void*>(&outFile));
CURLcode res = curl_easy_perform(curl.get()); CURLcode res = curl_easy_perform(curl.get());
outFile.close(); outFile.close();
@@ -98,7 +100,7 @@ void CURLWebClient::DownloadToFile(const std::string &url,
} }
} }
std::string CURLWebClient::Get(const std::string &url) { std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle(); auto curl = create_handle();
std::string response_string; std::string response_string;
@@ -126,9 +128,9 @@ std::string CURLWebClient::Get(const std::string &url) {
return response_string; return response_string;
} }
std::string CURLWebClient::UrlEncode(const std::string &value) { std::string CURLWebClient::UrlEncode(const std::string& value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs. // A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char *output = curl_easy_escape(nullptr, value.c_str(), 0); char* output = curl_easy_escape(nullptr, value.c_str(), 0);
if (output) { if (output) {
std::string result(output); std::string result(output);

View File

@@ -1,7 +1,9 @@
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <boost/json.hpp>
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client) WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}
@@ -17,9 +19,9 @@ std::string WikipediaService::FetchExtract(std::string_view query) {
boost::json::value doc = boost::json::parse(body, ec); boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) { if (!ec && doc.is_object()) {
auto &pages = doc.at("query").at("pages").get_object(); auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) { if (!pages.empty()) {
auto &page = pages.begin()->value().get_object(); auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) { if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str()); std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'", spdlog::debug("WikipediaService fetched {} chars for '{}'",
@@ -63,11 +65,10 @@ std::string WikipediaService::GetSummary(std::string_view city,
result += regionExtract; result += regionExtract;
} }
if (!beerExtract.empty()) { if (!beerExtract.empty()) {
if (!result.empty()) if (!result.empty()) result += "\n\n";
result += "\n\n";
result += beerExtract; result += beerExtract;
} }
} catch (const std::runtime_error &e) { } catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
e.what()); e.what());
} }