format codebase

This commit is contained in:
Aaron Po
2026-04-02 21:46:46 -04:00
parent ba165d8aa7
commit 3af053f0eb
31 changed files with 1479 additions and 1445 deletions

View File

@@ -1,10 +1,5 @@
--- ---
BasedOnStyle: Google BasedOnStyle: Google
Standard: c++23 ColumnLimit: 80
ColumnLimit: 100 IndentWidth: 3
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
... ...

View File

@@ -90,7 +90,11 @@ set(PIPELINE_SOURCES
src/data_generation/llama/generate_brewery.cpp src/data_generation/llama/generate_brewery.cpp
src/data_generation/llama/generate_user.cpp src/data_generation/llama/generate_user.cpp
src/data_generation/llama/helpers.cpp src/data_generation/llama/helpers.cpp
src/data_generation/mock_generator.cpp src/data_generation/mock/data.cpp
src/data_generation/mock/deterministic_hash.cpp
src/data_generation/mock/load.cpp
src/data_generation/mock/generate_brewery.cpp
src/data_generation/mock/generate_user.cpp
src/json_handling/stream_parser.cpp src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp src/wikipedia/wikipedia_service.cpp
src/main.cpp src/main.cpp

View File

@@ -19,7 +19,8 @@ public:
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string& cache_path, const std::string& cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string& commit =
"c5eb7772" // Stable commit: 2026-03-28 export
); );
private: private:

View File

@@ -28,8 +28,8 @@ private:
// models receive a proper system role instead of having the system text // models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal // concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output). // reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt, std::string Infer(const std::string& system_prompt,
int max_tokens = 10000); const std::string& prompt, int max_tokens = 10000);
llama_model* model_ = nullptr; llama_model* model_ = nullptr;
llama_context* context_ = nullptr; llama_context* context_ = nullptr;

View File

@@ -12,9 +12,8 @@ typedef int llama_token;
std::string PrepareRegionContextPublic(std::string_view region_context, std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700); std::size_t max_chars = 700);
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponsePublic(
ParseTwoLineResponsePublic(const std::string& raw, const std::string& raw, const std::string& error_message);
const std::string& error_message);
std::string ToChatPromptPublic(const llama_model* model, std::string ToChatPromptPublic(const llama_model* model,
const std::string& user_prompt); const std::string& user_prompt);

View File

@@ -1,10 +1,11 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_ #define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "data_generation/data_generator.h"
class MockGenerator final : public DataGenerator { class MockGenerator final : public DataGenerator {
public: public:
void Load(const std::string& model_path) override; void Load(const std::string& model_path) override;

View File

@@ -1,8 +1,9 @@
#ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_ #define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
@@ -67,8 +68,8 @@ public:
const std::string& iso2); const std::string& iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int state_id, int country_id, const std::string &name, void InsertCity(int id, int state_id, int country_id,
double latitude, double longitude); const std::string& name, double latitude, double longitude);
/// @brief Returns city records including parent country id. /// @brief Returns city records including parent country id.
std::vector<City> QueryCities(); std::vector<City> QueryCities();

View File

@@ -1,15 +1,17 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include <string>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing. /// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader { class JsonLoader {
public: public:
/// @brief Parses a JSON file and writes country/state/city rows into db. /// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db); static void LoadWorldCities(const std::string& json_path,
SqliteDatabase& db);
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_ #endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,10 +1,11 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_ #define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
#include "database/database.h"
// Forward declaration to avoid circular dependency // Forward declaration to avoid circular dependency
class SqliteDatabase; class SqliteDatabase;

View File

@@ -1,9 +1,10 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_ #define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory> #include <memory>
#include "web_client/web_client.h"
// RAII for curl_global_init/cleanup. // RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl // An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application. // operations and exist for the lifetime of the application.

View File

@@ -11,8 +11,8 @@ public:
virtual void DownloadToFile(const std::string& url, virtual void DownloadToFile(const std::string& url,
const std::string& file_path) = 0; const std::string& file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on // Performs a GET request and returns the response body as a string. Throws
// error. // on error.
virtual std::string Get(const std::string& url) = 0; virtual std::string Get(const std::string& url) = 0;
// URL-encodes a string. // URL-encodes a string.

View File

@@ -1,11 +1,14 @@
#include "data_generation/data_downloader.h" #include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <spdlog/spdlog.h>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <spdlog/spdlog.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include "web_client/web_client.h"
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client) DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {} : web_client_(std::move(web_client)) {}
@@ -15,9 +18,8 @@ bool DataDownloader::FileExists(const std::string &file_path) {
return std::filesystem::exists(file_path); return std::filesystem::exists(file_path);
} }
std::string std::string DataDownloader::DownloadCountriesDatabase(
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path, const std::string& cache_path, const std::string& commit) {
const std::string &commit) {
if (FileExists(cache_path)) { if (FileExists(cache_path)) {
spdlog::info("[DataDownloader] Cache hit: {}", cache_path); spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
return cache_path; return cache_path;
@@ -28,7 +30,8 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cache_path,
short_commit = commit.substr(0, 7); short_commit = commit.substr(0, 7);
} }
std::string url = "https://raw.githubusercontent.com/dr5hn/" std::string url =
"https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" + "countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json"; short_commit + "/json/countries+states+cities.json";

View File

@@ -1,6 +1,5 @@
#include "llama.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
LlamaGenerator::~LlamaGenerator() { LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) { if (context_ != nullptr) {

View File

@@ -1,25 +1,25 @@
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
BreweryResult BreweryResult LlamaGenerator::GenerateBrewery(
LlamaGenerator::GenerateBrewery(const std::string& city_name, const std::string& city_name, const std::string& country_name,
const std::string& country_name,
const std::string& region_context) { const std::string& region_context) {
const std::string safe_region_context = const std::string safe_region_context =
PrepareRegionContextPublic(region_context); PrepareRegionContextPublic(region_context);
const std::string system_prompt = const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. " "You are the brewmaster and owner of a local craft brewery. "
"Your writing is vivid, specific to place, and avoids generic beer " "Write a name and a short, soulful description for your brewery that "
"cliches. " "reflects your pride in the local community and your craft. "
"You must output ONLY valid JSON. " "The tone should be authentic and welcoming, like a note on a "
"The JSON schema must be exactly: {\"name\": \"string\", " "chalkboard "
"\"description\": \"string\"}. " "menu. Output ONLY a single JSON object with keys \"name\" and "
"\"description\". "
"Do not include markdown formatting or backticks."; "Do not include markdown formatting or backticks.";
std::string prompt = std::string prompt =
@@ -52,7 +52,8 @@ LlamaGenerator::GenerateBrewery(const std::string& city_name,
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validation_error); attempt + 1, validation_error);
prompt = "Your previous response was invalid. Error: " + validation_error + prompt =
"Your previous response was invalid. Error: " + validation_error +
"\nReturn ONLY valid JSON with this exact schema: " "\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}." "{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys." "\nDo not include markdown, comments, or extra keys."
@@ -65,7 +66,8 @@ LlamaGenerator::GenerateBrewery(const std::string& city_name,
: std::string("\nRegional context: ") + safe_region_context); : std::string("\nRegional context: ") + safe_region_context);
} }
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " spdlog::error(
"LlamaGenerator: malformed brewery response after {} attempts: "
"{}", "{}",
max_attempts, last_error.empty() ? raw : last_error); max_attempts, last_error.empty() ? raw : last_error);
throw std::runtime_error("LlamaGenerator: malformed brewery response"); throw std::runtime_error("LlamaGenerator: malformed brewery response");

View File

@@ -1,9 +1,9 @@
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
@@ -40,17 +40,18 @@ UserResult LlamaGenerator::GenerateUser(const std::string& locale) {
throw std::runtime_error("LlamaGenerator: malformed user response"); throw std::runtime_error("LlamaGenerator: malformed user response");
} }
if (bio.size() > 200) if (bio.size() > 200) bio = bio.substr(0, 200);
bio = bio.substr(0, 200);
return {username, bio}; return {username, bio};
} catch (const std::exception& e) { } catch (const std::exception& e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}", spdlog::warn(
"LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what()); attempt + 1, e.what());
} }
} }
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}", spdlog::error(
"LlamaGenerator: malformed user response after {} attempts: {}",
max_attempts, raw); max_attempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response"); throw std::runtime_error("LlamaGenerator: malformed user response");
} }

View File

@@ -1,15 +1,14 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <boost/json.hpp>
#include <cctype> #include <cctype>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
namespace { namespace {
@@ -103,8 +102,8 @@ std::string StripCommonPrefix(std::string line) {
return Trim(std::move(line)); return Trim(std::move(line));
} }
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponse(
ParseTwoLineResponse(const std::string& raw, const std::string& error_message) { const std::string& raw, const std::string& error_message) {
std::string normalized = raw; std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n'); std::replace(normalized.begin(), normalized.end(), '\r', '\n');
@@ -113,8 +112,7 @@ ParseTwoLineResponse(const std::string& raw, const std::string& error_message) {
std::string line; std::string line;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
line = StripCommonPrefix(std::move(line)); line = StripCommonPrefix(std::move(line));
if (!line.empty()) if (!line.empty()) lines.push_back(std::move(line));
lines.push_back(std::move(line));
} }
std::vector<std::string> filtered; std::vector<std::string> filtered;
@@ -123,27 +121,22 @@ ParseTwoLineResponse(const std::string& raw, const std::string& error_message) {
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c)); return static_cast<char>(std::tolower(c));
}); });
if (!l.empty() && l.front() == '<' && low.back() == '>') if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
continue; if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l)); filtered.push_back(std::move(l));
} }
if (filtered.size() < 2) if (filtered.size() < 2) throw std::runtime_error(error_message);
throw std::runtime_error(error_message);
std::string first = Trim(filtered.front()); std::string first = Trim(filtered.front());
std::string second; std::string second;
for (size_t i = 1; i < filtered.size(); ++i) { for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) if (!second.empty()) second += ' ';
second += ' ';
second += filtered[i]; second += filtered[i];
} }
second = Trim(std::move(second)); second = Trim(std::move(second));
if (first.empty() || second.empty()) if (first.empty() || second.empty()) throw std::runtime_error(error_message);
throw std::runtime_error(error_message);
return {first, second}; return {first, second};
} }
@@ -156,7 +149,8 @@ std::string ToChatPrompt(const llama_model *model,
const llama_chat_message message{"user", user_prompt.c_str()}; const llama_chat_message message{"user", user_prompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, user_prompt.size() * 4)); std::vector<char> buffer(
std::max<std::size_t>(1024, user_prompt.size() * 4));
int32_t required = int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
@@ -167,10 +161,12 @@ std::string ToChatPrompt(const llama_model *model,
if (required >= static_cast<int32_t>(buffer.size())) { if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1); buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(), required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
if (required < 0) { if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template"); throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
} }
} }
@@ -200,10 +196,12 @@ std::string ToChatPrompt(const llama_model *model,
if (required >= static_cast<int32_t>(buffer.size())) { if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1); buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(), required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size())); static_cast<int32_t>(buffer.size()));
if (required < 0) { if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template"); throw std::runtime_error(
"LlamaGenerator: failed to apply chat template");
} }
} }
@@ -220,8 +218,8 @@ void AppendTokenPiece(const llama_vocab *vocab, llama_token token,
if (bytes < 0) { if (bytes < 0) {
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes)); std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(), bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
static_cast<int32_t>(dynamic_buffer.size()), 0, static_cast<int32_t>(dynamic_buffer.size()),
true); 0, true);
if (bytes < 0) { if (bytes < 0) {
throw std::runtime_error( throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece"); "LlamaGenerator: failed to decode sampled token piece");
@@ -372,9 +370,8 @@ std::string PrepareRegionContextPublic(std::string_view region_context,
return PrepareRegionContext(region_context, max_chars); return PrepareRegionContext(region_context, max_chars);
} }
std::pair<std::string, std::string> std::pair<std::string, std::string> ParseTwoLineResponsePublic(
ParseTwoLineResponsePublic(const std::string& raw, const std::string& raw, const std::string& error_message) {
const std::string& error_message) {
return ParseTwoLineResponse(raw, error_message); return ParseTwoLineResponse(raw, error_message);
} }

View File

@@ -1,14 +1,14 @@
#include <spdlog/spdlog.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "llama.h"
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "data_generation/llama_generator_helpers.h" #include "data_generation/llama_generator_helpers.h"
#include "llama.h"
std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) { std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr) if (model_ == nullptr || context_ == nullptr)
@@ -45,14 +45,16 @@ std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
} }
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget); prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens "
"to fit n_batch/n_ctx limits", "to fit n_batch/n_ctx limits",
token_count, prompt_budget); token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
@@ -84,9 +86,9 @@ std::string LlamaGenerator::Infer(const std::string& prompt, int max_tokens) {
generated_tokens.reserve(static_cast<std::size_t>(max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); const llama_token next =
if (llama_vocab_is_eog(vocab, next)) llama_sampler_sample(sampler.get(), context_, -1);
break; if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next); generated_tokens.push_back(next);
llama_token token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
@@ -138,14 +140,16 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
throw std::runtime_error("LlamaGenerator: invalid context or batch size"); throw std::runtime_error("LlamaGenerator: invalid context or batch size");
} }
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, n_ctx - 1)); const int32_t effective_max_tokens =
std::max(1, std::min(max_tokens, n_ctx - 1));
int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens); int32_t prompt_budget = std::min(n_batch, n_ctx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget); prompt_budget = std::max<int32_t>(1, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(token_count)); prompt_tokens.resize(static_cast<std::size_t>(token_count));
if (token_count > prompt_budget) { if (token_count > prompt_budget) {
spdlog::warn( spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " "LlamaGenerator: prompt too long ({} tokens), truncating to {} "
"tokens "
"to fit n_batch/n_ctx limits", "to fit n_batch/n_ctx limits",
token_count, prompt_budget); token_count, prompt_budget);
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget)); prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
@@ -177,9 +181,9 @@ std::string LlamaGenerator::Infer(const std::string& system_prompt,
generated_tokens.reserve(static_cast<std::size_t>(max_tokens)); generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) { for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); const llama_token next =
if (llama_vocab_is_eog(vocab, next)) llama_sampler_sample(sampler.get(), context_, -1);
break; if (llama_vocab_is_eog(vocab, next)) break;
generated_tokens.push_back(next); generated_tokens.push_back(next);
llama_token token = next; llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1); const llama_batch one_token_batch = llama_batch_get_one(&token, 1);

View File

@@ -1,10 +1,10 @@
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "llama.h"
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::Load(const std::string& model_path) { void LlamaGenerator::Load(const std::string& model_path) {
if (model_path.empty()) if (model_path.empty())

View File

@@ -1,8 +1,7 @@
#include <stdexcept> #include <stdexcept>
#include "llama.h"
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include "llama.h"
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p, void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) { int seed) {

View File

@@ -1,7 +1,7 @@
#include "data_generation/mock_generator.h" #include <string>
#include <vector>
#include <functional> #include "data_generation/mock_generator.h"
#include <spdlog/spdlog.h>
const std::vector<std::string> MockGenerator::kBreweryAdjectives = { const std::vector<std::string> MockGenerator::kBreweryAdjectives = {
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
@@ -63,42 +63,3 @@ const std::vector<std::string> MockGenerator::kBios = {
"Craft beer fan mapping tasting notes and favorite brew routes.", "Craft beer fan mapping tasting notes and favorite brew routes.",
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}
std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}
BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}
UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,12 @@
#include <string>
#include "data_generation/mock_generator.h"
std::size_t MockGenerator::DeterministicHash(const std::string& a,
const std::string& b) {
std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b);
seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13));
return seed;
}

View File

@@ -0,0 +1,21 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
BreweryResult MockGenerator::GenerateBrewery(
const std::string& city_name, const std::string& country_name,
const std::string& region_context) {
const std::string location_key =
country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash =
region_context.empty() ? std::hash<std::string>{}(location_key)
: DeterministicHash(location_key, region_context);
BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
kBreweryNouns[(hash / 7) % kBreweryNouns.size()];
result.description =
kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()];
return result;
}

View File

@@ -0,0 +1,13 @@
#include <functional>
#include <string>
#include "data_generation/mock_generator.h"
UserResult MockGenerator::GenerateUser(const std::string& locale) {
const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result;
result.username = kUsernames[hash % kUsernames.size()];
result.bio = kBios[(hash / 11) % kBios.size()];
return result;
}

View File

@@ -0,0 +1,9 @@
#include <spdlog/spdlog.h>
#include <string>
#include "data_generation/mock_generator.h"
void MockGenerator::Load(const std::string& /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed");
}

View File

@@ -1,5 +1,7 @@
#include "database/database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
@@ -104,7 +106,8 @@ void SqliteDatabase::InsertCountry(int id, const std::string &name,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int country_id, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id,
const std::string& name,
const std::string& iso2) { const std::string& iso2) {
std::lock_guard<std::mutex> lock(db_mutex_); std::lock_guard<std::mutex> lock(db_mutex_);

View File

@@ -1,8 +1,9 @@
#include <chrono> #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include "json_handling/json_loader.h" #include <chrono>
#include "json_handling/stream_parser.h" #include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string& json_path, void JsonLoader::LoadWorldCities(const std::string& json_path,

View File

@@ -1,12 +1,13 @@
#include <cstdio> #include "json_handling/stream_parser.h"
#include <stdexcept>
#include <spdlog/spdlog.h>
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp> #include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <cstdio>
#include <stdexcept>
#include "database/database.h" #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler { class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>; friend class boost::json::basic_parser<CityRecordHandler>;
@@ -235,7 +236,6 @@ void StreamingJsonParser::Parse(
const std::string& file_path, SqliteDatabase& db, const std::string& file_path, SqliteDatabase& db,
std::function<void(const CityRecord&)> on_city, std::function<void(const CityRecord&)> on_city,
std::function<void(size_t, size_t)> on_progress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path); spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
FILE* file = std::fopen(file_path.c_str(), "rb"); FILE* file = std::fopen(file_path.c_str(), "rb");
@@ -252,8 +252,8 @@ void StreamingJsonParser::Parse(
std::rewind(file); std::rewind(file);
} }
CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0, total_size,
total_size, 0, 0}; 0, 0};
boost::json::basic_parser<CityRecordHandler> parser( boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx); boost::json::parse_options{}, ctx);
@@ -284,5 +284,6 @@ void StreamingJsonParser::Parse(
} }
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); ctx.countries_inserted, ctx.states_inserted,
ctx.cities_emitted);
} }

View File

@@ -1,6 +1,8 @@
#include "web_client/curl_web_client.h" #include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h> #include <curl/curl.h>
#include <cstdio>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
@@ -68,8 +70,8 @@ void CURLWebClient::DownloadToFile(const std::string &url,
std::ofstream outFile(file_path, std::ios::binary); std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) { if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + throw std::runtime_error(
file_path); "[CURLWebClient] Cannot open file for writing: " + file_path);
} }
set_common_get_options(curl.get(), url, 30L, 300L); set_common_get_options(curl.get(), url, 30L, 300L);

View File

@@ -1,7 +1,9 @@
#include "wikipedia/wikipedia_service.h" #include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <boost/json.hpp>
WikipediaService::WikipediaService(std::shared_ptr<WebClient> client) WikipediaService::WikipediaService(std::shared_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}
@@ -63,8 +65,7 @@ std::string WikipediaService::GetSummary(std::string_view city,
result += regionExtract; result += regionExtract;
} }
if (!beerExtract.empty()) { if (!beerExtract.empty()) {
if (!result.empty()) if (!result.empty()) result += "\n\n";
result += "\n\n";
result += beerExtract; result += beerExtract;
} }
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {