5 Commits

Author SHA1 Message Date
Aaron Po
5abb3f2e24 Add mock enrichment process 2026-05-14 13:49:59 -04:00
Aaron Po
a057b9197f Add location count to application options and as a cli arg 2026-05-13 22:04:48 -04:00
Aaron Po
773e7c774b Add timeout for enrichment, refactor json deserialization 2026-05-13 12:44:30 -04:00
b7c0b1c8d4 Fix mistake in .gitattributes
archive/* is incorrect as it will ignore sub-dirs
2026-05-12 01:05:07 -04:00
b8ebe03921 Pipeline: Add Runpod docker configuration (#222)
* Begin work on Runpod docker config

* Reduce docker image size

* Create .dockerignore
2026-05-12 00:44:09 -04:00
23 changed files with 346 additions and 207 deletions

2
.gitattributes vendored
View File

@@ -1 +1 @@
archive/* linguist-vendored archive/** linguist-vendored

View File

@@ -137,7 +137,8 @@ set(HTTPLIB_REQUIRE_OPENSSL ON CACHE BOOL "Require OpenSSL for cpp-httplib" FORC
FetchContent_MakeAvailable(cpp-httplib) FetchContent_MakeAvailable(cpp-httplib)
# 5. Executable & Sources # 5. Executable & Sources
add_executable(${PROJECT_NAME}) add_executable(${PROJECT_NAME}
includes/services/enrichment/mock_enrichment.h)
# --- Entry point --- # --- Entry point ---
target_sources(${PROJECT_NAME} PRIVATE target_sources(${PROJECT_NAME} PRIVATE
@@ -194,9 +195,9 @@ endif()
# --- services: wikipedia --- # --- services: wikipedia ---
target_sources(${PROJECT_NAME} PRIVATE target_sources(${PROJECT_NAME} PRIVATE
src/services/wikipedia/wikipedia_service.cc src/services/enrichment/wikipedia/wikipedia_service.cc
src/services/wikipedia/fetch_extract.cc src/services/enrichment/wikipedia/fetch_extract.cc
src/services/wikipedia/get_summary.cc src/services/enrichment/wikipedia/get_summary.cc
) )
# --- services: sqlite --- # --- services: sqlite ---

View File

@@ -12,8 +12,8 @@
#include "data_generation/data_generator.h" #include "data_generation/data_generator.h"
#include "data_model/generated_models.h" #include "data_model/generated_models.h"
#include "services/enrichment/enrichment_service.h"
#include "services/database/export_service.h" #include "services/database/export_service.h"
#include "services/enrichment/enrichment_service.h"
/** /**
* @brief Main data generator class for the Biergarten pipeline. * @brief Main data generator class for the Biergarten pipeline.
@@ -32,7 +32,8 @@ class BiergartenDataGenerator {
*/ */
BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service, BiergartenDataGenerator(std::unique_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator, std::unique_ptr<DataGenerator> generator,
std::unique_ptr<IExportService> exporter); std::unique_ptr<IExportService> exporter,
const ApplicationOptions& application_options);
/** /**
* @brief Run the data generation pipeline. * @brief Run the data generation pipeline.
@@ -56,12 +57,14 @@ class BiergartenDataGenerator {
/// @brief Storage backend for generated brewery records. /// @brief Storage backend for generated brewery records.
std::unique_ptr<IExportService> exporter_; std::unique_ptr<IExportService> exporter_;
const ApplicationOptions application_options_;
/** /**
* @brief Load locations from JSON and sample cities. * @brief Load locations from JSON and sample cities.
* *
* @return Vector of sampled locations capped at 50 entries. * @return Vector of sampled locations capped at 50 entries.
*/ */
static std::vector<Location> QueryCitiesWithCountries(); std::vector<Location> QueryCitiesWithCountries();
/** /**
* @brief Generate breweries for enriched cities. * @brief Generate breweries for enriched cities.

View File

@@ -83,6 +83,9 @@ struct SamplingOptions {
/// @brief Random seed (-1 for random, otherwise non-negative). /// @brief Random seed (-1 for random, otherwise non-negative).
int seed = -1; int seed = -1;
/// @brief Number of layers to offload to GPU.
int n_gpu_layers = 0;
}; };
/** /**
@@ -95,8 +98,7 @@ struct GeneratorOptions {
/// @brief Use mocked generator instead of actual LLM inference. /// @brief Use mocked generator instead of actual LLM inference.
bool use_mocked = false; bool use_mocked = false;
/// @brief Number of layers to offload to GPU.
int n_gpu_layers = 0;
/// @brief Specific sampling parameters for this generator. /// @brief Specific sampling parameters for this generator.
/// If nullopt, the application should use global defaults. /// If nullopt, the application should use global defaults.
@@ -116,6 +118,10 @@ struct PipelineOptions {
/// @brief Path for application logs. /// @brief Path for application logs.
std::filesystem::path log_path; std::filesystem::path log_path;
/// @brief Number of locations to sample from the dataset
/// More locations -> more users/more breweries
uint32_t location_count;
}; };
/** /**

View File

@@ -0,0 +1,17 @@
//
// Created by aaronpo on 13/05/2026.
//
#ifndef BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_MOCK_ENRICHMENT_H_
#define BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_MOCK_ENRICHMENT_H_
#include <string>
#include "enrichment_service.h"
class MockEnrichmentService final : public IEnrichmentService {
public:
std::string GetLocationContext(const Location& /*loc*/) override {
return {};
}
};
#endif // BIERGARTEN_PIPELINE_INCLUDES_SERVICES_ENRICHMENT_MOCK_ENRICHMENT_H_

View File

@@ -15,10 +15,10 @@
#include "web_client/web_client.h" #include "web_client/web_client.h"
/// @brief Provides Wikipedia summary lookups backed by cached raw extracts. /// @brief Provides Wikipedia summary lookups backed by cached raw extracts.
class WikipediaService final : public IEnrichmentService { class WikipediaEnrichmentService final : public IEnrichmentService {
public: public:
/// @brief Creates a new Wikipedia service with the provided web client. /// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::unique_ptr<WebClient> client); explicit WikipediaEnrichmentService(std::unique_ptr<WebClient> client);
/// @brief Returns the Wikipedia-derived context for a location. /// @brief Returns the Wikipedia-derived context for a location.
[[nodiscard]] std::string GetLocationContext(const Location& loc) override; [[nodiscard]] std::string GetLocationContext(const Location& loc) override;

View File

@@ -42,7 +42,7 @@ public:
* @param value Raw string to encode. * @param value Raw string to encode.
* @return Percent-encoded string safe for use in a URL. * @return Percent-encoded string safe for use in a URL.
*/ */
std::string UrlEncode(const std::string& value) override; std::string EncodeURL(const std::string& value) override;
}; };

View File

@@ -30,7 +30,7 @@ class WebClient {
* @param value Raw string value. * @param value Raw string value.
* @return Encoded value safe for URL usage. * @return Encoded value safe for URL usage.
*/ */
virtual std::string UrlEncode(const std::string& value) = 0; virtual std::string EncodeURL(const std::string& value) = 0;
}; };
#endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_ #endif // BIERGARTEN_PIPELINE_INCLUDES_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -0,0 +1,9 @@
# Ignore model files!
*.gguf
*.bin
models/
weights/
# Ignore local build folders
build/
.git/

View File

@@ -1,65 +1,72 @@
# Phase 1: Pull prebuilt binaries # --- Stage 1: Build Environment (The "Heavy" Stage) ---
FROM ghcr.io/ggml-org/llama.cpp:full-cuda AS llama-bin FROM nvidia/cuda:12.6.3-devel-ubuntu24.04 AS builder
# Phase 2: Building environment
FROM nvidia/cuda:12.6.3-devel-ubuntu24.04
ENV DEBIAN_FRONTEND=noninteractive \ ENV DEBIAN_FRONTEND=noninteractive \
CMAKE_GENERATOR=Ninja \ CMAKE_GENERATOR=Ninja
APP_ROOT=/workspace/app \
BUILD_DIR=/workspace/app/build
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \ build-essential ca-certificates curl git libboost-json-dev \
ca-certificates \ libboost-program-options-dev libssl-dev ninja-build pkg-config zlib1g-dev \
curl \
git \
libboost-json-dev \
libboost-program-options-dev \
libssl-dev \
ninja-build \
pkg-config \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install modern CMake via curl (Ubuntu 24.04 'apt' version can be laggy) # Install modern CMake
RUN curl -L https://github.com/Kitware/CMake/releases/download/v3.31.0/cmake-3.31.0-linux-x86_64.sh -o cmake.sh && \ RUN curl -L https://github.com/Kitware/CMake/releases/download/v3.31.0/cmake-3.31.0-linux-x86_64.sh -o cmake.sh && \
sh cmake.sh --skip-license --prefix=/usr/local && rm cmake.sh sh cmake.sh --skip-license --prefix=/usr/local && rm cmake.sh
# Copy backends to /usr/local/lib and register with ldconfig so the # Get headers for C++ build
# runtime linker can resolve libllama.so, libggml.so, libggml-base.so etc. RUN curl -L https://github.com/ggml-org/llama.cpp/archive/refs/tags/b9012.tar.gz -o /tmp/llama-src.tar.gz && \
COPY --from=llama-bin /app/lib*.so* /usr/local/lib/ tar -xzf /tmp/llama-src.tar.gz -C /tmp && \
RUN ldconfig cp -r /tmp/llama.cpp-b9012/include/* /usr/local/include/ && \
cp -r /tmp/llama.cpp-b9012/ggml/include/* /usr/local/include/
# Headers for C++ Build # Pull llama.cpp binaries to use during build if needed
RUN git clone --depth 1 -b b9012 https://github.com/ggml-org/llama.cpp.git /tmp/llama-src && \ COPY --from=ghcr.io/ggml-org/llama.cpp:full-cuda /app/lib*.so* /usr/local/lib/
cp -r /tmp/llama-src/include/* /usr/local/include/ && \
cp -r /tmp/llama-src/ggml/include/* /usr/local/include/ && \
rm -rf /tmp/llama-src
ENV LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}" WORKDIR /app
WORKDIR /workspace/app
COPY . . COPY . .
# Build the C++ pipeline # Build the C++ pipeline
RUN cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release && \ RUN cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release && \
cmake --build build -j$(nproc) cmake --build build -j$(nproc)
# Co-locate GGML backend plugins with the executable. # --- Stage 2: Runtime Environment (The "Slim" Stage) ---
# ggml_backend_load_all() searches the executable directory first when FROM nvidia/cuda:12.6.3-runtime-ubuntu24.04 AS runtime
# GGML_BACKEND_DIR is not set. Copying the ggml-*.so plugin files here
# ensures the loader finds them without any environment variable. # Install only necessary runtime shared libraries
# libllama.so, libggml.so, and libggml-base.so are NOT copied here — RUN apt-get update && apt-get install -y --no-install-recommends \
# those are proper shared libraries resolved via ldconfig/LD_LIBRARY_PATH. curl \
RUN cp /usr/local/lib/libggml-cuda.so /workspace/app/build/ 2>/dev/null || true && \ ca-certificates \
cp /usr/local/lib/libggml-cpu*.so /workspace/app/build/ 2>/dev/null || true && \ libboost-json1.83.0 \
cp /usr/local/lib/libggml-blas*.so /workspace/app/build/ 2>/dev/null || true && \ libboost-program-options1.83.0 \
cp /usr/local/lib/libggml-rpc*.so /workspace/app/build/ 2>/dev/null || true libgomp1 \
libssl3 \
zlib1g \
&& rm -rf /var/lib/apt/lists/*
ENV APP_ROOT=/app \
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}"
WORKDIR /app/build
# Copy only the compiled binaries from the builder
COPY --from=builder /app/build/biergarten-pipeline ./
# Copy required config files
COPY locations.json /app/build/
COPY beer-styles.json /app/build/
# Copy prompt templates
COPY prompts /app/prompts
# Copy only the necessary shared libraries from builder/llama-bin
COPY --from=ghcr.io/ggml-org/llama.cpp:full-cuda /app/lib*.so* /usr/local/lib/
# Co-locate plugins
RUN cp /usr/local/lib/libggml-cuda.so . 2>/dev/null || true && \
cp /usr/local/lib/libggml-cpu*.so . 2>/dev/null || true
# Setup Start Script # Setup Start Script
COPY runpod/start.sh /usr/local/bin/biergarten-start COPY ./runpod/start.sh /usr/local/bin/biergarten-start
RUN chmod +x /usr/local/bin/biergarten-start RUN chmod +x /usr/local/bin/biergarten-start
WORKDIR /workspace/app/build
ENTRYPOINT ["/usr/local/bin/biergarten-start"] ENTRYPOINT ["/usr/local/bin/biergarten-start"]

View File

@@ -1,49 +1,58 @@
#!/bin/bash #!/bin/bash
set -e set -e
# Configuration / Defaults
MODEL_PATH="${BIERGARTEN_MODEL_PATH:-/workspace/models/google_gemma-4-E4B-it-Q6_K.gguf}" MODEL_PATH="${BIERGARTEN_MODEL_PATH:-/workspace/models/google_gemma-4-E4B-it-Q6_K.gguf}"
OUTPUT_DIR="${BIERGARTEN_OUTPUT_DIR:-/workspace/output}" OUTPUT_DIR="${BIERGARTEN_OUTPUT_DIR:-/workspace/output}"
LOG_PATH="${BIERGARTEN_LOG_PATH:-/workspace/logs/pipeline.log}" LOG_PATH="${BIERGARTEN_LOG_PATH:-/workspace/logs/pipeline.log}"
EXECUTABLE="/workspace/app/build/biergarten-pipeline" EXECUTABLE="/app/build/biergarten-pipeline"
PROMPT_DIR="/workspace/app/build/prompts" PROMPT_DIR="/app/prompts"
echo "--- Starting Biergarten Pipeline Environment Check ---" echo "--- Starting Biergarten Pipeline Environment Check ---"
# 1. Ensure volume mount directories exist # Ensure directories exist
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
mkdir -p "$(dirname "$LOG_PATH")" mkdir -p "$(dirname "$LOG_PATH")"
mkdir -p "$(dirname "$MODEL_PATH")"
# 2. Check for model file # Download model if missing
if [ ! -f "$MODEL_PATH" ]; then if [ ! -f "$MODEL_PATH" ]; then
echo "ERROR: Model not found at $MODEL_PATH" echo "Model not found. Downloading (this may take a while)..."
echo "Current /workspace/models contents:"
ls -lh /workspace/models 2>/dev/null || echo "(directory does not exist)" curl -L -C - \
-o "$MODEL_PATH" \
"https://huggingface.co/bartowski/google_gemma-4-E4B-it-GGUF/resolve/main/google_gemma-4-E4B-it-Q6_K.gguf?download=true"
echo "Download complete."
fi
# Verify model exists
if [ ! -f "$MODEL_PATH" ]; then
echo "ERROR: Model still not found after download attempt."
exit 1 exit 1
fi fi
# 3. Build the command arguments # Default GPU layers
GL_LAYERS="${BIERGARTEN_GL_LAYERS:-40}"
# Build args
ARGS=( ARGS=(
"--model" "$MODEL_PATH" "--model" "$MODEL_PATH"
"--prompt-dir" "$PROMPT_DIR" "--prompt-dir" "$PROMPT_DIR"
"--output" "$OUTPUT_DIR" "--output" "$OUTPUT_DIR"
"--log-path" "$LOG_PATH" "--log-path" "$LOG_PATH"
"--n-gpu-layers" "$GL_LAYERS"
) )
# Optional hyperparameters # Optional params
[[ -n "$BIERGARTEN_TEMPERATURE" ]] && ARGS+=("--temperature" "$BIERGARTEN_TEMPERATURE") [[ -n "$BIERGARTEN_TEMPERATURE" ]] && ARGS+=("--temperature" "$BIERGARTEN_TEMPERATURE")
[[ -n "$BIERGARTEN_TOP_P" ]] && ARGS+=("--top-p" "$BIERGARTEN_TOP_P") [[ -n "$BIERGARTEN_TOP_P" ]] && ARGS+=("--top-p" "$BIERGARTEN_TOP_P")
[[ -n "$BIERGARTEN_TOP_K" ]] && ARGS+=("--top-k" "$BIERGARTEN_TOP_K") [[ -n "$BIERGARTEN_TOP_K" ]] && ARGS+=("--top-k" "$BIERGARTEN_TOP_K")
[[ -n "$BIERGARTEN_N_CTX" ]] && ARGS+=("--n-ctx" "$BIERGARTEN_N_CTX") [[ -n "$BIERGARTEN_N_CTX" ]] && ARGS+=("--n-ctx" "$BIERGARTEN_N_CTX")
[[ -n "$BIERGARTEN_SEED" ]] && ARGS+=("--seed" "$BIERGARTEN_SEED") [[ -n "$BIERGARTEN_SEED" ]] && ARGS+=("--seed" "$BIERGARTEN_SEED")
[[ -n "$BIERGARTEN_GL_LAYERS" ]] && ARGS+=("--n-gpu-layers" "$BIERGARTEN_GL_LAYERS")
# Append any extra custom args # Extra args
if [[ -n "$BIERGARTEN_EXTRA_ARGS" ]]; then [[ -n "$BIERGARTEN_EXTRA_ARGS" ]] && ARGS+=($BIERGARTEN_EXTRA_ARGS)
ARGS+=($BIERGARTEN_EXTRA_ARGS)
fi
echo "--- Executing: $EXECUTABLE ${ARGS[*]} ---" echo "--- Executing: $EXECUTABLE ${ARGS[*]} ---"
# Execute the binary directly, replacing the shell process
exec "$EXECUTABLE" "${ARGS[@]}" exec "$EXECUTABLE" "${ARGS[@]}"

View File

@@ -30,6 +30,8 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
"Context window size in tokens"); "Context window size in tokens");
opt("seed", prog_opts::value<int>()->default_value(sampling_defaults.seed), opt("seed", prog_opts::value<int>()->default_value(sampling_defaults.seed),
"Sampler seed: -1 for random, otherwise non-negative integer"); "Sampler seed: -1 for random, otherwise non-negative integer");
opt("n-gpu-layers", prog_opts::value<int>()->default_value(0),
"Number of layers to offload to GPU");
}; };
// --mocked and --model are mutually exclusive; validation is enforced below // --mocked and --model are mutually exclusive; validation is enforced below
@@ -50,8 +52,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
opt("prompt-dir", prog_opts::value<std::string>()->default_value(""), opt("prompt-dir", prog_opts::value<std::string>()->default_value(""),
"Directory containing named prompt files (e.g. BREWERY_GENERATION.md)." "Directory containing named prompt files (e.g. BREWERY_GENERATION.md)."
" Required when not using --mocked."); " Required when not using --mocked.");
opt("n-gpu-layers", prog_opts::value<int>()->default_value(0), opt("location-count", prog_opts::value<uint32_t>()->default_value(10));
"Number of layers to offload to GPU");
}; };
add_sampling_options(); add_sampling_options();
@@ -84,6 +85,8 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
options.pipeline.output_path = var_map["output"].as<std::string>(); options.pipeline.output_path = var_map["output"].as<std::string>();
options.pipeline.log_path = var_map["log-path"].as<std::string>(); options.pipeline.log_path = var_map["log-path"].as<std::string>();
options.pipeline.prompt_dir = var_map["prompt-dir"].as<std::string>(); options.pipeline.prompt_dir = var_map["prompt-dir"].as<std::string>();
options.pipeline.location_count =
var_map["location-count"].as<uint32_t>();
const bool use_mocked = var_map["mocked"].as<bool>(); const bool use_mocked = var_map["mocked"].as<bool>();
const std::string model_path = var_map["model"].as<std::string>(); const std::string model_path = var_map["model"].as<std::string>();
@@ -113,7 +116,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
options.generator.use_mocked = use_mocked; options.generator.use_mocked = use_mocked;
options.generator.model_path = model_path; options.generator.model_path = model_path;
options.generator.n_gpu_layers = n_gpu_layers; // options.generator.n_gpu_layers = n_gpu_layers;
// Only populate sampling config when the user explicitly overrides at // Only populate sampling config when the user explicitly overrides at
// least one value. Leaving it as std::nullopt lets LlamaGenerator fall // least one value. Leaving it as std::nullopt lets LlamaGenerator fall
@@ -122,7 +125,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
const bool user_provided_sampling = const bool user_provided_sampling =
!var_map["temperature"].defaulted() || !var_map["top-p"].defaulted() || !var_map["temperature"].defaulted() || !var_map["top-p"].defaulted() ||
!var_map["top-k"].defaulted() || !var_map["n-ctx"].defaulted() || !var_map["top-k"].defaulted() || !var_map["n-ctx"].defaulted() ||
!var_map["seed"].defaulted(); !var_map["seed"].defaulted() || !var_map["n_gpu_layers"].defaulted();
if (user_provided_sampling) { if (user_provided_sampling) {
// Warn but do not fail — the run is still valid, the flags are just // Warn but do not fail — the run is still valid, the flags are just
@@ -136,6 +139,7 @@ std::optional<ApplicationOptions> ParseArguments(const int argc, char** argv) {
sampling.top_k = var_map["top-k"].as<uint32_t>(); sampling.top_k = var_map["top-k"].as<uint32_t>();
sampling.n_ctx = var_map["n-ctx"].as<uint32_t>(); sampling.n_ctx = var_map["n-ctx"].as<uint32_t>();
sampling.seed = var_map["seed"].as<int>(); sampling.seed = var_map["seed"].as<int>();
sampling.n_gpu_layers = var_map["n-gpu-layers"].as<int>();
options.generator.sampling = sampling; options.generator.sampling = sampling;
} }

View File

@@ -10,7 +10,9 @@
BiergartenDataGenerator::BiergartenDataGenerator( BiergartenDataGenerator::BiergartenDataGenerator(
std::unique_ptr<IEnrichmentService> context_service, std::unique_ptr<IEnrichmentService> context_service,
std::unique_ptr<DataGenerator> generator, std::unique_ptr<DataGenerator> generator,
std::unique_ptr<IExportService> exporter) std::unique_ptr<IExportService> exporter,
const ApplicationOptions &app_options)
: context_service_(std::move(context_service)), : context_service_(std::move(context_service)),
generator_(std::move(generator)), generator_(std::move(generator)),
exporter_(std::move(exporter)) {} exporter_(std::move(exporter)),
application_options_(app_options) {}

View File

@@ -13,8 +13,6 @@
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
#include "json_handling/json_loader.h" #include "json_handling/json_loader.h"
static constexpr size_t kBreweryAmount = 50;
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() { std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
@@ -23,7 +21,9 @@ std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
auto all_locations = JsonLoader::LoadLocations(locations_path); auto all_locations = JsonLoader::LoadLocations(locations_path);
spdlog::info(" Locations available: {}", all_locations.size()); spdlog::info(" Locations available: {}", all_locations.size());
const size_t sample_count = std::min(kBreweryAmount, all_locations.size()); const size_t sample_count = std::min(
static_cast<size_t>(application_options_.pipeline.location_count),
all_locations.size());
const auto sample_count_signed = const auto sample_count_signed =
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>( static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(

View File

@@ -21,8 +21,8 @@ bool BiergartenDataGenerator::Run() {
for (auto& city : cities) { for (auto& city : cities) {
try { try {
std::string region_context = context_service_->GetLocationContext(city); std::string region_context = context_service_->GetLocationContext(city);
spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}", // spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}",
city.city, city.country, region_context); // city.city, city.iso3166_2, region_context);
enriched.push_back( enriched.push_back(
EnrichedCity{.location = std::move(city), EnrichedCity{.location = std::move(city),

View File

@@ -89,7 +89,7 @@ LlamaGenerator::LlamaGenerator(
} }
n_ctx_ = sampling.n_ctx; n_ctx_ = sampling.n_ctx;
n_gpu_layers_ = options.generator.n_gpu_layers; n_gpu_layers_ = sampling.n_gpu_layers;
this->Load(model_path); this->Load(model_path);
} }

View File

@@ -8,11 +8,9 @@
#include <boost/di.hpp> #include <boost/di.hpp>
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include <exception> #include <exception>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string> #include <string>
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
@@ -21,12 +19,13 @@
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h" #include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
#include "data_model/models.h" #include "data_model/models.h"
#include "llama_backend_state.h" #include "llama_backend_state.h"
#include "services/enrichment/enrichment_service.h"
#include "services/database/export_service.h" #include "services/database/export_service.h"
#include "services/prompting/prompt_directory.h"
#include "services/database/sqlite_export_service.h" #include "services/database/sqlite_export_service.h"
#include "services/datetime/timer.h" #include "services/datetime/timer.h"
#include "services/enrichment/enrichment_service.h"
#include "services/enrichment/mock_enrichment.h"
#include "services/enrichment/wikipedia_service.h" #include "services/enrichment/wikipedia_service.h"
#include "services/prompting/prompt_directory.h"
#include "web_client/http_web_client.h" #include "web_client/http_web_client.h"
namespace di = boost::di; namespace di = boost::di;
@@ -43,7 +42,9 @@ int main(const int argc, char** argv) {
spdlog::set_level(spdlog::level::debug); spdlog::set_level(spdlog::level::debug);
#endif #endif
const auto parsed_options = ParseArguments(argc, argv); const std::optional<ApplicationOptions> parsed_options =
ParseArguments(argc, argv);
if (!parsed_options.has_value()) { if (!parsed_options.has_value()) {
return 0; return 0;
} }
@@ -65,12 +66,20 @@ int main(const int argc, char** argv) {
} }
const auto injector = di::make_injector( const auto injector = di::make_injector(
di::bind<WebClient>().to<HttpWebClient>(),
di::bind<ApplicationOptions>().to(options), di::bind<ApplicationOptions>().to(options),
di::bind<IEnrichmentService>().to<WikipediaService>(), di::bind<std::string>().to(model_path),
di::bind<WebClient>().to<HttpWebClient>(),
di::bind<IExportService>().to<SqliteExportService>(), di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(), di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<std::string>().to(model_path), di::bind<IEnrichmentService>().to(
[options](const auto& inj) -> std::unique_ptr<IEnrichmentService> {
if (options.generator.use_mocked) {
return std::make_unique<MockEnrichmentService>();
}
return std::make_unique<WikipediaEnrichmentService>(
inj.template create<std::unique_ptr<WebClient>>());
}),
di::bind<DataGenerator>().to( di::bind<DataGenerator>().to(
[options, model_path, sampling, &prompt_directory]( [options, model_path, sampling, &prompt_directory](
const auto& inj) -> std::unique_ptr<DataGenerator> { const auto& inj) -> std::unique_ptr<DataGenerator> {
@@ -89,9 +98,11 @@ int main(const int argc, char** argv) {
options, model_path, options, model_path,
inj.template create<std::unique_ptr<IPromptFormatter>>(), inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory)); std::move(prompt_directory));
})); })
auto generator = );
const auto generator =
injector.create<std::unique_ptr<BiergartenDataGenerator>>(); injector.create<std::unique_ptr<BiergartenDataGenerator>>();
if (!generator->Run()) { if (!generator->Run()) {

View File

@@ -0,0 +1,112 @@
/**
* @file wikipedia/fetch_extract.cc
*/
#include <spdlog/spdlog.h>
#include <boost/json.hpp>
#include <chrono>
#include <format>
#include <string>
#include <string_view>
#include <thread>
#include "services/enrichment/wikipedia_service.h"
using namespace boost;
std::string WikipediaEnrichmentService::FetchExtract(std::string_view query) {
const std::string cache_key(query);
// 1. Cache Lookup
if (const auto cache_it = this->extract_cache_.find(cache_key);
cache_it != this->extract_cache_.end()) {
spdlog::debug("Wikipedia: Cache hit for {}!", cache_key);
return cache_it->second;
}
const std::string encoded = this->client_->EncodeURL(cache_key);
const std::string url = std::format(
"https://en.wikipedia.org/w/"
"api.php?action=query&titles={}&prop=extracts&explaintext=1&format=json",
encoded);
const std::string body = this->client_->Get(url);
{
using namespace std::literals::chrono_literals;
std::this_thread::sleep_for(1s);
}
// 2. Parse JSON
system::error_code ec;
json::value doc = json::parse(body, ec);
if (ec) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
ec.message());
return {};
}
// 3. Safe Extraction
const json::object* obj = doc.if_object();
if (obj == nullptr) {
spdlog::warn("WikipediaService: Expected root object for '{}'", query);
return {};
}
const json::value* query_ptr = obj->if_contains("query");
const json::value* pages_ptr =
((query_ptr != nullptr) && query_ptr->is_object())
? query_ptr->get_object().if_contains("pages")
: nullptr;
if ((pages_ptr == nullptr) || !pages_ptr->is_object()) {
spdlog::warn("WikipediaService: Missing query.pages for '{}'", query);
return {};
}
const json::object& pages = pages_ptr->get_object();
if (pages.empty()) {
spdlog::warn("WikipediaService: No pages returned for '{}'", query);
this->extract_cache_.emplace(cache_key, "");
return {};
}
// Wikipedia returns the page under a dynamic ID key; we just want the first
// one
const json::value& page_val = pages.begin()->value();
if (!page_val.is_object()) {
spdlog::warn("WikipediaService: Unexpected page format for '{}'", query);
return {};
}
const json::object& page = page_val.get_object();
// Handle 404/Missing status
if (page.contains("missing")) {
spdlog::warn("WikipediaService: Page '{}' does not exist", query);
this->extract_cache_.emplace(cache_key, "");
return {};
}
const json::value* extract_ptr = page.if_contains("extract");
if ((extract_ptr == nullptr) || !extract_ptr->is_string()) {
spdlog::warn("WikipediaService: No extract string found for '{}'", query);
this->extract_cache_.emplace(cache_key, "");
return {};
}
// 4. Success
std::string extract(extract_ptr->as_string());
spdlog::info("WikipediaService: Fetched {} chars for '{}'", extract.size(),
query);
this->extract_cache_.insert_or_assign(cache_key, extract);
return extract;
}

View File

@@ -0,0 +1,58 @@
/**
* @file wikipedia/get_summary.cc
* @brief WikipediaService::GetLocationContext() implementation.
*/
#include <spdlog/spdlog.h>
#include <chrono>
#include <format>
#include <string>
#include <thread>
#include "services/enrichment/wikipedia_service.h"
std::string WikipediaEnrichmentService::GetLocationContext(const Location& loc) {
using namespace std::literals::chrono_literals;
if (!this->client_) {
spdlog::warn("Client is nullptr.");
return {};
}
std::string result;
// std::string region_query(loc.city);
// if (!loc.country.empty()) {
// region_query += loc.state_province,
// region_query += ", ";
// region_query += loc.country;
// }
constexpr std::string_view brewing_query = "brewing";
const std::string location_query =
std::format("{}, {}", loc.city, loc.iso3166_2);
const std::string beer_query = std::format("beer in {}", loc.country);
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
return;
}
if (!result.empty()) {
result += "\n\n";
}
result += extract;
};
try {
append_extract(FetchExtract(brewing_query));
append_extract(FetchExtract(beer_query));
spdlog::info("Done fetching for {}. Sleeping for 10 seconds.",
location_query);
std::this_thread::sleep_for(10s);
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", location_query,
e.what());
}
return result;
}

View File

@@ -7,5 +7,6 @@
#include <utility> #include <utility>
WikipediaService::WikipediaService(std::unique_ptr<WebClient> client) WikipediaEnrichmentService::WikipediaEnrichmentService(
std::unique_ptr<WebClient> client)
: client_(std::move(client)) {} : client_(std::move(client)) {}

View File

@@ -1,61 +0,0 @@
/**
* @file wikipedia/fetch_extract.cc
* @brief WikipediaService::FetchExtract() implementation.
*/
#include <spdlog/spdlog.h>
#include <boost/json.hpp>
#include <string>
#include <string_view>
#include "services/enrichment/wikipedia_service.h"
std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string cache_key(query);
const auto cache_it = this->extract_cache_.find(cache_key);
if (cache_it != this->extract_cache_.end()) {
return cache_it->second;
}
const std::string encoded = this->client_->UrlEncode(cache_key);
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=1&format=json";
const std::string body = this->client_->Get(url);
boost::system::error_code parse_error;
boost::json::value doc = boost::json::parse(body, parse_error);
if (!parse_error && doc.is_object()) {
try {
auto& pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto& page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
const std::string_view extract_view = page.at("extract").as_string();
std::string extract(extract_view);
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
this->extract_cache_.emplace(cache_key, extract);
return extract;
}
}
this->extract_cache_.emplace(cache_key, std::string{});
} catch (const std::exception& e) {
spdlog::warn(
"WikipediaService: failed to parse response structure for '{}': "
"{}",
query, e.what());
return {};
}
} else if (parse_error) {
spdlog::warn("WikipediaService: JSON parse error for '{}': {}", query,
parse_error.message());
}
return {};
}

View File

@@ -1,47 +0,0 @@
/**
* @file wikipedia/get_summary.cc
* @brief WikipediaService::GetLocationContext() implementation.
*/
#include <spdlog/spdlog.h>
#include <string>
#include "services/enrichment/wikipedia_service.h"
std::string WikipediaService::GetLocationContext(const Location& loc) {
if (!client_) {
return {};
}
std::string result;
std::string region_query(loc.city);
if (!loc.country.empty()) {
region_query += ", ";
region_query += loc.country;
}
const std::string beer_query = "beer in " + loc.country;
const std::string city_beer_query = "beer in " + loc.city;
auto append_extract = [&result](const std::string& extract) -> void {
if (extract.empty()) {
return;
}
if (!result.empty()) {
result += "\n\n";
}
result += extract;
};
try {
append_extract(FetchExtract(region_query));
append_extract(FetchExtract(beer_query));
append_extract(FetchExtract(city_beer_query));
} catch (const std::runtime_error& e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", region_query,
e.what());
}
return result;
}

View File

@@ -12,6 +12,8 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "spdlog/spdlog.h"
namespace { namespace {
constexpr time_t kConnectionTimeoutSeconds = 5; constexpr time_t kConnectionTimeoutSeconds = 5;
constexpr time_t kReadTimeoutSeconds = 10; constexpr time_t kReadTimeoutSeconds = 10;
@@ -38,8 +40,12 @@ std::string HttpWebClient::Get(const std::string& url) {
client.set_follow_location(true); client.set_follow_location(true);
client.set_connection_timeout(kConnectionTimeoutSeconds); client.set_connection_timeout(kConnectionTimeoutSeconds);
client.set_read_timeout(kReadTimeoutSeconds); client.set_read_timeout(kReadTimeoutSeconds);
client.set_default_headers({
{"Accept", "application/json"},
{"User-Agent", "biergarten-pipeline/1.0"}
});
const auto result = client.Get(path); const httplib::Result result = client.Get(path);
if (!result) { if (!result) {
throw std::runtime_error( throw std::runtime_error(
@@ -48,6 +54,7 @@ std::string HttpWebClient::Get(const std::string& url) {
} }
if (result->status < kSuccessMin || result->status >= kSuccessMax) { if (result->status < kSuccessMin || result->status >= kSuccessMax) {
spdlog::error("[HttpWebClient] Request failed for URL: " + url);
throw std::runtime_error( throw std::runtime_error(
"[HttpWebClient] HTTP " + std::to_string(result->status) + "[HttpWebClient] HTTP " + std::to_string(result->status) +
" for URL: " + url); " for URL: " + url);
@@ -56,6 +63,6 @@ std::string HttpWebClient::Get(const std::string& url) {
return result->body; return result->body;
} }
std::string HttpWebClient::UrlEncode(const std::string& value) { std::string HttpWebClient::EncodeURL(const std::string& value) {
return httplib::encode_uri_component(value); return httplib::encode_uri_component(value);
} }