fix: llama backend lifetime, Wikipedia enrichment depth, and misc cleanup

This commit is contained in:
Aaron Po
2026-04-09 21:59:13 -04:00
parent 824f5b2b4f
commit b53f9e5582
17 changed files with 161 additions and 104 deletions

View File

@@ -3,8 +3,7 @@
* @brief LlamaGenerator constructor implementation.
*/
#include <llama.h>
#include <random>
#include <stdexcept>
#include <string>
@@ -12,7 +11,8 @@
#include "data_generation/llama_generator.h"
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) {
const std::string& model_path)
: rng_() {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
@@ -39,15 +39,13 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_seed_ = (options.seed < 0)
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(options.seed);
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
try {
Load(model_path);
} catch (...) {
llama_backend_free();
throw;
}
Load(model_path);
}

View File

@@ -23,9 +23,4 @@ LlamaGenerator::~LlamaGenerator() {
llama_model_free(model_);
model_ = nullptr;
}
/**
* Clean up the backend (GPU/CPU acceleration resources)
*/
llama_backend_free();
}

View File

@@ -145,8 +145,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
/**
* TOKEN GENERATION LOOP
@@ -187,10 +186,5 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
/**
* Advance seed for next generation to improve output diversity
*/
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
return output;
}

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
@@ -22,11 +23,6 @@ void LlamaGenerator::Load(const std::string& model_path) {
model_ = nullptr;
}
/**
* Initialize the llama backend (one-time setup for GPU/CPU acceleration)
*/
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
@@ -36,7 +32,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = n_ctx_; // Set batch size equal to context window
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(512));
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {

View File

@@ -16,6 +16,7 @@
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
#include "llama_backend_state.h"
#include "services/enrichment_service.h"
#include "services/wikipedia_service.h"
#include "web_client/curl_web_client.h"
@@ -116,6 +117,7 @@ auto ParseArguments(const int argc, char** argv,
auto main(const int argc, char** argv) noexcept -> int {
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
ApplicationOptions options;

View File

@@ -11,19 +11,24 @@
#include "services/wikipedia_service.h"
auto WikipediaService::FetchExtract(std::string_view query) const
-> std::string {
const std::string encoded = client_->UrlEncode(std::string(query));
auto WikipediaService::FetchExtract(std::string_view query) -> std::string {
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string body = client_->Get(url);
const std::string body = this->client_->Get(url);
boost::system::error_code ec;
boost::json::value doc = boost::json::parse(body, ec);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
if (!ec && doc.is_object()) {
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
@@ -32,9 +37,11 @@ auto WikipediaService::FetchExtract(std::string_view query) const
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
@@ -42,9 +49,9 @@ auto WikipediaService::FetchExtract(std::string_view query) const
query, e.what());
return {};
}
} else if (ec) {
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
ec.message());
parse_error.message());
}
return {};

View File

@@ -30,20 +30,22 @@ auto WikipediaService::GetLocationContext(const Location& loc) -> std::string {
}
const std::string beer_query = "beer in " + loc.country;
const std::string city_beer_query = "beer in " + loc.city;
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
return;
}
if (!result.empty()) {
result += "\n\n";
}
result += extract;
};
try {
const std::string region_extract = FetchExtract(region_query);
const std::string beer_extract = FetchExtract(beer_query);
if (!region_extract.empty()) {
result += region_extract;
}
if (!beer_extract.empty()) {
if (!result.empty()) {
result += "\n\n";
}
result += beer_extract;
}
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());

View File

@@ -7,35 +7,12 @@
#include <cstdio>
#include <fstream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h"
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that writes to a file stream
static size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void* userp) {
@@ -55,7 +32,7 @@ void CURLWebClient::DownloadToFile(const std::string& url,
"[CURLWebClient] Cannot open file for writing: " + file_path);
}
set_common_get_options(curl.get(), url, 30L, 300L);
set_common_get_options(curl.get(), url, {30L, 300L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void*>(&outFile));

View File

@@ -5,36 +5,13 @@
#include <curl/curl.h>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h"
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void* userp) {
@@ -48,7 +25,7 @@ std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle();
std::string response_string;
set_common_get_options(curl.get(), url, 10L, 20L);
set_common_get_options(curl.get(), url, {10L, 20L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);

View File

@@ -0,0 +1,28 @@
/**
* @file web_client/curl_web_client_utils.cpp
* @brief Shared CURLWebClient helper implementations.
*/
#include "curl_web_client_utils.h"
#include <stdexcept>
auto create_handle() -> CurlHandle {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
auto set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts) -> void {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}

View File

@@ -0,0 +1,26 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
/**
* @file web_client/curl_web_client_utils.h
* @brief Shared helpers for CURLWebClient request setup.
*/
#include <curl/curl.h>
#include <memory>
#include <string>
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
struct CurlTimeouts {
long connect_timeout;
long total_timeout;
};
CurlHandle create_handle();
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts);
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_