Compare commits

4 Commits

Author SHA1 Message Date
Aaron Po
077f6ab4ae edit prompt 2026-04-02 22:56:18 -04:00
Aaron Po
534403734a Refactor BiergartenDataGenerator and LlamaGenerator 2026-04-02 22:46:00 -04:00
Aaron Po
3af053f0eb format codebase 2026-04-02 21:46:46 -04:00
Aaron Po
ba165d8aa7 Separate llama generator class src file into method files 2026-04-02 21:37:46 -04:00
34 changed files with 1879 additions and 1754 deletions

View File

@@ -1,10 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
Standard: c++23 ColumnLimit: 80
ColumnLimit: 100 IndentWidth: 3
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
... ...

View File

@@ -83,8 +83,18 @@ set(PIPELINE_SOURCES
src/data_generation/data_downloader.cpp src/data_generation/data_downloader.cpp
src/database/database.cpp src/database/database.cpp
src/json_handling/json_loader.cpp src/json_handling/json_loader.cpp
src/data_generation/llama_generator.cpp src/data_generation/llama/destructor.cpp
src/data_generation/mock_generator.cpp src/data_generation/llama/set_sampling_options.cpp
src/data_generation/llama/load.cpp
src/data_generation/llama/infer.cpp
src/data_generation/llama/generate_brewery.cpp
src/data_generation/llama/generate_user.cpp
src/data_generation/llama/helpers.cpp
src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/load.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
src/json_handling/stream_parser.cpp src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp src/wikipedia/wikipedia_service.cpp
src/main.cpp src/main.cpp

View File

@@ -3,23 +3,24 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "database/database.h" #include "database/database.h"
#include "web_client/web_client.h" #include "web_client/web_client.h"
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
/** /**
* @brief Program options for the Biergarten pipeline application. * @brief Program options for the Biergarten pipeline application.
*/ */
struct ApplicationOptions { struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with use_mocked. /// @brief Path to the LLM model file (gguf format); mutually exclusive with
/// use_mocked.
std::string model_path; std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with model_path. /// @brief Use mocked generator instead of LLM; mutually exclusive with
/// model_path.
bool use_mocked = false; bool use_mocked = false;
/// @brief Directory for cached JSON and database files. /// @brief Directory for cached JSON and database files.
@@ -28,24 +29,24 @@ struct ApplicationOptions {
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random). /// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 0.8f; float temperature = 0.8f;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more random). /// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more
/// random).
float top_p = 0.92f; float top_p = 0.92f;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative). /// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
/// @brief Git commit hash for database consistency (always pinned to c5eb7772). /// @brief Git commit hash for database consistency (always pinned to
/// c5eb7772).
std::string commit = "c5eb7772"; std::string commit = "c5eb7772";
}; };
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_
/** /**
* @brief Main data generator class for the Biergarten pipeline. * @brief Main data generator class for the Biergarten pipeline.
* *
* This class encapsulates the core logic for generating brewery data. * This class encapsulates the core logic for generating brewery data.
* It handles database initialization, data loading/downloading, and brewery generation. * It handles database initialization, data loading/downloading, and brewery
* generation.
*/ */
class BiergartenDataGenerator { class BiergartenDataGenerator {
public: public:
@@ -83,6 +84,16 @@ private:
/// @brief Database dependency. /// @brief Database dependency.
SqliteDatabase& database_; SqliteDatabase& database_;
/**
* @brief Enriched city data with Wikipedia context.
*/
struct EnrichedCity {
int city_id;
std::string city_name;
std::string country_name;
std::string region_context;
};
/** /**
* @brief Initialize the data generator based on options. * @brief Initialize the data generator based on options.
* *
@@ -98,19 +109,45 @@ private:
void LoadGeographicData(); void LoadGeographicData();
/** /**
* @brief Generate sample breweries for demonstration. * @brief Query cities from database and build country name map.
*
* @return Vector of (City, country_name) pairs capped at 30 entries.
*/ */
void GenerateSampleBreweries(); std::vector<std::pair<City, std::string>> QueryCitiesWithCountries();
/**
* @brief Enrich cities with Wikipedia summaries.
*
* @param cities Vector of (City, country_name) pairs.
* @return Vector of enriched city data with context.
*/
std::vector<EnrichedCity> EnrichWithWikipedia(
const std::vector<std::pair<City, std::string>>& cities);
/**
* @brief Generate breweries for enriched cities.
*
* @param generator The data generator instance.
* @param cities Vector of enriched city data.
*/
void GenerateBreweries(DataGenerator& generator,
const std::vector<EnrichedCity>& cities);
/**
* @brief Log the generated brewery results.
*/
void LogResults() const;
/** /**
* @brief Helper struct to store generated brewery data. * @brief Helper struct to store generated brewery data.
*/ */
struct GeneratedBrewery { struct GeneratedBrewery {
int cityId; int city_id;
std::string cityName; std::string city_name;
BreweryResult brewery; BreweryResult brewery;
}; };
/// @brief Stores generated brewery data. /// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generatedBreweries_; std::vector<GeneratedBrewery> generatedBreweries_;
}; };
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_

View File

@@ -19,7 +19,8 @@ public:
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string& cache_path, const std::string& cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string& commit =
"c5eb7772" // Stable commit: 2026-03-28 export
); );
private: private:

View File

@@ -28,7 +28,10 @@ private:
// models receive a proper system role instead of having the system text // models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal // concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output). // reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt, std::string Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens = 10000);
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = 10000); int max_tokens = 10000);
llama_model* model_ = nullptr; llama_model* model_ = nullptr;

View File

@@ -0,0 +1,32 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
#include <string>
#include <utility>
struct llama_model;
struct llama_vocab;
typedef int llama_token;
// Helper functions for LlamaGenerator methods
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700);
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message);
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt);
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt);
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output);
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out);
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_

View File

@@ -1,10 +1,11 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "data_generation/data_generator.h"
class MockGenerator final : public DataGenerator { class MockGenerator final : public DataGenerator {
public: public:
void Load(const std::string& model_path) override; void Load(const std::string& model_path) override;

View File

@@ -1,8 +1,9 @@
#ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
@@ -67,8 +68,8 @@ public:
const std::string& iso2); const std::string& iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int state_id, int country_id, const std::string &name, void InsertCity(int id, int state_id, int country_id,
double latitude, double longitude); const std::string& name, double latitude, double longitude);
/// @brief Returns city records including parent country id. /// @brief Returns city records including parent country id.
std::vector<City> QueryCities(); std::vector<City> QueryCities();

View File

@@ -1,15 +1,17 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include <string>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing. /// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON file and writes country/state/city rows into db. /// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); static void LoadWorldCities(const std::string& json_path,
SqliteDatabase& db);
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,10 +1,11 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "database/database.h"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
class SqliteDatabase; class SqliteDatabase;

View File

@@ -1,9 +1,10 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory> #include <memory>
#include "web_client/web_client.h"
// RAII for curl_global_init/cleanup. // RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl // An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application. // operations and exist for the lifetime of the application.

View File

@@ -11,8 +11,8 @@ public:
virtual void DownloadToFile(const std::string& url, virtual void DownloadToFile(const std::string& url,
const std::string& file_path) = 0; const std::string& file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on // Performs a GET request and returns the response body as a string. Throws
// error. // on error.
virtual std::string Get(const std::string& url) = 0; virtual std::string Get(const std::string& url) = 0;
// URL-encodes a string. // URL-encodes a string.

View File

@@ -1,20 +1,19 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <filesystem> #include <filesystem>
#include <unordered_map> #include <unordered_map>
#include <spdlog/spdlog.h>
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "json_handling/json_loader.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h" #include "data_generation/mock_generator.h"
#include "json_handling/json_loader.h"
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
BiergartenDataGenerator::BiergartenDataGenerator( BiergartenDataGenerator::BiergartenDataGenerator(
const ApplicationOptions &options, const ApplicationOptions& options, std::shared_ptr<WebClient> web_client,
std::shared_ptr<WebClient> web_client,
SqliteDatabase& database) SqliteDatabase& database)
: options_(options), webClient_(web_client), database_(database) {} : options_(options), webClient_(web_client), database_(database) {}
@@ -62,57 +61,79 @@ void BiergartenDataGenerator::LoadGeographicData() {
} }
} }
void BiergartenDataGenerator::GenerateSampleBreweries() { std::vector<std::pair<City, std::string>>
auto generator = InitializeGenerator(); BiergartenDataGenerator::QueryCitiesWithCountries() {
WikipediaService wikipedia_service(webClient_);
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
auto countries = database_.QueryCountries(50);
auto states = database_.QueryStates(50);
auto cities = database_.QueryCities(); auto cities = database_.QueryCities();
// Build a quick map of country id -> name for per-city lookups. // Build a quick map of country id -> name for per-city lookups.
auto all_countries = database_.QueryCountries(0); auto all_countries = database_.QueryCountries(0);
std::unordered_map<int, std::string> country_map; std::unordered_map<int, std::string> country_map;
for (const auto &c : all_countries) for (const auto& c : all_countries) {
country_map[c.id] = c.name; country_map[c.id] = c.name;
}
spdlog::info("\nTotal records loaded:"); spdlog::info("\nTotal records loaded:");
spdlog::info(" Countries: {}", database_.QueryCountries(0).size()); spdlog::info(" Countries: {}", database_.QueryCountries(0).size());
spdlog::info(" States: {}", database_.QueryStates(0).size()); spdlog::info(" States: {}", database_.QueryStates(0).size());
spdlog::info(" Cities: {}", cities.size()); spdlog::info(" Cities: {}", cities.size());
generatedBreweries_.clear(); // Cap at 30 entries.
const size_t sample_count = std::min(size_t(30), cities.size()); const size_t sample_count = std::min(size_t(30), cities.size());
std::vector<std::pair<City, std::string>> result;
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
for (size_t i = 0; i < sample_count; i++) { for (size_t i = 0; i < sample_count; i++) {
const auto& city = cities[i]; const auto& city = cities[i];
const int city_id = city.id; std::string country_name;
const std::string city_name = city.name;
std::string local_country;
const auto country_it = country_map.find(city.country_id); const auto country_it = country_map.find(city.country_id);
if (country_it != country_map.end()) { if (country_it != country_map.end()) {
local_country = country_it->second; country_name = country_it->second;
}
result.push_back({city, country_name});
} }
return result;
}
std::vector<BiergartenDataGenerator::EnrichedCity>
BiergartenDataGenerator::EnrichWithWikipedia(
const std::vector<std::pair<City, std::string>>& cities) {
WikipediaService wikipedia_service(webClient_);
std::vector<EnrichedCity> enriched;
for (const auto& [city, country_name] : cities) {
const std::string region_context = const std::string region_context =
wikipedia_service.GetSummary(city_name, local_country); wikipedia_service.GetSummary(city.name, country_name);
spdlog::debug("[Pipeline] Region context for {}: {}", city_name, spdlog::debug("[Pipeline] Region context for {}: {}", city.name,
region_context); region_context);
auto brewery = enriched.push_back({city.id, city.name, country_name, region_context});
generator->GenerateBrewery(city_name, local_country, region_context);
generatedBreweries_.push_back({city_id, city_name, brewery});
} }
return enriched;
}
void BiergartenDataGenerator::GenerateBreweries(
DataGenerator& generator, const std::vector<EnrichedCity>& cities) {
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
generatedBreweries_.clear();
for (const auto& enriched_city : cities) {
auto brewery = generator.GenerateBrewery(enriched_city.city_name,
enriched_city.country_name,
enriched_city.region_context);
generatedBreweries_.push_back(
{enriched_city.city_id, enriched_city.city_name, brewery});
}
}
void BiergartenDataGenerator::LogResults() const {
spdlog::info("\n=== GENERATED DATA DUMP ==="); spdlog::info("\n=== GENERATED DATA DUMP ===");
for (size_t i = 0; i < generatedBreweries_.size(); i++) { for (size_t i = 0; i < generatedBreweries_.size(); i++) {
const auto& entry = generatedBreweries_[i]; const auto& entry = generatedBreweries_[i];
spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId, spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.city_id,
entry.cityName); entry.city_name);
spdlog::info(" brewery_name=\"{}\"", entry.brewery.name); spdlog::info(" brewery_name=\"{}\"", entry.brewery.name);
spdlog::info(" brewery_description=\"{}\"", entry.brewery.description); spdlog::info(" brewery_description=\"{}\"", entry.brewery.description);
} }
@@ -121,7 +142,11 @@ void BiergartenDataGenerator::GenerateSampleBreweries() {
int BiergartenDataGenerator::Run() { int BiergartenDataGenerator::Run() {
try { try {
LoadGeographicData(); LoadGeographicData();
GenerateSampleBreweries(); auto generator = InitializeGenerator();
auto cities = QueryCitiesWithCountries();
auto enriched = EnrichWithWikipedia(cities);
GenerateBreweries(*generator, enriched);
LogResults();
spdlog::info("\nOK: Pipeline completed successfully"); spdlog::info("\nOK: Pipeline completed successfully");
return 0; return 0;

View File

@@ -1,11 +1,14 @@
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <spdlog/spdlog.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include "web_client/web_client.h"
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client) DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {} : web_client_(std::move(web_client)) {}
@@ -15,9 +18,8 @@ bool DataDownloader::FileExists(const std::string &file_path) {
return std::filesystem::exists(file_path); return std::filesystem::exists(file_path);
} }
std::string std::string DataDownloader::DownloadCountriesDatabase(
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, const std::string& cache_path, const std::string& commit) {
const std::string &commit) {
if (FileExists(cache_path)) { if (FileExists(cache_path)) {
spdlog::info("[DataDownloader] Cache hit: {}", cache_path); spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
return cache_path; return cache_path;
@@ -28,7 +30,8 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cache_path,
short_commit = commit.substr(0, 7); short_commit = commit.substr(0, 7);
} }
std::string url = "https://raw.githubusercontent.com/dr5hn/" std::string url =
"https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" + "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json"; short_commit + "/json/countries+states+cities.json";

View File

@@ -0,0 +1,16 @@
#include "data_generation/llama_generator.h"
#include "llama.h"
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}

View File

@@ -0,0 +1,74 @@
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
BreweryResult LlamaGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string safe_region_context =
PrepareRegionContextPublic(region_context);
const std::string system_prompt =
"You are the brewmaster and owner of a local craft brewery. "
"Write a name and a short, soulful description for your brewery that "
"reflects your pride in the local community and your craft. "
"The tone should be authentic and welcoming, like a note on a "
"chalkboard "
"menu. Output ONLY a single JSON object with keys \"name\" and "
"\"description\". "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific long description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int max_attempts = 3;
std::string raw;
std::string last_error;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validation_error =
ValidateBreweryJsonPublic(raw, name, description);
if (validation_error.empty()) {
return {std::move(name), std::move(description)};
}
last_error = validation_error;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error);
prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}

View File

@@ -0,0 +1,57 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int max_attempts = 3;
std::string raw;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] = ParseTwoLineResponsePublic(
raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200) bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception& e) {
spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -0,0 +1,398 @@
#include <algorithm>
#include <array>
#include <boost/json.hpp>
#include <cctype>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "data_generation/llama_generator.h"
#include "llama.h"
namespace {
std::string Trim(std::string value) {
auto not_space = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), not_space));
value.erase(std::find_if(value.rbegin(), value.rend(), not_space).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool in_whitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!in_whitespace) {
out.push_back(' ');
in_whitespace = true;
}
continue;
}
in_whitespace = false;
out.push_back(static_cast<char>(ch));
}
return Trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view region_context,
std::size_t max_chars) {
std::string normalized = CondenseWhitespace(std::string(region_context));
if (normalized.size() <= max_chars) {
return normalized;
}
normalized.resize(max_chars);
const std::size_t last_space = normalized.find_last_of(' ');
if (last_space != std::string::npos && last_space > max_chars / 2) {
normalized.resize(last_space);
}
normalized += "...";
return normalized;
}
std::string StripCommonPrefix(std::string line) {
line = Trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = Trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = Trim(line.substr(i + 1));
}
}
auto strip_label = [&line](const std::string& label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = Trim(line.substr(label.size()));
}
}
};
strip_label("name:");
strip_label("brewery name:");
strip_label("description:");
strip_label("username:");
strip_label("bio:");
return Trim(std::move(line));
}
std::pair<std::string, std::string> ParseTwoLineResponse(
const std::string& raw, const std::string& error_message) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line));
if (!line.empty()) lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto& l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) throw std::runtime_error(error_message);
std::string first = Trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) second += ' ';
second += filtered[i];
}
second = Trim(std::move(second));
if (first.empty() || second.empty()) throw std::runtime_error(error_message);
return {first, second};
}
std::string ToChatPrompt(const llama_model* model,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return user_prompt;
}
const llama_chat_message message{"user", user_prompt.c_str()};
std::vector<char> buffer(
std::max<std::size_t>(1024, user_prompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string ToChatPrompt(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
const char* tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + user_prompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", user_prompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (system_prompt.size() + user_prompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
std::string& output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()),
0, true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool ExtractFirstJsonObject(const std::string& text, std::string& json_out) {
std::size_t start = std::string::npos;
int depth = 0;
bool in_string = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (in_string) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
in_string = false;
}
continue;
}
if (ch == '"') {
in_string = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
json_out = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string& raw, std::string& name_out,
std::string& description_out) {
auto validate_object = [&](const boost::json::value& jv,
std::string& error_out) -> bool {
if (!jv.is_object()) {
error_out = "JSON root must be an object";
return false;
}
const auto& obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
error_out = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
error_out = "JSON field 'description' is missing or not a string";
return false;
}
name_out = Trim(std::string(obj.at("name").as_string().c_str()));
description_out =
Trim(std::string(obj.at("description").as_string().c_str()));
if (name_out.empty()) {
error_out = "JSON field 'name' must not be empty";
return false;
}
if (description_out.empty()) {
error_out = "JSON field 'description' must not be empty";
return false;
}
std::string name_lower = name_out;
std::string description_lower = description_out;
std::transform(
name_lower.begin(), name_lower.end(), name_lower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(description_lower.begin(), description_lower.end(),
description_lower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (name_lower == "string" || description_lower == "string") {
error_out = "JSON appears to be a schema placeholder, not content";
return false;
}
error_out.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validation_error;
if (ec) {
std::string extracted;
if (!ExtractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
if (!validate_object(jv, validation_error)) {
return validation_error;
}
return {};
}
} // namespace
// Forward declarations for helper functions exposed to other translation units
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars) {
return PrepareRegionContext(region_context, max_chars);
}
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
const std::string& raw, const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt) {
return ToChatPrompt(model, user_prompt);
}
std::string ToChatPromptPublic(const llama_model* model,
const std::string& system_prompt,
const std::string& user_prompt) {
return ToChatPrompt(model, system_prompt, user_prompt);
}
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
std::string& output) {
AppendTokenPiece(vocab, token, output);
}
std::string ValidateBreweryJsonPublic(const std::string& raw,
std::string& name_out,
std::string& description_out) {
return ValidateBreweryJson(raw, name_out, description_out);
}

View File

@@ -0,0 +1,111 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, prompt), max_tokens);
}
std::string LlamaGenerator::Infer(const std::string& system_prompt,
const std::string& prompt, int max_tokens) {
return InferFormatted(ToChatPromptPublic(model_, system_prompt, prompt),
max_tokens);
}
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab* vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + 8);
int32_t token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
if (token_count < 0) {
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
token_count = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()), true, true);
}
if (token_count < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t n_ctx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t n_batch = static_cast<int32_t>(llama_n_batch(context_));
if (n_ctx <= 1 || n_batch <= 0)
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens to fit n_batch/n_ctx limits",
token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
token_count = prompt_budget;
}
const llama_batch prompt_batch = llama_batch_get_one(
prompt_tokens.data(), static_cast<int32_t>(prompt_tokens.size()));
if (llama_decode(context_, prompt_batch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next =
llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
return output;
}

View File

@@ -0,0 +1,42 @@
#include <spdlog/spdlog.h>
#include <stdexcept>
#include <string>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}

View File

@@ -0,0 +1,25 @@
#include <stdexcept>
#include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}

View File

@@ -1,734 +0,0 @@
#include <algorithm>
#include <array>
#include <cctype>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h"
namespace {
std::string trim(std::string value) {
auto notSpace = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), notSpace));
value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool inWhitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!inWhitespace) {
out.push_back(' ');
inWhitespace = true;
}
continue;
}
inWhitespace = false;
out.push_back(static_cast<char>(ch));
}
return trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view regionContext,
std::size_t maxChars = 700) {
std::string normalized = CondenseWhitespace(std::string(regionContext));
if (normalized.size() <= maxChars) {
return normalized;
}
normalized.resize(maxChars);
const std::size_t lastSpace = normalized.find_last_of(' ');
if (lastSpace != std::string::npos && lastSpace > maxChars / 2) {
normalized.resize(lastSpace);
}
normalized += "...";
return normalized;
}
std::string stripCommonPrefix(std::string line) {
line = trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = trim(line.substr(i + 1));
}
}
auto stripLabel = [&line](const std::string &label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = trim(line.substr(label.size()));
}
}
};
stripLabel("name:");
stripLabel("brewery name:");
stripLabel("description:");
stripLabel("username:");
stripLabel("bio:");
return trim(std::move(line));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty())
lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>')
continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2)
throw std::runtime_error(errorMessage);
std::string first = trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty())
second += ' ';
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty())
throw std::runtime_error(errorMessage);
return {first, second};
}
std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return userPrompt;
}
const llama_chat_message message{"user", userPrompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string toChatPrompt(const llama_model *model,
const std::string &system_prompt,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + userPrompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", userPrompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (systemPrompt.size() + userPrompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void appendTokenPiece(const llama_vocab *vocab, llama_token token,
std::string &output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamicBuffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(),
static_cast<int32_t>(dynamicBuffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamicBuffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) {
std::size_t start = std::string::npos;
int depth = 0;
bool inString = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (inString) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
inString = false;
}
continue;
}
if (ch == '"') {
inString = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
jsonOut = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut,
std::string &descriptionOut) {
auto validateObject = [&](const boost::json::value &jv,
std::string &errorOut) -> bool {
if (!jv.is_object()) {
errorOut = "JSON root must be an object";
return false;
}
const auto &obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
errorOut = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
errorOut = "JSON field 'description' is missing or not a string";
return false;
}
nameOut = trim(std::string(obj.at("name").as_string().c_str()));
descriptionOut =
trim(std::string(obj.at("description").as_string().c_str()));
if (nameOut.empty()) {
errorOut = "JSON field 'name' must not be empty";
return false;
}
if (descriptionOut.empty()) {
errorOut = "JSON field 'description' must not be empty";
return false;
}
std::string nameLower = nameOut;
std::string descriptionLower = descriptionOut;
std::transform(
nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(descriptionLower.begin(), descriptionLower.end(),
descriptionLower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (nameLower == "string" || descriptionLower == "string") {
errorOut = "JSON appears to be a schema placeholder, not content";
return false;
}
errorOut.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validationError;
if (ec) {
std::string extracted;
if (!extractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
} // namespace
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::Load(const std::string &model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}
std::string LlamaGenerator::Infer(const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt = toChatPrompt(model_, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
const int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
std::string LlamaGenerator::Infer(const std::string &system_prompt,
const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt =
toChatPrompt(model_, system_prompt, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
BreweryResult
LlamaGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string safe_region_context = PrepareRegionContext(region_context);
const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. "
"Your writing is vivid, specific to place, and avoids generic beer "
"cliches. "
"You must output ONLY valid JSON. "
"The JSON schema must be exactly: {\"name\": \"string\", "
"\"description\": \"string\"}. "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int maxAttempts = 3;
std::string raw;
std::string lastError;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validationError =
ValidateBreweryJson(raw, name, description);
if (validationError.empty()) {
return {std::move(name), std::move(description)};
}
lastError = validationError;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validationError);
prompt = "Your previous response was invalid. Error: " + validationError +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
maxAttempts, lastError.empty() ? raw : lastError);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}
UserResult LlamaGenerator::GenerateUser(const std::string &locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int maxAttempts = 3;
std::string raw;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200)
bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception &e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
maxAttempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -1,7 +1,7 @@
#include "data_generation/mock_generator.h" #include <string>
#include <vector>
#include <functional> #include "data_generation/mock_generator.h"
#include <spdlog/spdlog.h>
const std::vector<std::string> MockGenerator::kBreweryAdjectives = { const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
@@ -63,42 +63,3 @@ const std::vector<std::string> MockGenerator::kBios = {
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}
std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}
BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}
UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,12 @@
#include <string>
#include "data_generation/mock_generator.h"
std::size_t MockGenerator::DeterministicHash(const std::string& a,
const std::string& b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}

View File

@@ -0,0 +1,21 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash =
region_context.empty() ? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}

View File

@@ -0,0 +1,13 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,9 @@
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -1,5 +1,7 @@
#include "database/database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
@@ -104,7 +106,8 @@ void SqliteDatabase::InsertCountry(int id, const std::string &name,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id,
const std::string& name,
const std::string& iso2) { const std::string& iso2) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);

View File

@@ -1,8 +1,9 @@
#include <chrono> #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "json_handling/json_loader.h" #include <chrono>
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string& json_path, void JsonLoader::LoadWorldCities(const std::string& json_path,

View File

@@ -1,12 +1,13 @@
#include <cstdio> #include "json_handling/stream_parser.h"
#include <stdexcept>
#include <spdlog/spdlog.h>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp> #include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <cstdio>
#include <stdexcept>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler { class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>; friend class boost::json::basic_parser<CityRecordHandler>;
@@ -235,7 +236,6 @@ void StreamingJsonParser::Parse(
const std::string& file_path, SqliteDatabase& db, const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord&)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
FILE* file = std::fopen(file_path.c_str(), "rb"); FILE* file = std::fopen(file_path.c_str(), "rb");
@@ -252,8 +252,8 @@ void StreamingJsonParser::Parse(
std::rewind(file); std::rewind(file);
} }
CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size,
total_size, 0, 0}; 0, 0};
boost::json::basic_parser<CityRecordHandler> parser( boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx); boost::json::parse_options{}, ctx);
@@ -284,5 +284,6 @@ void StreamingJsonParser::Parse(
} }
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); ctx.countries_inserted, ctx.states_inserted,
ctx.cities_emitted);
} }

View File

@@ -1,6 +1,8 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdio>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
@@ -68,8 +70,8 @@ void CURLWebClient::DownloadToFile(const std::string &url,
std::ofstream outFile(file_path, std::ios::binary); std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) { if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + throw std::runtime_error(
file_path); "[CURLWebClient] Cannot open file for writing: " + file_path);
} }
set_common_get_options(curl.get(), url, 30L, 300L); set_common_get_options(curl.get(), url, 30L, 300L);

View File

@@ -1,8 +1,10 @@
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
WikipediaService::WikipediaService(std::shared_ptr<IWebClient> client) #include <boost/json.hpp>
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}
std::string WikipediaService::FetchExtract(std::string_view query) { std::string WikipediaService::FetchExtract(std::string_view query) {
@@ -63,8 +65,7 @@ std::string WikipediaService::GetSummary(std::string_view city,
result += regionExtract; result += regionExtract;
} }
if (!beerExtract.empty()) { if (!beerExtract.empty()) {
if (!result.empty()) if (!result.empty()) result += "\n\n";
result += "\n\n";
result += beerExtract; result += beerExtract;
} }
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {