fix: llama backend lifetime, Wikipedia enrichment depth, and misc cleanup

This commit is contained in:
Aaron Po
2026-04-09 21:59:13 -04:00
parent 824f5b2b4f
commit b53f9e5582
17 changed files with 161 additions and 104 deletions

View File

@@ -120,6 +120,7 @@ set(SOURCES
src/web_client/curl_web_client_destructor.cpp
src/web_client/curl_web_client_download_to_file.cpp
src/web_client/curl_web_client_get.cpp
src/web_client/curl_web_client_utils.cpp
src/web_client/curl_web_client_url_encode.cpp
# Data generation modules
src/data_generation/llama/destructor.cpp

View File

@@ -2,6 +2,26 @@
Biergarten Pipeline is a C++23 command-line tool that reads a local city list, resolves contextual enrichment for each sampled city through an injected service, and generates brewery names and descriptions. The current code samples up to four locations per run, then uses either a local GGUF model or the mock generator to produce the output.
## Hardware & GPU Config
### x86/64 Linux, NVIDIA RTX 2000
- **Host**: ThinkPad P1 Gen 7 (Fedora 43)
- **CPU**: Intel Core Ultra 7 155H
- **GPU**: NVIDIA RTX 2000 Ada Generation
- **Memory**: 32GB
- **Model**: Qwen3-8B-Q6-K
- **Inference**: llama.cpp with CUDA 12.x support
### ARM MacOS, M1 Pro
- **Host**: MacBook Pro 14" (2021)
- **CPU**: Apple M1 Pro (8-core)
- **GPU**: Apple M1 Pro (14-core) [Integrated]
- **Memory**: 16GB
- **Model**: Qwen3-8B-Q6-K
- **Inference**: llama.cpp with Metal (MPS) support
## Pipeline
| Stage | What happens |

View File

@@ -7,6 +7,7 @@
*/
#include <cstdint>
#include <random>
#include <string>
#include "data_generation/data_generator.h"
@@ -114,7 +115,7 @@ class LlamaGenerator final : public DataGenerator {
llama_context* context_ = nullptr;
float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu;
std::mt19937 rng_;
uint32_t n_ctx_ = 8192;
std::string brewery_system_prompt_;
};

View File

@@ -21,7 +21,7 @@ typedef int llama_token;
* @return Processed region context.
*/
std::string PrepareRegionContextPublic(std::string_view region_context,
std::size_t max_chars = 700);
std::size_t max_chars = 2000);
/**
* @brief Parses a response expected to contain two logical lines.

View File

@@ -0,0 +1,32 @@
#ifndef BIERGARTEN_PIPELINE_LLAMA_BACKEND_STATE_H_
#define BIERGARTEN_PIPELINE_LLAMA_BACKEND_STATE_H_
/**
* @file llama_backend_state.h
* @brief RAII guard for llama.cpp backend process lifetime.
*/
#include <llama.h>
/**
* @brief RAII wrapper for llama_backend_init and llama_backend_free.
*
* Create one instance in application startup before using llama.cpp and keep
* it alive for application lifetime.
*/
class LlamaBackendState {
public:
/// @brief Initializes global llama backend state.
LlamaBackendState() { llama_backend_init(); }
/// @brief Cleans up global llama backend state.
~LlamaBackendState() { llama_backend_free(); }
/// @brief Non-copyable type.
LlamaBackendState(const LlamaBackendState&) = delete;
/// @brief Non-copyable type.
LlamaBackendState& operator=(const LlamaBackendState&) = delete;
};
#endif // BIERGARTEN_PIPELINE_LLAMA_BACKEND_STATE_H_

View File

@@ -24,9 +24,10 @@ class WikipediaService final : public IEnrichmentService {
[[nodiscard]] std::string GetLocationContext(const Location& loc) override;
private:
std::string FetchExtract(std::string_view query) const;
std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_;
std::unordered_map<std::string, std::string> extract_cache_;
};
#endif // BIERGARTEN_PIPELINE_WIKIPEDIA_SERVICE_H_

View File

@@ -3,8 +3,7 @@
* @brief LlamaGenerator constructor implementation.
*/
#include <llama.h>
#include <random>
#include <stdexcept>
#include <string>
@@ -12,7 +11,8 @@
#include "data_generation/llama_generator.h"
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
const std::string& model_path) {
const std::string& model_path)
: rng_() {
if (model_path.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
@@ -39,15 +39,13 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
sampling_temperature_ = options.temperature;
sampling_top_p_ = options.top_p;
sampling_seed_ = (options.seed < 0)
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(options.seed);
if (options.seed == -1) {
std::random_device random_device;
rng_.seed(random_device());
} else {
rng_.seed(static_cast<uint32_t>(options.seed));
}
n_ctx_ = options.n_ctx;
try {
Load(model_path);
} catch (...) {
llama_backend_free();
throw;
}
}

View File

@@ -23,9 +23,4 @@ LlamaGenerator::~LlamaGenerator() {
llama_model_free(model_);
model_ = nullptr;
}
/**
* Clean up the backend (GPU/CPU acceleration resources)
*/
llama_backend_free();
}

View File

@@ -145,8 +145,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
* Distribution sampler: selects actual token using configured seed for
* reproducibility
*/
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
/**
* TOKEN GENERATION LOOP
@@ -187,10 +186,5 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
for (const llama_token token : generated_tokens)
AppendTokenPiecePublic(vocab, token, output);
/**
* Advance seed for next generation to improve output diversity
*/
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
return output;
}

View File

@@ -6,6 +6,7 @@
#include <spdlog/spdlog.h>
#include <algorithm>
#include <stdexcept>
#include <string>
@@ -22,11 +23,6 @@ void LlamaGenerator::Load(const std::string& model_path) {
model_ = nullptr;
}
/**
* Initialize the llama backend (one-time setup for GPU/CPU acceleration)
*/
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
@@ -36,7 +32,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = n_ctx_;
context_params.n_batch = n_ctx_; // Set batch size equal to context window
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(512));
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {

View File

@@ -16,6 +16,7 @@
#include "biergarten_data_generator.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
#include "llama_backend_state.h"
#include "services/enrichment_service.h"
#include "services/wikipedia_service.h"
#include "web_client/curl_web_client.h"
@@ -116,6 +117,7 @@ auto ParseArguments(const int argc, char** argv,
auto main(const int argc, char** argv) noexcept -> int {
try {
const CurlGlobalState curl_state;
const LlamaBackendState llama_backend_state;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
ApplicationOptions options;

View File

@@ -11,19 +11,24 @@
#include "services/wikipedia_service.h"
auto WikipediaService::FetchExtract(std::string_view query) const
-> std::string {
const std::string encoded = client_->UrlEncode(std::string(query));
auto WikipediaService::FetchExtract(std::string_view query) -> std::string {
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string body = client_->Get(url);
const std::string body = this->client_->Get(url);
boost::system::error_code ec;
boost::json::value doc = boost::json::parse(body, ec);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
if (!ec && doc.is_object()) {
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
@@ -32,9 +37,11 @@ auto WikipediaService::FetchExtract(std::string_view query) const
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
@@ -42,9 +49,9 @@ auto WikipediaService::FetchExtract(std::string_view query) const
query, e.what());
return {};
}
} else if (ec) {
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
ec.message());
parse_error.message());
}
return {};

View File

@@ -30,20 +30,22 @@ auto WikipediaService::GetLocationContext(const Location& loc) -> std::string {
}
const std::string beer_query = "beer in " + loc.country;
const std::string city_beer_query = "beer in " + loc.city;
try {
const std::string region_extract = FetchExtract(region_query);
const std::string beer_extract = FetchExtract(beer_query);
if (!region_extract.empty()) {
result += region_extract;
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
return;
}
if (!beer_extract.empty()) {
if (!result.empty()) {
result += "\n\n";
}
result += beer_extract;
}
result += extract;
};
try {
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());

View File

@@ -7,35 +7,12 @@
#include <cstdio>
#include <fstream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h"
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that writes to a file stream
static size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
void* userp) {
@@ -55,7 +32,7 @@ void CURLWebClient::DownloadToFile(const std::string& url,
"[CURLWebClient] Cannot open file for writing: " + file_path);
}
set_common_get_options(curl.get(), url, 30L, 300L);
set_common_get_options(curl.get(), url, {30L, 300L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void*>(&outFile));

View File

@@ -5,36 +5,13 @@
#include <curl/curl.h>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include "curl_web_client_utils.h"
#include "web_client/curl_web_client.h"
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
static CurlHandle create_handle() {
CURL* handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
static void set_common_get_options(CURL* curl, const std::string& url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
// curl write callback that appends response data into a std::string
static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
void* userp) {
@@ -48,7 +25,7 @@ std::string CURLWebClient::Get(const std::string& url) {
auto curl = create_handle();
std::string response_string;
set_common_get_options(curl.get(), url, 10L, 20L);
set_common_get_options(curl.get(), url, {10L, 20L});
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);

View File

@@ -0,0 +1,28 @@
/**
* @file web_client/curl_web_client_utils.cpp
* @brief Shared CURLWebClient helper implementations.
*/
#include "curl_web_client_utils.h"
#include <stdexcept>
auto create_handle() -> CurlHandle {
CURL* handle = curl_easy_init();
if (handle == nullptr) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
auto set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts) -> void {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}

View File

@@ -0,0 +1,26 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
/**
* @file web_client/curl_web_client_utils.h
* @brief Shared helpers for CURLWebClient request setup.
*/
#include <curl/curl.h>
#include <memory>
#include <string>
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
struct CurlTimeouts {
long connect_timeout;
long total_timeout;
};
CurlHandle create_handle();
void set_common_get_options(CURL* curl, const std::string& url,
CurlTimeouts timeouts);
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_