mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 10:04:00 +00:00
fix: llama backend lifetime, Wikipedia enrichment depth, and misc cleanup
This commit is contained in:
@@ -3,8 +3,7 @@
|
||||
* @brief LlamaGenerator constructor implementation.
|
||||
*/
|
||||
|
||||
#include <llama.h>
|
||||
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
@@ -12,7 +11,8 @@
|
||||
#include "data_generation/llama_generator.h"
|
||||
|
||||
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
const std::string& model_path) {
|
||||
const std::string& model_path)
|
||||
: rng_() {
|
||||
if (model_path.empty()) {
|
||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||
}
|
||||
@@ -39,15 +39,13 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||
|
||||
sampling_temperature_ = options.temperature;
|
||||
sampling_top_p_ = options.top_p;
|
||||
sampling_seed_ = (options.seed < 0)
|
||||
? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
|
||||
: static_cast<uint32_t>(options.seed);
|
||||
if (options.seed == -1) {
|
||||
std::random_device random_device;
|
||||
rng_.seed(random_device());
|
||||
} else {
|
||||
rng_.seed(static_cast<uint32_t>(options.seed));
|
||||
}
|
||||
n_ctx_ = options.n_ctx;
|
||||
|
||||
try {
|
||||
Load(model_path);
|
||||
} catch (...) {
|
||||
llama_backend_free();
|
||||
throw;
|
||||
}
|
||||
Load(model_path);
|
||||
}
|
||||
|
||||
@@ -23,9 +23,4 @@ LlamaGenerator::~LlamaGenerator() {
|
||||
llama_model_free(model_);
|
||||
model_ = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up the backend (GPU/CPU acceleration resources)
|
||||
*/
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
@@ -145,8 +145,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
* Distribution sampler: selects actual token using configured seed for
|
||||
* reproducibility
|
||||
*/
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_dist(sampling_seed_));
|
||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng_()));
|
||||
|
||||
/**
|
||||
* TOKEN GENERATION LOOP
|
||||
@@ -187,10 +186,5 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
for (const llama_token token : generated_tokens)
|
||||
AppendTokenPiecePublic(vocab, token, output);
|
||||
|
||||
/**
|
||||
* Advance seed for next generation to improve output diversity
|
||||
*/
|
||||
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
@@ -22,11 +23,6 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
model_ = nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the llama backend (one-time setup for GPU/CPU acceleration)
|
||||
*/
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
if (model_ == nullptr) {
|
||||
@@ -36,7 +32,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
|
||||
llama_context_params context_params = llama_context_default_params();
|
||||
context_params.n_ctx = n_ctx_;
|
||||
context_params.n_batch = n_ctx_; // Set batch size equal to context window
|
||||
context_params.n_batch = std::min(n_ctx_, static_cast<uint32_t>(512));
|
||||
|
||||
context_ = llama_init_from_model(model_, context_params);
|
||||
if (context_ == nullptr) {
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "biergarten_data_generator.h"
|
||||
#include "data_generation/llama_generator.h"
|
||||
#include "data_generation/mock_generator.h"
|
||||
#include "llama_backend_state.h"
|
||||
#include "services/enrichment_service.h"
|
||||
#include "services/wikipedia_service.h"
|
||||
#include "web_client/curl_web_client.h"
|
||||
@@ -116,6 +117,7 @@ auto ParseArguments(const int argc, char** argv,
|
||||
auto main(const int argc, char** argv) noexcept -> int {
|
||||
try {
|
||||
const CurlGlobalState curl_state;
|
||||
const LlamaBackendState llama_backend_state;
|
||||
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
|
||||
|
||||
ApplicationOptions options;
|
||||
|
||||
@@ -11,19 +11,24 @@
|
||||
|
||||
#include "services/wikipedia_service.h"
|
||||
|
||||
auto WikipediaService::FetchExtract(std::string_view query) const
|
||||
-> std::string {
|
||||
const std::string encoded = client_->UrlEncode(std::string(query));
|
||||
auto WikipediaService::FetchExtract(std::string_view query) -> std::string {
|
||||
const std::string cache_key(query);
|
||||
const auto cache_it = this->extract_cache_.find(cache_key);
|
||||
if (cache_it != this->extract_cache_.end()) {
|
||||
return cache_it->second;
|
||||
}
|
||||
|
||||
const std::string encoded = this->client_->UrlEncode(cache_key);
|
||||
const std::string url =
|
||||
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
|
||||
"&prop=extracts&explaintext=1&format=json";
|
||||
|
||||
const std::string body = client_->Get(url);
|
||||
const std::string body = this->client_->Get(url);
|
||||
|
||||
boost::system::error_code ec;
|
||||
boost::json::value doc = boost::json::parse(body, ec);
|
||||
boost::system::error_code parse_error;
|
||||
boost::json::value doc = boost::json::parse(body, parse_error);
|
||||
|
||||
if (!ec && doc.is_object()) {
|
||||
if (!parse_error && doc.is_object()) {
|
||||
try {
|
||||
auto& pages = doc.at("query").at("pages").get_object();
|
||||
if (!pages.empty()) {
|
||||
@@ -32,9 +37,11 @@ auto WikipediaService::FetchExtract(std::string_view query) const
|
||||
std::string extract(page.at("extract").as_string().c_str());
|
||||
spdlog::debug("WikipediaService fetched {} chars for '{}'",
|
||||
extract.size(), query);
|
||||
this->extract_cache_.emplace(cache_key, extract);
|
||||
return extract;
|
||||
}
|
||||
}
|
||||
this->extract_cache_.emplace(cache_key, std::string{});
|
||||
} catch (const std::exception& e) {
|
||||
spdlog::warn(
|
||||
"WikipediaService: failed to parse response structure for '{}': "
|
||||
@@ -42,9 +49,9 @@ auto WikipediaService::FetchExtract(std::string_view query) const
|
||||
query, e.what());
|
||||
return {};
|
||||
}
|
||||
} else if (ec) {
|
||||
} else if (parse_error) {
|
||||
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
|
||||
ec.message());
|
||||
parse_error.message());
|
||||
}
|
||||
|
||||
return {};
|
||||
|
||||
@@ -30,20 +30,22 @@ auto WikipediaService::GetLocationContext(const Location& loc) -> std::string {
|
||||
}
|
||||
|
||||
const std::string beer_query = "beer in " + loc.country;
|
||||
const std::string city_beer_query = "beer in " + loc.city;
|
||||
|
||||
auto append_extract = [&result](const std::string& extract) -> void {
|
||||
if (extract.empty()) {
|
||||
return;
|
||||
}
|
||||
if (!result.empty()) {
|
||||
result += "\n\n";
|
||||
}
|
||||
result += extract;
|
||||
};
|
||||
|
||||
try {
|
||||
const std::string region_extract = FetchExtract(region_query);
|
||||
const std::string beer_extract = FetchExtract(beer_query);
|
||||
|
||||
if (!region_extract.empty()) {
|
||||
result += region_extract;
|
||||
}
|
||||
if (!beer_extract.empty()) {
|
||||
if (!result.empty()) {
|
||||
result += "\n\n";
|
||||
}
|
||||
result += beer_extract;
|
||||
}
|
||||
append_extract(FetchExtract(region_query));
|
||||
append_extract(FetchExtract(beer_query));
|
||||
append_extract(FetchExtract(city_beer_query));
|
||||
} catch (const std::runtime_error& e) {
|
||||
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
|
||||
e.what());
|
||||
|
||||
@@ -7,35 +7,12 @@
|
||||
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "curl_web_client_utils.h"
|
||||
#include "web_client/curl_web_client.h"
|
||||
|
||||
// RAII wrapper for CURL handle using unique_ptr
|
||||
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||
|
||||
static CurlHandle create_handle() {
|
||||
CURL* handle = curl_easy_init();
|
||||
if (!handle) {
|
||||
throw std::runtime_error(
|
||||
"[CURLWebClient] Failed to initialize libcurl handle");
|
||||
}
|
||||
return CurlHandle(handle, &curl_easy_cleanup);
|
||||
}
|
||||
|
||||
static void set_common_get_options(CURL* curl, const std::string& url,
|
||||
long connect_timeout, long total_timeout) {
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
|
||||
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
|
||||
}
|
||||
|
||||
// curl write callback that writes to a file stream
|
||||
static size_t WriteCallbackFile(void* contents, size_t size, size_t nmemb,
|
||||
void* userp) {
|
||||
@@ -55,7 +32,7 @@ void CURLWebClient::DownloadToFile(const std::string& url,
|
||||
"[CURLWebClient] Cannot open file for writing: " + file_path);
|
||||
}
|
||||
|
||||
set_common_get_options(curl.get(), url, 30L, 300L);
|
||||
set_common_get_options(curl.get(), url, {30L, 300L});
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
|
||||
static_cast<void*>(&outFile));
|
||||
|
||||
@@ -5,36 +5,13 @@
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "curl_web_client_utils.h"
|
||||
#include "web_client/curl_web_client.h"
|
||||
|
||||
// RAII wrapper for CURL handle using unique_ptr
|
||||
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||
|
||||
static CurlHandle create_handle() {
|
||||
CURL* handle = curl_easy_init();
|
||||
if (!handle) {
|
||||
throw std::runtime_error(
|
||||
"[CURLWebClient] Failed to initialize libcurl handle");
|
||||
}
|
||||
return CurlHandle(handle, &curl_easy_cleanup);
|
||||
}
|
||||
|
||||
static void set_common_get_options(CURL* curl, const std::string& url,
|
||||
long connect_timeout, long total_timeout) {
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
|
||||
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
|
||||
}
|
||||
|
||||
// curl write callback that appends response data into a std::string
|
||||
static size_t WriteCallbackString(void* contents, size_t size, size_t nmemb,
|
||||
void* userp) {
|
||||
@@ -48,7 +25,7 @@ std::string CURLWebClient::Get(const std::string& url) {
|
||||
auto curl = create_handle();
|
||||
|
||||
std::string response_string;
|
||||
set_common_get_options(curl.get(), url, 10L, 20L);
|
||||
set_common_get_options(curl.get(), url, {10L, 20L});
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
|
||||
|
||||
|
||||
28
pipeline/src/web_client/curl_web_client_utils.cpp
Normal file
28
pipeline/src/web_client/curl_web_client_utils.cpp
Normal file
@@ -0,0 +1,28 @@
|
||||
/**
|
||||
* @file web_client/curl_web_client_utils.cpp
|
||||
* @brief Shared CURLWebClient helper implementations.
|
||||
*/
|
||||
|
||||
#include "curl_web_client_utils.h"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
auto create_handle() -> CurlHandle {
|
||||
CURL* handle = curl_easy_init();
|
||||
if (handle == nullptr) {
|
||||
throw std::runtime_error(
|
||||
"[CURLWebClient] Failed to initialize libcurl handle");
|
||||
}
|
||||
return CurlHandle(handle, &curl_easy_cleanup);
|
||||
}
|
||||
|
||||
auto set_common_get_options(CURL* curl, const std::string& url,
|
||||
CurlTimeouts timeouts) -> void {
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
|
||||
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, timeouts.connect_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeouts.total_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
|
||||
}
|
||||
26
pipeline/src/web_client/curl_web_client_utils.h
Normal file
26
pipeline/src/web_client/curl_web_client_utils.h
Normal file
@@ -0,0 +1,26 @@
|
||||
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
|
||||
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
|
||||
|
||||
/**
|
||||
* @file web_client/curl_web_client_utils.h
|
||||
* @brief Shared helpers for CURLWebClient request setup.
|
||||
*/
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||
|
||||
struct CurlTimeouts {
|
||||
long connect_timeout;
|
||||
long total_timeout;
|
||||
};
|
||||
|
||||
CurlHandle create_handle();
|
||||
|
||||
void set_common_get_options(CURL* curl, const std::string& url,
|
||||
CurlTimeouts timeouts);
|
||||
|
||||
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_UTILS_H_
|
||||
Reference in New Issue
Block a user