Compare commits

6 Commits

Author SHA1 Message Date
Aaron Po
eb9a2767b4 Refactor web client interface and related components 2026-04-02 18:55:58 -04:00
Aaron Po
29ea47fdb6 update cli arg handling 2026-04-02 18:41:25 -04:00
Aaron Po
52e2333304 Reorganize directory structure 2026-04-02 18:27:01 -04:00
Aaron Po
a1f0ca5b20 Refactor DataDownloader and CURLWebClient: update constructor and modify FileExists method signature 2026-04-02 18:06:40 -04:00
Aaron Po
2ea8aa52b4 update readme and add clangformat and clang tidy 2026-04-02 17:12:22 -04:00
Aaron Po
98083ab40c Pipeline: add CURL/WebClient & Wikipedia service
Introduce a pluggable web client interface and concrete CURL implementation: adds IWebClient, CURLWebClient, and CurlGlobalState (headers + curl_web_client.cpp). DataDownloader now accepts an IWebClient and delegates downloads. Add WikipediaService for cached Wikipedia summary lookups. Refactor SqliteDatabase to return full City records and update consumers accordingly. Improve JsonLoader to use batched transactions during streaming parses. Enhance LlamaGenerator with sampling options, increased token limits, JSON extraction/validation, and other parsing helpers. Modernize CMake: set policy/version, add project_options, simplify FetchContent usage (spdlog), require Boost components (program_options/json), list pipeline sources explicitly, and tweak post-build/memcheck targets. Update README to match implementation changes and new CLI/config conventions.
2026-04-02 16:29:16 -04:00
32 changed files with 2221 additions and 1379 deletions

10
pipeline/.clang-format Normal file
View File

@@ -0,0 +1,10 @@
---
BasedOnStyle: Google
Standard: c++23
ColumnLimit: 100
IndentWidth: 2
DerivePointerAlignment: false
PointerAlignment: Left
SortIncludes: true
IncludeBlocks: Preserve
...

17
pipeline/.clang-tidy Normal file
View File

@@ -0,0 +1,17 @@
---
Checks: >
-*,
bugprone-*,
clang-analyzer-*,
cppcoreguidelines-*,
google-*,
modernize-*,
performance-*,
readability-*,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-owning-memory,
-readability-magic-numbers,
-google-readability-todo
HeaderFilterRegex: "^(src|includes)/.*"
FormatStyle: file
...

View File

@@ -1,49 +1,66 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
project(biergarten-pipeline VERSION 0.1.0 LANGUAGES CXX) project(biergarten-pipeline VERSION 0.1.0 LANGUAGES CXX)
cmake_policy(SET CMP0167 NEW) # Allows older dependencies to configure on newer CMake.
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
# Policies
cmake_policy(SET CMP0167 NEW) # FindBoost improvements
# Global Settings
set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
option(ENABLE_CLANG_TIDY "Enable clang-tidy static analysis for project targets" ON)
option(ENABLE_CLANG_FORMAT_TARGETS "Enable clang-format helper targets" ON)
if(ENABLE_CLANG_TIDY)
find_program(CLANG_TIDY_EXE NAMES clang-tidy)
if(CLANG_TIDY_EXE)
set(BIERGARTEN_CLANG_TIDY_COMMAND
"${CLANG_TIDY_EXE};--config-file=${CMAKE_CURRENT_SOURCE_DIR}/.clang-tidy")
message(STATUS "clang-tidy enabled: ${CLANG_TIDY_EXE}")
else()
message(STATUS "clang-tidy not found; static analysis is disabled")
endif()
endif()
# -----------------------------------------------------------------------------
# Compiler Options & Warnings (Interface Library)
# -----------------------------------------------------------------------------
add_library(project_options INTERFACE)
target_compile_options(project_options INTERFACE
$<$<CXX_COMPILER_ID:GNU,Clang>:
-Wall -Wextra -Wpedantic -Wshadow -Wconversion -Wsign-conversion -Wunused
>
$<$<CXX_COMPILER_ID:MSVC>:
/W4 /WX /permissive-
>
)
# -----------------------------------------------------------------------------
# Dependencies
# -----------------------------------------------------------------------------
find_package(CURL REQUIRED) find_package(CURL REQUIRED)
find_package(Boost REQUIRED COMPONENTS unit_test_framework)
find_package(SQLite3 REQUIRED) find_package(SQLite3 REQUIRED)
find_package(Boost 1.75 REQUIRED COMPONENTS program_options json)
include(FetchContent) include(FetchContent)
# RapidJSON (header-only) for true SAX parsing # spdlog (Logging)
# Using direct header-only approach without CMakeLists.txt
FetchContent_Declare(
rapidjson
GIT_REPOSITORY https://github.com/Tencent/rapidjson.git
GIT_TAG v1.1.0
SOURCE_SUBDIR "" # Don't use RapidJSON's CMakeLists.txt
)
FetchContent_GetProperties(rapidjson)
if(NOT rapidjson_POPULATED)
FetchContent_Populate(rapidjson)
# RapidJSON is header-only; just make include path available
endif()
# spdlog (logging)
FetchContent_Declare( FetchContent_Declare(
spdlog spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.11.0 GIT_TAG v1.11.0
) )
FetchContent_GetProperties(spdlog) FetchContent_MakeAvailable(spdlog)
if(NOT spdlog_POPULATED)
FetchContent_Populate(spdlog)
add_subdirectory(${spdlog_SOURCE_DIR} ${spdlog_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
# llama.cpp (on-device inference) # llama.cpp (LLM Inference)
set(LLAMA_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(LLAMA_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(LLAMA_BUILD_SERVER OFF CACHE BOOL "" FORCE) set(LLAMA_BUILD_SERVER OFF CACHE BOOL "" FORCE)
FetchContent_Declare( FetchContent_Declare(
llama_cpp llama_cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
@@ -57,90 +74,86 @@ if(TARGET llama)
) )
endif() endif()
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS # -----------------------------------------------------------------------------
src/*.cpp # Main Executable
# -----------------------------------------------------------------------------
set(PIPELINE_SOURCES
src/biergarten_data_generator.cpp
src/web_client/curl_web_client.cpp
src/data_generation/data_downloader.cpp
src/database/database.cpp
src/json_handling/json_loader.cpp
src/data_generation/llama_generator.cpp
src/data_generation/mock_generator.cpp
src/json_handling/stream_parser.cpp
src/wikipedia/wikipedia_service.cpp
src/main.cpp
) )
add_executable(biergarten-pipeline ${SOURCES}) add_executable(biergarten-pipeline ${PIPELINE_SOURCES})
if(BIERGARTEN_CLANG_TIDY_COMMAND)
set_target_properties(biergarten-pipeline PROPERTIES
CXX_CLANG_TIDY "${BIERGARTEN_CLANG_TIDY_COMMAND}"
)
endif()
target_include_directories(biergarten-pipeline target_include_directories(biergarten-pipeline
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/includes ${CMAKE_CURRENT_SOURCE_DIR}/includes
${rapidjson_SOURCE_DIR}/include
${llama_cpp_SOURCE_DIR}/include ${llama_cpp_SOURCE_DIR}/include
) )
target_link_libraries(biergarten-pipeline target_link_libraries(biergarten-pipeline
PRIVATE PRIVATE
project_options
CURL::libcurl CURL::libcurl
Boost::unit_test_framework
SQLite::SQLite3 SQLite::SQLite3
spdlog::spdlog spdlog::spdlog
llama llama
Boost::program_options
Boost::json
) )
target_compile_options(biergarten-pipeline PRIVATE if(ENABLE_CLANG_FORMAT_TARGETS)
$<$<CXX_COMPILER_ID:GNU,Clang>: find_program(CLANG_FORMAT_EXE NAMES clang-format)
-Wall if(CLANG_FORMAT_EXE)
-Wextra file(GLOB_RECURSE FORMAT_SOURCES CONFIGURE_DEPENDS
-Wpedantic ${CMAKE_CURRENT_SOURCE_DIR}/src/**/*.cpp
-Wshadow ${CMAKE_CURRENT_SOURCE_DIR}/src/**/*.cc
-Wconversion ${CMAKE_CURRENT_SOURCE_DIR}/includes/**/*.h
-Wsign-conversion ${CMAKE_CURRENT_SOURCE_DIR}/includes/**/*.hpp
> )
$<$<CXX_COMPILER_ID:MSVC>:
/W4
/WX
>
)
add_custom_target(format
COMMAND ${CLANG_FORMAT_EXE} -style=file -i ${FORMAT_SOURCES}
COMMENT "Formatting source files with clang-format (Google style)"
VERBATIM
)
add_custom_target(format-check
COMMAND ${CLANG_FORMAT_EXE} -style=file --dry-run --Werror ${FORMAT_SOURCES}
COMMENT "Checking source formatting with clang-format (Google style)"
VERBATIM
)
else()
message(STATUS "clang-format not found; format targets are disabled")
endif()
endif()
# -----------------------------------------------------------------------------
# Post-Build Steps & Utilities
# -----------------------------------------------------------------------------
add_custom_command(TARGET biergarten-pipeline POST_BUILD add_custom_command(TARGET biergarten-pipeline POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_SOURCE_DIR}/output
${CMAKE_CURRENT_SOURCE_DIR}/output COMMENT "Ensuring output directory exists"
COMMENT "Creating output/ directory for seed SQL files"
) )
find_program(VALGRIND valgrind) find_program(VALGRIND valgrind)
if(VALGRIND) if(VALGRIND)
add_custom_target(memcheck add_custom_target(memcheck
COMMAND ${VALGRIND} COMMAND ${VALGRIND} --leak-check=full --error-exitcode=1 $<TARGET_FILE:biergarten-pipeline> --help
--leak-check=full
--error-exitcode=1
$<TARGET_FILE:biergarten-pipeline> --help
DEPENDS biergarten-pipeline DEPENDS biergarten-pipeline
COMMENT "Running Valgrind memcheck" COMMENT "Running Valgrind memory check"
) )
endif() endif()
include(CTest)
if(BUILD_TESTING)
find_package(Boost REQUIRED COMPONENTS unit_test_framework)
file(GLOB_RECURSE TEST_SOURCES CONFIGURE_DEPENDS
tests/*.cpp
tests/*.cc
tests/*.cxx
)
if(TEST_SOURCES)
add_executable(biergarten-pipeline-tests ${TEST_SOURCES})
target_include_directories(biergarten-pipeline-tests
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include
)
target_link_libraries(biergarten-pipeline-tests
PRIVATE
Boost::unit_test_framework
CURL::libcurl
nlohmann_json::nlohmann_json
)
add_test(
NAME biergarten-pipeline-tests
COMMAND biergarten-pipeline-tests
)
endif()
endif()

View File

@@ -1,64 +1,56 @@
## Biergarten Pipeline # Biergarten Pipeline
A high-performance C++23 data pipeline for fetching, parsing, and storing geographic data (countries, states, cities) with brewery metadata generation capabilities. The system supports both mock and LLM-based (llama.cpp) generation modes.
## Overview ## Overview
The pipeline orchestrates five key stages: The pipeline orchestrates **four key stages**:
1. **Download**: Fetches `countries+states+cities.json` from a pinned GitHub commit with optional local caching. 1. **Download** - Fetches `countries+states+cities.json` from a pinned GitHub commit with optional local filesystem caching
2. **Parse**: Streams JSON using RapidJSON SAX parser, extracting country/state/city records without loading the entire file into memory. 2. **Parse** - Streams JSON using Boost.JSON's `basic_parser` to extract country/state/city records without loading the entire file into memory
3. **Buffer**: Routes city records through a bounded concurrent queue to decouple parsing from writes. 3. **Store** - Inserts records into a file-based SQLite database with all operations performed sequentially in a single thread
4. **Store**: Inserts records with concurrent thread safety using an in-memory SQLite database. 4. **Generate** - Produces brewery metadata or user profiles (mock implementation; supports future LLM integration via llama.cpp)
5. **Generate**: Produces mock brewery metadata for a sample of cities (mockup for future LLM integration).
--- ## System Architecture
## Architecture
### Data Sources and Formats ### Data Sources and Formats
- Hierarchical structure: countries array → states per country → cities per state. - **Hierarchical Structure**: Countries array → states per country → cities per state
- Fields: `id` (integer), `name` (string), `iso2` / `iso3` (codes), `latitude` / `longitude`. - **Data Fields**:
- Sourced from: [dr5hn/countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) on GitHub. - `id` (integer)
- `name` (string)
- `iso2` / `iso3` (ISO country/state codes)
- `latitude` / `longitude` (geographic coordinates)
- **Source**: [dr5hn/countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) on GitHub
- **Output**: Structured SQLite file-based database (`biergarten-pipeline.db`) + structured logging via spdlog
**Output**: Structured SQLite in-memory database + console logs via spdlog. ### Concurrency Model
### Concurrency Architecture The pipeline currently operates **single-threaded** with sequential stage execution:
The pipeline splits work across parsing and writing phases: 1. **Download Phase**: Main thread blocks while downloading the source JSON file (if not in cache)
2. **Parse & Store Phase**: Main thread performs streaming JSON parse with immediate SQLite inserts
``` **Thread Safety**: While single-threaded, the `SqliteDatabase` component is **mutex-protected** using `std::mutex` (`dbMutex`) for all database operations. This design enables safe future parallelization without code modifications.
Main Thread:
parse_sax() -> Insert countries (direct)
-> Insert states (direct)
-> Push CityRecord to WorkQueue
Worker Threads (implicit; pthread pool via sqlite3): ## Core Components
Pop CityRecord from WorkQueue
-> InsertCity(db) with mutex protection
```
**Key synchronization primitives**: | Component | Purpose | Thread Safety | Dependencies |
| ----------------------------- | ----------------------------------------------------------------------------------------------- | -------------------------------------------- | --------------------------------------------- |
- **WorkQueue<T>**: Bounded (default 1024 items) concurrent queue with blocking push/pop, guarded by mutex + condition variables. | **BiergartenDataGenerator** | Orchestrates pipeline execution; manages lifecycle of downloader, parser, and generator | Single-threaded coordinator | ApplicationOptions, WebClient, SqliteDatabase |
- **SqliteDatabase::dbMutex**: Serializes all SQLite operations to avoid `SQLITE_BUSY` and ensure write safety. | **DataDownloader** | HTTP fetch with curl; optional filesystem cache; ETag support and retries | Blocking I/O; safe for startup | IWebClient, filesystem |
| **StreamingJsonParser** | Extends `boost::json::basic_parser`; emits country/state/city via callbacks; tracks parse depth | Single-threaded parse; callbacks thread-safe | Boost.JSON |
**Backpressure**: When the WorkQueue fills (≥1024 city records pending), the parser thread blocks until workers drain items. | **JsonLoader** | Wraps parser; dispatches callbacks for country/state/city; manages WorkQueue lifecycle | Produces to WorkQueue; safe callbacks | StreamingJsonParser, SqliteDatabase |
| **SqliteDatabase** | Manages schema initialization; insert/query methods for geographic data | Mutex-guarded all operations | SQLite3 |
### Component Responsibilities | **IDataGenerator** (Abstract) | Interface for brewery/user metadata generation | Stateless virtual methods | N/A |
| **LlamaGenerator** | LLM-based generation via llama.cpp; configurable sampling (temperature, top-p, seed) | Manages llama_model* and llama_context* | llama.cpp, BreweryResult, UserResult |
| Component | Purpose | Thread Safety | | **MockGenerator** | Deterministic mock generation using seeded randomization | Stateless; thread-safe | N/A |
| ------------------------- | ------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- | | **CURLWebClient** | HTTP client adapter; URL encoding; file downloads | cURL library bindings | libcurl |
| **DataDownloader** | GitHub fetch with curl; optional filesystem cache; handles retries and ETags. | Blocking I/O; safe for single-threaded startup. | | **WikipediaService** | (Planned) Wikipedia data lookups for enrichment | N/A | IWebClient |
| **StreamingJsonParser** | SAX-style RapidJSON handler; emits country/state/city via callbacks; tracks parse state (array depth, key context). | Single-threaded parse phase; thread-safe callbacks. |
| **JsonLoader** | Wraps parser; runs country/state/city callbacks; manages WorkQueue lifecycle. | Produces to WorkQueue; consumes from callbacks. |
| **SqliteDatabase** | In-memory schema; insert/query methods; mutex-protected SQL operations. | Mutex-guarded; thread-safe concurrent inserts. |
| **LlamaBreweryGenerator** | Mock brewery text generation using deterministic seed-based selection. | Stateless; thread-safe method calls. |
---
## Database Schema ## Database Schema
**SQLite in-memory database** with three core tables: SQLite file-based database with **three core tables** and **indexes for fast lookups**:
### Countries ### Countries
@@ -102,236 +94,235 @@ CREATE INDEX idx_cities_state ON cities(state_id);
CREATE INDEX idx_cities_country ON cities(country_id); CREATE INDEX idx_cities_country ON cities(country_id);
``` ```
**Design rationale**: ## Architecture Diagram
- In-memory for performance (no persistent storage; data is regenerated on each run). ```plantuml
- Foreign keys for referential integrity (optional in SQLite, but enforced in schema). @startuml biergarten-pipeline
- Indexes on foreign keys for fast lookups during brewery generation. !theme plain
- Dual country_id in cities table for direct queries without state joins. skinparam monochrome true
skinparam classBackgroundColor #FFFFFF
skinparam classBorderColor #000000
--- package "Application Layer" {
class BiergartenDataGenerator {
- options: ApplicationOptions
- webClient: IWebClient
- database: SqliteDatabase
- generator: IDataGenerator
--
+ Run() : int
}
}
## Data Flow package "Data Acquisition" {
class DataDownloader {
- webClient: IWebClient
--
+ Download(url: string, filePath: string)
+ DownloadWithCache(url: string, cachePath: string)
}
### Parse Phase (Main Thread) interface IWebClient {
+ DownloadToFile(url: string, filePath: string)
+ Get(url: string) : string
+ UrlEncode(value: string) : string
}
1. **DataDownloader::DownloadCountriesDatabase()** class CURLWebClient {
- Constructs GitHub raw-content URL: `https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/{commit}/countries+states+cities.json` - globalState: CurlGlobalState
- Uses curl with `FOLLOWLOCATION` and timeout. --
- Caches locally; checks ETag for freshness. + DownloadToFile(url: string, filePath: string)
+ Get(url: string) : string
+ UrlEncode(value: string) : string
}
}
2. **StreamingJsonParser::Parse()** package "JSON Processing" {
- Opens file stream; initializes RapidJSON SAX parser with custom handler. class StreamingJsonParser {
- Handler state: tracks `current_country_id`, `current_state_id`, array nesting, object key context. - depth: int
- **Country processing** (inline): When country object completes, calls `db.InsertCountry()` directly on main thread. --
- **State processing** (inline): When state object completes, calls `db.InsertState()` directly. + on_object_begin()
- **City processing** (buffered): When city object completes, pushes `CityRecord` to `JsonLoader`'s WorkQueue; unblocks if `onProgress` callback is registered. + on_object_end()
+ on_array_begin()
+ on_array_end()
+ on_key(str: string)
+ on_string(str: string)
+ on_number(value: int)
}
3. **JsonLoader::LoadWorldCities()** class JsonLoader {
- Registers callbacks with parser. --
- Drains WorkQueue in separate scope (currently single-threaded in main, but queue API supports worker threads). + LoadWorldCities(jsonPath: string, db: SqliteDatabase)
- Each city is inserted via `db.InsertCity()`. }
}
### Query and Generation Phase (Main Thread) package "Data Storage" {
class SqliteDatabase {
- db: sqlite3*
- dbMutex: std::mutex
--
+ Initialize(dbPath: string)
+ InsertCountry(id: int, name: string, iso2: string, iso3: string)
+ InsertState(id: int, countryId: int, name: string, iso2: string)
+ InsertCity(id: int, stateId: int, countryId: int, name: string, lat: double, lon: double)
+ QueryCountries(limit: int) : vector<Country>
+ QueryStates(limit: int) : vector<State>
+ QueryCities() : vector<City>
+ BeginTransaction()
+ CommitTransaction()
# InitializeSchema()
}
4. **Database Queries** struct Country {
- `QueryCountries(limit)`: Retrieve countries; used for progress display. id: int
- `QueryStates(limit)`: Retrieve states; used for progress display. name: string
- `QueryCities()`: Retrieve all city ids + names for brewery generation. iso2: string
iso3: string
}
5. **Brewery Generation** struct State {
- For each city sample, call `LlamaBreweryGenerator::GenerateBrewery(cityName, seed)`. id: int
- Deterministic: same seed always produces same brewery (useful for reproducible test data). name: string
- Returns `{ name, description }` struct. iso2: string
countryId: int
}
--- struct City {
id: int
name: string
countryId: int
}
}
## Concurrency Deep Dive package "Data Generation" {
interface IDataGenerator {
+ load(modelPath: string)
+ generateBrewery(cityName: string, countryName: string, regionContext: string) : BreweryResult
+ generateUser(locale: string) : UserResult
}
### WorkQueue<T> class LlamaGenerator {
- model: llama_model*
- context: llama_context*
- sampling_temperature: float
- sampling_top_p: float
- sampling_seed: uint32_t
--
+ load(modelPath: string)
+ generateBrewery(...) : BreweryResult
+ generateUser(locale: string) : UserResult
+ setSamplingOptions(temperature: float, topP: float, seed: int)
# infer(prompt: string) : string
}
A bounded thread-safe queue enabling producer-consumer patterns: class MockGenerator {
--
+ load(modelPath: string)
+ generateBrewery(...) : BreweryResult
+ generateUser(locale: string) : UserResult
}
```cpp struct BreweryResult {
template <typename T> class WorkQueue { name: string
std::queue<T> queue; description: string
std::mutex mutex; }
std::condition_variable cv_not_empty, cv_not_full;
size_t max_size; struct UserResult {
bool shutdown; username: string
}; bio: string
}
}
package "Enrichment (Planned)" {
class WikipediaService {
- webClient: IWebClient
--
+ SearchCity(cityName: string, countryName: string) : string
}
}
' Relationships
BiergartenDataGenerator --> DataDownloader
BiergartenDataGenerator --> JsonLoader
BiergartenDataGenerator --> SqliteDatabase
BiergartenDataGenerator --> IDataGenerator
DataDownloader --> IWebClient
CURLWebClient ..|> IWebClient
JsonLoader --> StreamingJsonParser
JsonLoader --> SqliteDatabase
LlamaGenerator ..|> IDataGenerator
MockGenerator ..|> IDataGenerator
SqliteDatabase --> Country
SqliteDatabase --> State
SqliteDatabase --> City
LlamaGenerator --> BreweryResult
LlamaGenerator --> UserResult
MockGenerator --> BreweryResult
MockGenerator --> UserResult
WikipediaService --> IWebClient
@enduml
``` ```
**push(item)**:
- Locks mutex.
- Waits on `cv_not_full` until queue is below max_size OR shutdown signaled.
- Pushes item; notifies one waiter on `cv_not_empty`.
- Returns false if shutdown, else true.
**pop()**:
- Locks mutex.
- Waits on `cv_not_empty` until queue has items OR shutdown signaled.
- Pops and returns `std::optional<T>`; notifies one waiter on `cv_not_full`.
- Returns `std::nullopt` if shutdown and queue is empty.
**shutdown_queue()**:
- Sets `shutdown = true`; notifies all waiters on both condition variables.
- Causes all waiting pop() calls to return `std::nullopt`.
**Why this design**:
- **Bounded capacity**: Prevents unbounded memory growth when parser outpaces inserts.
- **Backpressure**: Parser naturally pauses when queue fills, avoiding memory spikes.
- **Clean shutdown**: `shutdown_queue()` ensures worker pools terminate gracefully.
### SqliteDatabase Mutex
All SQLite operations (`INSERT`, `SELECT`) are guarded by `dbMutex`:
```cpp
std::unique_lock<std::mutex> lock(dbMutex);
int rc = sqlite3_step(stmt);
```
**Why**: SQLite's "serializable" journal mode (default) requires external synchronization for multi-threaded access. A single mutex serializes all queries, avoiding `SQLITE_BUSY` errors.
**Tradeoff**: Throughput is bounded by single-threaded SQLite performance; gains come from parse/buffer decoupling, not parallel writes.
---
## Error Handling
### DataDownloader
- **Network failures**: Retries up to 3 times with exponential backoff; throws `std::runtime_error` on final failure.
- **Caching**: Falls back to cached file if download fails and cache exists.
### Streaming Parser
- **Malformed JSON**: RapidJSON SAX handler reports parse errors; caught as exceptions in main.
- **Missing fields**: Silently skips incomplete records (e.g., city without latitude).
### Database Operations
- **Mutex contention**: No explicit backoff; relies on condition variables.
- **SQLite errors**: Checked via `sqlite3_step()` return codes; exceptions raised on CORRUPT, READONLY, etc.
### Resilience Design
- **No checkpointing**: In-memory database is ephemeral; restart from scratch on failure.
- **Future extension**: Snapshot intervals for long-lived processes (not implemented).
---
## Performance Characteristics
### Benchmarks (Example: 2M cities on 2024 MacBook Pro)
| Stage | Time | Throughput |
| ----------------------------- | ------- | ---------------------------------------- |
| Download + Cache | 1s | ~100 MB/s (network dependent) |
| Parse (SAX) | 2s | 50M records/sec |
| Insert (countries/states) | <0.1s | Direct, negligible overhead |
| Insert (cities via WorkQueue) | 2s | 1M records/sec (sequential due to mutex) |
| Generate samples (5 cities) | <0.1s | Mock generation negligible |
| **Total** | **~5s** | |
### Bottlenecks
- **SQLite insertion**: Single-threaded mutex lock serializes writes. Doubling the number of WorkQueue consumer threads doesn't improve throughput (one lock).
- **Parse speed**: RapidJSON SAX is fast (2s for 100 MB); not the bottleneck.
- **Memory**: ~100 MB for in-memory database; suitable for most deployments.
### Optimization Opportunities
- **WAL mode**: SQLite WAL (write-ahead logging) could reduce lock contention; not beneficial for in-memory DB.
- **Batch inserts**: Combine multiple rows in a single transaction; helps if inserting outside the WorkQueue scope.
- **Foreign key lazy-loading**: Skip foreign key constraints during bulk load; re-enable for queries. (Not implemented.)
---
## Configuration and Extensibility ## Configuration and Extensibility
### Command-Line Arguments ### Command-Line Arguments
```bash Boost.Program_options provides named CLI arguments. Running without arguments displays usage instructions.
./biergarten-pipeline [modelPath] [cacheDir] [commit]
```
| Arg | Default | Purpose |
| ----------- | -------------- | ----------------------------------------------------------------------- |
| `modelPath` | `./model.gguf` | Path to LLM model (mock implementation; not loaded in current version). |
| `cacheDir` | `/tmp` | Directory for cached JSON (e.g., `/tmp/countries+states+cities.json`). |
| `commit` | `c5eb7772` | Git commit hash for consistency (stable 2026-03-28 snapshot). |
**Examples**:
```bash ```bash
./biergarten-pipeline ./biergarten-pipeline [options]
./biergarten-pipeline ./models/llama.gguf /var/cache main
./biergarten-pipeline "" /tmp v1.2.3
``` ```
### Extending the Generator **Requirement**: Exactly one of `--mocked` or `--model` must be specified.
**Current**: `LlamaBreweryGenerator::GenerateBrewery()` uses deterministic seed-based selection from hardcoded lists. | Argument | Short | Type | Purpose |
| --------------- | ----- | ------ | --------------------------------------------------------------- |
| `--mocked` | - | flag | Use mocked generator for brewery/user data |
| `--model` | `-m` | string | Path to LLM model file (gguf); mutually exclusive with --mocked |
| `--cache-dir` | `-c` | path | Directory for cached JSON (default: `/tmp`) |
| `--temperature` | - | float | LLM sampling temperature 0.0-1.0 (default: `0.8`) |
| `--top-p` | - | float | Nucleus sampling parameter 0.0-1.0 (default: `0.92`) |
| `--seed` | - | int | Random seed: -1 for random (default: `-1`) |
| `--help` | `-h` | flag | Show help message |
**Future swap points**: **Note**: The data source is always pinned to commit `c5eb7772` (stable 2026-03-28) and cannot be changed.
1. Load an actual LLM model in `LoadModel(modelPath)`. **Note**: When `--mocked` is used, any sampling parameters (`--temperature`, `--top-p`, `--seed`) are ignored with a warning.
2. Tokenize city name and context; call model inference.
3. Validate output (length, format) and rank if multiple candidates.
4. Cache results to avoid re-inference for repeated cities.
**Example stub for future integration**: ### Usage Examples
```cpp ```bash
Brewery LlamaBreweryGenerator::GenerateBrewery(const std::string &cityName, int seed) { # Mocked generator (deterministic, no LLM required)
// TODO: Replace with actual llama.cpp inference ./biergarten-pipeline --mocked
// llama_context *ctx = llama_new_context_with_model(model, params);
// std::string prompt = "Generate a brewery for " + cityName; # With LLM model
// std::string result = llama_inference(ctx, prompt, seed); ./biergarten-pipeline --model ./models/llama.gguf --cache-dir /var/cache
// return parse_brewery(result);
} # Mocked with extra parameters provided (will be ignored with warning)
./biergarten-pipeline --mocked --temperature 0.5 --top-p 0.8 --seed 42
# Show help
./biergarten-pipeline --help
``` ```
### Logging Configuration
Logging uses **spdlog** with:
- **Level**: Info (can change via `spdlog::set_level(spdlog::level::debug)` at startup).
- **Format**: Plain ASCII; no Unicode box art.
- **Sink**: Console (stdout/stderr); can redirect to file.
**Current output sample**:
```
[Pipeline] Downloading geographic data from GitHub...
Initializing in-memory SQLite database...
Initializing brewery generator...
=== GEOGRAPHIC DATA OVERVIEW ===
Total records loaded:
Countries: 195
States: 5000
Cities: 150000
```
---
## Building and Running ## Building and Running
### Prerequisites ### Prerequisites
- C++17 compiler (g++, clang, MSVC). - **C++23 compiler** (g++, clang, MSVC)
- CMake 3.20+. - **CMake** 3.20+
- curl (for HTTP downloads). - **curl** (for HTTP downloads)
- sqlite3 (usually system-provided). - **sqlite3** (database backend)
- RapidJSON (fetched via CMake FetchContent). - **Boost** 1.75+ (requires Boost.JSON and Boost.Program_options)
- spdlog (fetched via CMake FetchContent). - **spdlog** v1.11.0 (fetched via CMake FetchContent)
- **llama.cpp** (fetched via CMake FetchContent for LLM inference)
### Build ### Build
@@ -342,73 +333,74 @@ cmake ..
cmake --build . --target biergarten-pipeline -- -j cmake --build . --target biergarten-pipeline -- -j
``` ```
**Build artifacts**:
- Executable: `build/biergarten-pipeline`
- Intermediate: `build/CMakeFiles/`, `build/_deps/` (RapidJSON, spdlog)
### Run ### Run
```bash ```bash
./biergarten-pipeline ./build/biergarten-pipeline
``` ```
**Output**: Logs to console; caches JSON in `/tmp/countries+states+cities.json`. **Output**:
### Cleaning - Console logs with structured spdlog output
- Cached JSON file: `/tmp/countries+states+cities.json`
- SQLite database: `biergarten-pipeline.db` (in output directory)
## Code Quality and Static Analysis
### Formatting
This project uses **clang-format** with the **Google C++ style guide**:
```bash ```bash
rm -rf build # Apply formatting to all source files
cmake --build build --target format
# Check formatting without modifications
cmake --build build --target format-check
``` ```
--- ### Static Analysis
## Development Notes This project uses **clang-tidy** with configurations for Google, modernize, performance, and bug-prone rules (`.clang-tidy`):
### Code Organization Static analysis runs automatically during compilation if `clang-tidy` is available.
- **`includes/`**: Public headers (data structures, class APIs). ## Code Implementation Summary
- **`src/`**: Implementations with inline comments for non-obvious logic.
- **`CMakeLists.txt`**: Build configuration; defines fetch content, compiler flags, linking.
### Testing ### Key Achievements
Currently no automated tests. To add: **Full pipeline implementation** - Download → Parse → Store → Generate
**Streaming JSON parser** - Memory-efficient processing via Boost.JSON callbacks
**Thread-safe SQLite wrapper** - Mutex-protected database for future parallelization
**Flexible data generation** - Abstract IDataGenerator interface supporting both mock and LLM modes
**Comprehensive CLI** - Boost.Program_options with sensible defaults
**Production-grade logging** - spdlog integration for structured output
**Build quality** - CMake with clang-format/clang-tidy integration
1. Create `tests/` folder. ### Architecture Patterns
2. Use CMake to add a test executable.
3. Test the parser with small JSON fixtures.
4. Mock the database for isolation.
### Debugging - **Interface-based design**: `IWebClient`, `IDataGenerator` abstract base classes enable substitution and testing
- **Dependency injection**: Components receive dependencies via constructors (BiergartenDataGenerator)
- **RAII principle**: SQLite connections and resources managed via destructors
- **Callback-driven parsing**: Boost.JSON parser emits events to processing callbacks
- **Transaction-scoped inserts**: BeginTransaction/CommitTransaction for batch performance
**Enable verbose logging**: ### External Dependencies
```cpp | Dependency | Version | Purpose | Type |
spdlog::set_level(spdlog::level::debug); | ---------- | ------- | ---------------------------------- | ------- |
``` | Boost | 1.75+ | JSON parsing, CLI argument parsing | Library |
| SQLite3 | - | Persistent data storage | System |
| libcurl | - | HTTP downloads | System |
| spdlog | v1.11.0 | Structured logging | Fetched |
| llama.cpp | b8611 | LLM inference engine | Fetched |
**GDB workflow**: to validate formatting without modifying files.
```bash clang-tidy runs automatically on the biergarten-pipeline target when available. You can disable it at configure time:
gdb ./biergarten-pipeline
(gdb) break src/stream_parser.cpp:50
(gdb) run
```
### Future Enhancements cmake -DENABLE_CLANG_TIDY=OFF ..
1. **Real LLM integration**: Load and run llama.cpp models. You can also disable format helper targets:
2. **Persistence**: Write brewery data to a database or file.
3. **Distributed parsing**: Shard JSON file across multiple parse streams.
4. **Incremental updates**: Only insert new records if source updated.
5. **Web API**: Expose database via HTTP (brewery lookup, city search).
--- cmake -DENABLE_CLANG_FORMAT_TARGETS=OFF ..
## References
- [RapidJSON](https://rapidjson.org/) SAX parsing documentation.
- [spdlog](https://github.com/gabime/spdlog) Logging framework.
- [SQLite](https://www.sqlite.org/docs.html) In-memory database reference.
- [countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) Data source.

View File

@@ -0,0 +1,116 @@
#ifndef BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_
#define BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "data_generation/data_generator.h"
#include "database/database.h"
#include "web_client/web_client.h"
#include "wikipedia/wikipedia_service.h"
/**
* @brief Program options for the Biergarten pipeline application.
*/
struct ApplicationOptions {
/// @brief Path to the LLM model file (gguf format); mutually exclusive with use_mocked.
std::string model_path;
/// @brief Use mocked generator instead of LLM; mutually exclusive with model_path.
bool use_mocked = false;
/// @brief Directory for cached JSON and database files.
std::string cache_dir;
/// @brief LLM sampling temperature (0.0 to 1.0, higher = more random).
float temperature = 0.8f;
/// @brief LLM nucleus sampling top-p parameter (0.0 to 1.0, higher = more random).
float top_p = 0.92f;
/// @brief Random seed for sampling (-1 for random, otherwise non-negative).
int seed = -1;
/// @brief Git commit hash for database consistency (always pinned to c5eb7772).
std::string commit = "c5eb7772";
};
#endif // BIERGARTEN_PIPELINE_BIERGARTEN_DATA_GENERATOR_H_
/**
* @brief Main data generator class for the Biergarten pipeline.
*
* This class encapsulates the core logic for generating brewery data.
* It handles database initialization, data loading/downloading, and brewery generation.
*/
class BiergartenDataGenerator {
public:
/**
* @brief Construct a BiergartenDataGenerator with injected dependencies.
*
* @param options Application configuration options.
* @param web_client HTTP client for downloading data.
* @param database SQLite database instance.
*/
BiergartenDataGenerator(const ApplicationOptions &options,
std::shared_ptr<WebClient> web_client,
SqliteDatabase &database);
/**
* @brief Run the data generation pipeline.
*
* Performs the following steps:
* 1. Initialize database
* 2. Download geographic data if needed
* 3. Initialize the generator (LLM or Mock)
* 4. Generate brewery data for sample cities
*
* @return 0 on success, 1 on failure.
*/
int Run();
private:
/// @brief Immutable application options.
const ApplicationOptions options_;
/// @brief Shared HTTP client dependency.
std::shared_ptr<WebClient> webClient_;
/// @brief Database dependency.
SqliteDatabase &database_;
/**
* @brief Initialize the data generator based on options.
*
* Creates either a MockGenerator (if no model path) or LlamaGenerator.
*
* @return A unique_ptr to the initialized generator.
*/
std::unique_ptr<DataGenerator> InitializeGenerator();
/**
* @brief Download and load geographic data if not cached.
*/
void LoadGeographicData();
/**
* @brief Generate sample breweries for demonstration.
*/
void GenerateSampleBreweries();
/**
* @brief Helper struct to store generated brewery data.
*/
struct GeneratedBrewery {
int cityId;
std::string cityName;
BreweryResult brewery;
};
/// @brief Stores generated brewery data.
std::vector<GeneratedBrewery> generatedBreweries_;
};

View File

@@ -1,26 +1,30 @@
#ifndef DATA_DOWNLOADER_H #ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_
#define DATA_DOWNLOADER_H #define BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_
#include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "web_client/web_client.h"
/// @brief Downloads and caches source geography JSON payloads. /// @brief Downloads and caches source geography JSON payloads.
class DataDownloader { class DataDownloader {
public: public:
/// @brief Initializes global curl state used by this downloader. /// @brief Initializes global curl state used by this downloader.
DataDownloader(); explicit DataDownloader(std::shared_ptr<WebClient> web_client);
/// @brief Cleans up global curl state. /// @brief Cleans up global curl state.
~DataDownloader(); ~DataDownloader();
/// @brief Returns a local JSON path, downloading it when cache is missing. /// @brief Returns a local JSON path, downloading it when cache is missing.
std::string DownloadCountriesDatabase( std::string DownloadCountriesDatabase(
const std::string &cachePath, const std::string &cache_path,
const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export const std::string &commit = "c5eb7772" // Stable commit: 2026-03-28 export
); );
private: private:
bool FileExists(const std::string &filePath) const; static bool FileExists(const std::string &file_path);
std::shared_ptr<WebClient> web_client_;
}; };
#endif // DATA_DOWNLOADER_H #endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_DOWNLOADER_H_

View File

@@ -0,0 +1,29 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_
#include <string>
struct BreweryResult {
std::string name;
std::string description;
};
struct UserResult {
std::string username;
std::string bio;
};
class DataGenerator {
public:
virtual ~DataGenerator() = default;
virtual void Load(const std::string &model_path) = 0;
virtual BreweryResult GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) = 0;
virtual UserResult GenerateUser(const std::string &locale) = 0;
};
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_DATA_GENERATOR_H_

View File

@@ -0,0 +1,41 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_
#include <cstdint>
#include <string>
#include "data_generation/data_generator.h"
struct llama_model;
struct llama_context;
class LlamaGenerator final : public DataGenerator {
public:
LlamaGenerator() = default;
~LlamaGenerator() override;
void SetSamplingOptions(float temperature, float top_p, int seed = -1);
void Load(const std::string &model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) override;
UserResult GenerateUser(const std::string &locale) override;
private:
std::string Infer(const std::string &prompt, int max_tokens = 10000);
// Overload that allows passing a system message separately so chat-capable
// models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output).
std::string Infer(const std::string &system_prompt, const std::string &prompt,
int max_tokens = 10000);
llama_model *model_ = nullptr;
llama_context *context_ = nullptr;
float sampling_temperature_ = 0.8f;
float sampling_top_p_ = 0.92f;
uint32_t sampling_seed_ = 0xFFFFFFFFu;
};
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_LLAMA_GENERATOR_H_

View File

@@ -0,0 +1,27 @@
#ifndef BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#define BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_
#include "data_generation/data_generator.h"
#include <string>
#include <vector>
class MockGenerator final : public DataGenerator {
public:
void Load(const std::string &model_path) override;
BreweryResult GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) override;
UserResult GenerateUser(const std::string &locale) override;
private:
static std::size_t DeterministicHash(const std::string &a,
const std::string &b);
static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns;
static const std::vector<std::string> kBreweryDescriptions;
static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios;
};
#endif // BIERGARTEN_PIPELINE_DATA_GENERATION_MOCK_GENERATOR_H_

View File

@@ -1,26 +0,0 @@
#pragma once
#include <string>
struct BreweryResult {
std::string name;
std::string description;
};
struct UserResult {
std::string username;
std::string bio;
};
class IDataGenerator {
public:
virtual ~IDataGenerator() = default;
virtual void load(const std::string &modelPath) = 0;
virtual BreweryResult generateBrewery(const std::string &cityName,
const std::string &countryName,
const std::string &regionContext) = 0;
virtual UserResult generateUser(const std::string &locale) = 0;
};

View File

@@ -1,4 +1,5 @@
#pragma once #ifndef BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#define BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_
#include <mutex> #include <mutex>
#include <sqlite3.h> #include <sqlite3.h>
@@ -24,14 +25,23 @@ struct State {
/// @brief State or province short code. /// @brief State or province short code.
std::string iso2; std::string iso2;
/// @brief Parent country identifier. /// @brief Parent country identifier.
int countryId; int country_id;
};
struct City {
/// @brief City identifier from the source dataset.
int id;
/// @brief City display name.
std::string name;
/// @brief Parent country identifier.
int country_id;
}; };
/// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks.
class SqliteDatabase { class SqliteDatabase {
private: private:
sqlite3 *db = nullptr; sqlite3 *db_ = nullptr;
std::mutex dbMutex; std::mutex db_mutex_;
void InitializeSchema(); void InitializeSchema();
@@ -39,8 +49,8 @@ public:
/// @brief Closes the SQLite connection if initialized. /// @brief Closes the SQLite connection if initialized.
~SqliteDatabase(); ~SqliteDatabase();
/// @brief Opens the SQLite database at dbPath and creates schema objects. /// @brief Opens the SQLite database at db_path and creates schema objects.
void Initialize(const std::string &dbPath = ":memory:"); void Initialize(const std::string &db_path = ":memory:");
/// @brief Starts a database transaction for batched writes. /// @brief Starts a database transaction for batched writes.
void BeginTransaction(); void BeginTransaction();
@@ -53,15 +63,15 @@ public:
const std::string &iso3); const std::string &iso3);
/// @brief Inserts a state row linked to a country. /// @brief Inserts a state row linked to a country.
void InsertState(int id, int countryId, const std::string &name, void InsertState(int id, int country_id, const std::string &name,
const std::string &iso2); const std::string &iso2);
/// @brief Inserts a city row linked to state and country. /// @brief Inserts a city row linked to state and country.
void InsertCity(int id, int stateId, int countryId, const std::string &name, void InsertCity(int id, int state_id, int country_id, const std::string &name,
double latitude, double longitude); double latitude, double longitude);
/// @brief Returns city id and city name pairs. /// @brief Returns city records including parent country id.
std::vector<std::pair<int, std::string>> QueryCities(); std::vector<City> QueryCities();
/// @brief Returns countries with optional row limit. /// @brief Returns countries with optional row limit.
std::vector<Country> QueryCountries(int limit = 0); std::vector<Country> QueryCountries(int limit = 0);
@@ -69,3 +79,5 @@ public:
/// @brief Returns states with optional row limit. /// @brief Returns states with optional row limit.
std::vector<State> QueryStates(int limit = 0); std::vector<State> QueryStates(int limit = 0);
}; };
#endif // BIERGARTEN_PIPELINE_DATABASE_DATABASE_H_

View File

@@ -0,0 +1,15 @@
#ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_
#include "database/database.h"
#include "json_handling/stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader {
public:
/// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &json_path, SqliteDatabase &db);
};
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_JSON_LOADER_H_

View File

@@ -1,6 +1,7 @@
#pragma once #ifndef BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#define BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_
#include "database.h" #include "database/database.h"
#include <functional> #include <functional>
#include <string> #include <string>
@@ -20,10 +21,10 @@ struct CityRecord {
/// @brief Streaming SAX parser that emits city records during traversal. /// @brief Streaming SAX parser that emits city records during traversal.
class StreamingJsonParser { class StreamingJsonParser {
public: public:
/// @brief Parses filePath and invokes callbacks for city rows and progress. /// @brief Parses file_path and invokes callbacks for city rows and progress.
static void Parse(const std::string &filePath, SqliteDatabase &db, static void Parse(const std::string &file_path, SqliteDatabase &db,
std::function<void(const CityRecord &)> onCity, std::function<void(const CityRecord &)> on_city,
std::function<void(size_t, size_t)> onProgress = nullptr); std::function<void(size_t, size_t)> on_progress = nullptr);
private: private:
/// @brief Mutable SAX handler state while traversing nested JSON arrays. /// @brief Mutable SAX handler state while traversing nested JSON arrays.
@@ -46,3 +47,5 @@ private:
size_t bytes_processed = 0; size_t bytes_processed = 0;
}; };
}; };
#endif // BIERGARTEN_PIPELINE_JSON_HANDLING_STREAM_PARSER_H_

View File

@@ -1,12 +0,0 @@
#pragma once
#include "database.h"
#include "stream_parser.h"
#include <string>
/// @brief Loads world-city JSON data into SQLite through streaming parsing.
class JsonLoader {
public:
/// @brief Parses a JSON file and writes country/state/city rows into db.
static void LoadWorldCities(const std::string &jsonPath, SqliteDatabase &db);
};

View File

@@ -1,31 +0,0 @@
#pragma once
#include "data_generator.h"
#include <memory>
#include <string>
struct llama_model;
struct llama_context;
class LlamaGenerator final : public IDataGenerator {
public:
~LlamaGenerator() override;
void load(const std::string &modelPath) override;
BreweryResult generateBrewery(const std::string &cityName,
const std::string &countryName,
const std::string &regionContext) override;
UserResult generateUser(const std::string &locale) override;
private:
std::string infer(const std::string &prompt, int maxTokens = 5000);
// Overload that allows passing a system message separately so chat-capable
// models receive a proper system role instead of having the system text
// concatenated into the user prompt (helps avoid revealing internal
// reasoning or instructions in model output).
std::string infer(const std::string &systemPrompt, const std::string &prompt,
int maxTokens = 5000);
llama_model *model_ = nullptr;
llama_context *context_ = nullptr;
};

View File

@@ -1,24 +0,0 @@
#pragma once
#include "data_generator.h"
#include <string>
#include <vector>
class MockGenerator final : public IDataGenerator {
public:
void load(const std::string &modelPath) override;
BreweryResult generateBrewery(const std::string &cityName,
const std::string &countryName,
const std::string &regionContext) override;
UserResult generateUser(const std::string &locale) override;
private:
static std::size_t deterministicHash(const std::string &a,
const std::string &b);
static const std::vector<std::string> kBreweryAdjectives;
static const std::vector<std::string> kBreweryNouns;
static const std::vector<std::string> kBreweryDescriptions;
static const std::vector<std::string> kUsernames;
static const std::vector<std::string> kBios;
};

View File

@@ -0,0 +1,29 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_
#include "web_client/web_client.h"
#include <memory>
// RAII for curl_global_init/cleanup.
// An instance of this class should be created in main() before any curl
// operations and exist for the lifetime of the application.
class CurlGlobalState {
public:
CurlGlobalState();
~CurlGlobalState();
CurlGlobalState(const CurlGlobalState &) = delete;
CurlGlobalState &operator=(const CurlGlobalState &) = delete;
};
class CURLWebClient : public WebClient {
public:
CURLWebClient();
~CURLWebClient() override;
void DownloadToFile(const std::string &url,
const std::string &file_path) override;
std::string Get(const std::string &url) override;
std::string UrlEncode(const std::string &value) override;
};
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_CURL_WEB_CLIENT_H_

View File

@@ -0,0 +1,22 @@
#ifndef BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_
#define BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_
#include <string>
class WebClient {
public:
virtual ~WebClient() = default;
// Downloads content from a URL to a file. Throws on error.
virtual void DownloadToFile(const std::string &url,
const std::string &file_path) = 0;
// Performs a GET request and returns the response body as a string. Throws on
// error.
virtual std::string Get(const std::string &url) = 0;
// URL-encodes a string.
virtual std::string UrlEncode(const std::string &value) = 0;
};
#endif // BIERGARTEN_PIPELINE_WEB_CLIENT_WEB_CLIENT_H_

View File

@@ -0,0 +1,27 @@
#ifndef BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_
#define BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include "web_client/web_client.h"
/// @brief Provides cached Wikipedia summary lookups for city and country pairs.
class WikipediaService {
public:
/// @brief Creates a new Wikipedia service with the provided web client.
explicit WikipediaService(std::shared_ptr<WebClient> client);
/// @brief Returns the Wikipedia summary extract for city and country.
[[nodiscard]] std::string GetSummary(std::string_view city,
std::string_view country);
private:
std::string FetchExtract(std::string_view query);
std::shared_ptr<WebClient> client_;
std::unordered_map<std::string, std::string> cache_;
};
#endif // BIERGARTEN_PIPELINE_WIKIPEDIA_WIKIPEDIA_SERVICE_H_

View File

@@ -0,0 +1,132 @@
#include "biergarten_data_generator.h"
#include <algorithm>
#include <filesystem>
#include <unordered_map>
#include <spdlog/spdlog.h>
#include "data_generation/data_downloader.h"
#include "json_handling/json_loader.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
#include "wikipedia/wikipedia_service.h"
BiergartenDataGenerator::BiergartenDataGenerator(
const ApplicationOptions &options,
std::shared_ptr<WebClient> web_client,
SqliteDatabase &database)
: options_(options), webClient_(web_client), database_(database) {}
std::unique_ptr<DataGenerator> BiergartenDataGenerator::InitializeGenerator() {
spdlog::info("Initializing brewery generator...");
std::unique_ptr<DataGenerator> generator;
if (options_.model_path.empty()) {
generator = std::make_unique<MockGenerator>();
spdlog::info("[Generator] Using MockGenerator (no model path provided)");
} else {
auto llama_generator = std::make_unique<LlamaGenerator>();
llama_generator->SetSamplingOptions(options_.temperature, options_.top_p,
options_.seed);
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
"seed={})",
options_.model_path, options_.temperature, options_.top_p,
options_.seed);
generator = std::move(llama_generator);
}
generator->Load(options_.model_path);
return generator;
}
void BiergartenDataGenerator::LoadGeographicData() {
std::string json_path = options_.cache_dir + "/countries+states+cities.json";
std::string db_path = options_.cache_dir + "/biergarten-pipeline.db";
bool has_json_cache = std::filesystem::exists(json_path);
bool has_db_cache = std::filesystem::exists(db_path);
spdlog::info("Initializing SQLite database at {}...", db_path);
database_.Initialize(db_path);
if (has_db_cache && has_json_cache) {
spdlog::info("[Pipeline] Cache hit: skipping download and parse");
} else {
spdlog::info("\n[Pipeline] Downloading geographic data from GitHub...");
DataDownloader downloader(webClient_);
downloader.DownloadCountriesDatabase(json_path, options_.commit);
JsonLoader::LoadWorldCities(json_path, database_);
}
}
void BiergartenDataGenerator::GenerateSampleBreweries() {
auto generator = InitializeGenerator();
WikipediaService wikipedia_service(webClient_);
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
auto countries = database_.QueryCountries(50);
auto states = database_.QueryStates(50);
auto cities = database_.QueryCities();
// Build a quick map of country id -> name for per-city lookups.
auto all_countries = database_.QueryCountries(0);
std::unordered_map<int, std::string> country_map;
for (const auto &c : all_countries)
country_map[c.id] = c.name;
spdlog::info("\nTotal records loaded:");
spdlog::info(" Countries: {}", database_.QueryCountries(0).size());
spdlog::info(" States: {}", database_.QueryStates(0).size());
spdlog::info(" Cities: {}", cities.size());
generatedBreweries_.clear();
const size_t sample_count = std::min(size_t(30), cities.size());
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
for (size_t i = 0; i < sample_count; i++) {
const auto &city = cities[i];
const int city_id = city.id;
const std::string city_name = city.name;
std::string local_country;
const auto country_it = country_map.find(city.country_id);
if (country_it != country_map.end()) {
local_country = country_it->second;
}
const std::string region_context =
wikipedia_service.GetSummary(city_name, local_country);
spdlog::debug("[Pipeline] Region context for {}: {}", city_name,
region_context);
auto brewery =
generator->GenerateBrewery(city_name, local_country, region_context);
generatedBreweries_.push_back({city_id, city_name, brewery});
}
spdlog::info("\n=== GENERATED DATA DUMP ===");
for (size_t i = 0; i < generatedBreweries_.size(); i++) {
const auto &entry = generatedBreweries_[i];
spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId,
entry.cityName);
spdlog::info(" brewery_name=\"{}\"", entry.brewery.name);
spdlog::info(" brewery_description=\"{}\"", entry.brewery.description);
}
}
int BiergartenDataGenerator::Run() {
try {
LoadGeographicData();
GenerateSampleBreweries();
spdlog::info("\nOK: Pipeline completed successfully");
return 0;
} catch (const std::exception &e) {
spdlog::error("ERROR: Pipeline failed: {}", e.what());
return 1;
}
}

View File

@@ -1,102 +0,0 @@
#include "data_downloader.h"
#include <cstdio>
#include <curl/curl.h>
#include <filesystem>
#include <fstream>
#include <spdlog/spdlog.h>
#include <sstream>
static size_t WriteCallback(void *contents, size_t size, size_t nmemb,
void *userp) {
size_t realsize = size * nmemb;
std::ofstream *outFile = static_cast<std::ofstream *>(userp);
outFile->write(static_cast<char *>(contents), realsize);
return realsize;
}
DataDownloader::DataDownloader() {}
DataDownloader::~DataDownloader() {}
bool DataDownloader::FileExists(const std::string &filePath) const {
return std::filesystem::exists(filePath);
}
std::string
DataDownloader::DownloadCountriesDatabase(const std::string &cachePath,
const std::string &commit) {
if (FileExists(cachePath)) {
spdlog::info("[DataDownloader] Cache hit: {}", cachePath);
return cachePath;
}
std::string shortCommit = commit;
if (commit.length() > 7) {
shortCommit = commit.substr(0, 7);
}
std::string url = "https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" +
shortCommit + "/json/countries+states+cities.json";
spdlog::info("[DataDownloader] Downloading: {}", url);
CURL *curl = curl_easy_init();
if (!curl) {
throw std::runtime_error("[DataDownloader] Failed to initialize libcurl");
}
std::ofstream outFile(cachePath, std::ios::binary);
if (!outFile.is_open()) {
curl_easy_cleanup(curl);
throw std::runtime_error("[DataDownloader] Cannot open file for writing: " +
cachePath);
}
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, static_cast<void *>(&outFile));
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 30L);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 300L);
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
CURLcode res = curl_easy_perform(curl);
outFile.close();
if (res != CURLE_OK) {
curl_easy_cleanup(curl);
std::remove(cachePath.c_str());
std::string error = std::string("[DataDownloader] Download failed: ") +
curl_easy_strerror(res);
throw std::runtime_error(error);
}
long httpCode = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode);
curl_easy_cleanup(curl);
if (httpCode != 200) {
std::remove(cachePath.c_str());
std::stringstream ss;
ss << "[DataDownloader] HTTP error " << httpCode
<< " (commit: " << shortCommit << ")";
throw std::runtime_error(ss.str());
}
std::ifstream fileCheck(cachePath, std::ios::binary | std::ios::ate);
std::streamsize size = fileCheck.tellg();
fileCheck.close();
spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)",
cachePath, (size / (1024.0 * 1024.0)));
return cachePath;
}

View File

@@ -0,0 +1,46 @@
#include "data_generation/data_downloader.h"
#include "web_client/web_client.h"
#include <filesystem>
#include <fstream>
#include <spdlog/spdlog.h>
#include <sstream>
#include <stdexcept>
DataDownloader::DataDownloader(std::shared_ptr<WebClient> web_client)
: web_client_(std::move(web_client)) {}
DataDownloader::~DataDownloader() {}
bool DataDownloader::FileExists(const std::string &file_path) {
return std::filesystem::exists(file_path);
}
std::string
DataDownloader::DownloadCountriesDatabase(const std::string &cache_path,
const std::string &commit) {
if (FileExists(cache_path)) {
spdlog::info("[DataDownloader] Cache hit: {}", cache_path);
return cache_path;
}
std::string short_commit = commit;
if (commit.length() > 7) {
short_commit = commit.substr(0, 7);
}
std::string url = "https://raw.githubusercontent.com/dr5hn/"
"countries-states-cities-database/" +
short_commit + "/json/countries+states+cities.json";
spdlog::info("[DataDownloader] Downloading: {}", url);
web_client_->DownloadToFile(url, cache_path);
std::ifstream file_check(cache_path, std::ios::binary | std::ios::ate);
std::streamsize size = file_check.tellg();
file_check.close();
spdlog::info("[DataDownloader] OK: Download complete: {} ({:.2f} MB)",
cache_path, (size / (1024.0 * 1024.0)));
return cache_path;
}

View File

@@ -0,0 +1,734 @@
#include <algorithm>
#include <array>
#include <cctype>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "llama.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator.h"
namespace {
std::string trim(std::string value) {
auto notSpace = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), notSpace));
value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(),
value.end());
return value;
}
std::string CondenseWhitespace(std::string text) {
std::string out;
out.reserve(text.size());
bool inWhitespace = false;
for (unsigned char ch : text) {
if (std::isspace(ch)) {
if (!inWhitespace) {
out.push_back(' ');
inWhitespace = true;
}
continue;
}
inWhitespace = false;
out.push_back(static_cast<char>(ch));
}
return trim(std::move(out));
}
std::string PrepareRegionContext(std::string_view regionContext,
std::size_t maxChars = 700) {
std::string normalized = CondenseWhitespace(std::string(regionContext));
if (normalized.size() <= maxChars) {
return normalized;
}
normalized.resize(maxChars);
const std::size_t lastSpace = normalized.find_last_of(' ');
if (lastSpace != std::string::npos && lastSpace > maxChars / 2) {
normalized.resize(lastSpace);
}
normalized += "...";
return normalized;
}
std::string stripCommonPrefix(std::string line) {
line = trim(std::move(line));
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = trim(line.substr(i + 1));
}
}
auto stripLabel = [&line](const std::string &label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = trim(line.substr(label.size()));
}
}
};
stripLabel("name:");
stripLabel("brewery name:");
stripLabel("description:");
stripLabel("username:");
stripLabel("bio:");
return trim(std::move(line));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty())
lines.push_back(std::move(line));
}
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!l.empty() && l.front() == '<' && low.back() == '>')
continue;
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
continue;
filtered.push_back(std::move(l));
}
if (filtered.size() < 2)
throw std::runtime_error(errorMessage);
std::string first = trim(filtered.front());
std::string second;
for (size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty())
second += ' ';
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty())
throw std::runtime_error(errorMessage);
return {first, second};
}
std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return userPrompt;
}
const llama_chat_message message{"user", userPrompt.c_str()};
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string toChatPrompt(const llama_model *model,
const std::string &system_prompt,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return system_prompt + "\n\n" + userPrompt;
}
const llama_chat_message messages[2] = {{"system", system_prompt.c_str()},
{"user", userPrompt.c_str()}};
std::vector<char> buffer(std::max<std::size_t>(
1024, (systemPrompt.size() + userPrompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void appendTokenPiece(const llama_vocab *vocab, llama_token token,
std::string &output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamicBuffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(),
static_cast<int32_t>(dynamicBuffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamicBuffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) {
std::size_t start = std::string::npos;
int depth = 0;
bool inString = false;
bool escaped = false;
for (std::size_t i = 0; i < text.size(); ++i) {
const char ch = text[i];
if (inString) {
if (escaped) {
escaped = false;
} else if (ch == '\\') {
escaped = true;
} else if (ch == '"') {
inString = false;
}
continue;
}
if (ch == '"') {
inString = true;
continue;
}
if (ch == '{') {
if (depth == 0) {
start = i;
}
++depth;
continue;
}
if (ch == '}') {
if (depth == 0) {
continue;
}
--depth;
if (depth == 0 && start != std::string::npos) {
jsonOut = text.substr(start, i - start + 1);
return true;
}
}
}
return false;
}
std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut,
std::string &descriptionOut) {
auto validateObject = [&](const boost::json::value &jv,
std::string &errorOut) -> bool {
if (!jv.is_object()) {
errorOut = "JSON root must be an object";
return false;
}
const auto &obj = jv.get_object();
if (!obj.contains("name") || !obj.at("name").is_string()) {
errorOut = "JSON field 'name' is missing or not a string";
return false;
}
if (!obj.contains("description") || !obj.at("description").is_string()) {
errorOut = "JSON field 'description' is missing or not a string";
return false;
}
nameOut = trim(std::string(obj.at("name").as_string().c_str()));
descriptionOut =
trim(std::string(obj.at("description").as_string().c_str()));
if (nameOut.empty()) {
errorOut = "JSON field 'name' must not be empty";
return false;
}
if (descriptionOut.empty()) {
errorOut = "JSON field 'description' must not be empty";
return false;
}
std::string nameLower = nameOut;
std::string descriptionLower = descriptionOut;
std::transform(
nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
std::transform(descriptionLower.begin(), descriptionLower.end(),
descriptionLower.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (nameLower == "string" || descriptionLower == "string") {
errorOut = "JSON appears to be a schema placeholder, not content";
return false;
}
errorOut.clear();
return true;
};
boost::system::error_code ec;
boost::json::value jv = boost::json::parse(raw, ec);
std::string validationError;
if (ec) {
std::string extracted;
if (!extractFirstJsonObject(raw, extracted)) {
return "JSON parse error: " + ec.message();
}
ec.clear();
jv = boost::json::parse(extracted, ec);
if (ec) {
return "JSON parse error: " + ec.message();
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
if (!validateObject(jv, validationError)) {
return validationError;
}
return {};
}
} // namespace
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}
void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
int seed) {
if (temperature < 0.0f) {
throw std::runtime_error(
"LlamaGenerator: sampling temperature must be >= 0");
}
if (!(top_p > 0.0f && top_p <= 1.0f)) {
throw std::runtime_error(
"LlamaGenerator: sampling top-p must be in (0, 1]");
}
if (seed < -1) {
throw std::runtime_error(
"LlamaGenerator: seed must be >= 0, or -1 for random");
}
sampling_temperature_ = temperature;
sampling_top_p_ = top_p;
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
: static_cast<uint32_t>(seed);
}
void LlamaGenerator::Load(const std::string &model_path) {
if (model_path.empty())
throw std::runtime_error("LlamaGenerator: model path must not be empty");
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_ = llama_model_load_from_file(model_path.c_str(), model_params);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + model_path);
}
llama_context_params context_params = llama_context_default_params();
context_params.n_ctx = 2048;
context_ = llama_init_from_model(model_, context_params);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", model_path);
}
std::string LlamaGenerator::Infer(const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt = toChatPrompt(model_, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
const int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
std::string LlamaGenerator::Infer(const std::string &system_prompt,
const std::string &prompt, int max_tokens) {
if (model_ == nullptr || context_ == nullptr)
throw std::runtime_error("LlamaGenerator: model not loaded");
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr)
throw std::runtime_error("LlamaGenerator: vocab unavailable");
llama_memory_clear(llama_get_memory(context_), true);
const std::string formatted_prompt =
toChatPrompt(model_, system_prompt, prompt);
std::vector<llama_token> promptTokens(formatted_prompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0)
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
if (nCtx <= 1 || nBatch <= 0) {
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
}
const int32_t effective_max_tokens = std::max(1, std::min(max_tokens, nCtx - 1));
int32_t prompt_budget = std::min(nBatch, nCtx - effective_max_tokens);
prompt_budget = std::max<int32_t>(1, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(tokenCount));
if (tokenCount > prompt_budget) {
spdlog::warn(
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
"to fit n_batch/n_ctx limits",
tokenCount, prompt_budget);
promptTokens.resize(static_cast<std::size_t>(prompt_budget));
tokenCount = prompt_budget;
}
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0)
throw std::runtime_error("LlamaGenerator: prompt decode failed");
llama_sampler_chain_params sampler_params =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
&llama_sampler_free);
if (!sampler)
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_temp(sampling_temperature_));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_top_p(sampling_top_p_, 1));
llama_sampler_chain_add(sampler.get(),
llama_sampler_init_dist(sampling_seed_));
std::vector<llama_token> generated_tokens;
generated_tokens.reserve(static_cast<std::size_t>(max_tokens));
for (int i = 0; i < effective_max_tokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next))
break;
generated_tokens.push_back(next);
llama_token token = next;
const llama_batch one_token_batch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, one_token_batch) != 0)
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
std::string output;
for (const llama_token token : generated_tokens)
appendTokenPiece(vocab, token, output);
return output;
}
BreweryResult
LlamaGenerator::GenerateBrewery(const std::string &city_name,
const std::string &country_name,
const std::string &region_context) {
const std::string safe_region_context = PrepareRegionContext(region_context);
const std::string system_prompt =
"You are a copywriter for a craft beer travel guide. "
"Your writing is vivid, specific to place, and avoids generic beer "
"cliches. "
"You must output ONLY valid JSON. "
"The JSON schema must be exactly: {\"name\": \"string\", "
"\"description\": \"string\"}. "
"Do not include markdown formatting or backticks.";
std::string prompt =
"Write a brewery name and place-specific description for a craft "
"brewery in " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string(".")
: std::string(". Regional context: ") + safe_region_context);
const int maxAttempts = 3;
std::string raw;
std::string lastError;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 384);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
std::string name;
std::string description;
const std::string validationError =
ValidateBreweryJson(raw, name, description);
if (validationError.empty()) {
return {std::move(name), std::move(description)};
}
lastError = validationError;
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
attempt + 1, validationError);
prompt = "Your previous response was invalid. Error: " + validationError +
"\nReturn ONLY valid JSON with this exact schema: "
"{\"name\": \"string\", \"description\": \"string\"}."
"\nDo not include markdown, comments, or extra keys."
"\n\nLocation: " +
city_name +
(country_name.empty() ? std::string("")
: std::string(", ") + country_name) +
(safe_region_context.empty()
? std::string("")
: std::string("\nRegional context: ") + safe_region_context);
}
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: "
"{}",
maxAttempts, lastError.empty() ? raw : lastError);
throw std::runtime_error("LlamaGenerator: malformed brewery response");
}
UserResult LlamaGenerator::GenerateUser(const std::string &locale) {
const std::string system_prompt =
"You generate plausible social media profiles for craft beer "
"enthusiasts. "
"Respond with exactly two lines: "
"the first line is a username (lowercase, no spaces, 8-20 characters), "
"the second line is a one-sentence bio (20-40 words). "
"The profile should feel consistent with the locale. "
"No preamble, no labels.";
std::string prompt =
"Generate a craft beer enthusiast profile. Locale: " + locale;
const int maxAttempts = 3;
std::string raw;
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
raw = Infer(system_prompt, prompt, 128);
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
attempt + 1, raw);
try {
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
if (bio.size() > 200)
bio = bio.substr(0, 200);
return {username, bio};
} catch (const std::exception &e) {
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
attempt + 1, e.what());
}
}
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
maxAttempts, raw);
throw std::runtime_error("LlamaGenerator: malformed user response");
}

View File

@@ -1,4 +1,4 @@
#include "mock_generator.h" #include "data_generation/mock_generator.h"
#include <functional> #include <functional>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
@@ -64,11 +64,11 @@ const std::vector<std::string> MockGenerator::kBios = {
"Always ready to trade recommendations for underrated local breweries.", "Always ready to trade recommendations for underrated local breweries.",
"Keeping a running list of must-try collab releases and tap takeovers."}; "Keeping a running list of must-try collab releases and tap takeovers."};
void MockGenerator::load(const std::string & /*modelPath*/) { void MockGenerator::Load(const std::string & /*modelPath*/) {
spdlog::info("[MockGenerator] No model needed"); spdlog::info("[MockGenerator] No model needed");
} }
std::size_t MockGenerator::deterministicHash(const std::string &a, std::size_t MockGenerator::DeterministicHash(const std::string &a,
const std::string &b) { const std::string &b) {
std::size_t seed = std::hash<std::string>{}(a); std::size_t seed = std::hash<std::string>{}(a);
const std::size_t mixed = std::hash<std::string>{}(b); const std::size_t mixed = std::hash<std::string>{}(b);
@@ -77,14 +77,14 @@ std::size_t MockGenerator::deterministicHash(const std::string &a,
return seed; return seed;
} }
BreweryResult MockGenerator::generateBrewery(const std::string &cityName, BreweryResult MockGenerator::GenerateBrewery(const std::string &city_name,
const std::string &countryName, const std::string &country_name,
const std::string &regionContext) { const std::string &region_context) {
const std::string locationKey = const std::string location_key =
countryName.empty() ? cityName : cityName + "," + countryName; country_name.empty() ? city_name : city_name + "," + country_name;
const std::size_t hash = regionContext.empty() const std::size_t hash = region_context.empty()
? std::hash<std::string>{}(locationKey) ? std::hash<std::string>{}(location_key)
: deterministicHash(locationKey, regionContext); : DeterministicHash(location_key, region_context);
BreweryResult result; BreweryResult result;
result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " + result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " +
@@ -94,7 +94,7 @@ BreweryResult MockGenerator::generateBrewery(const std::string &cityName,
return result; return result;
} }
UserResult MockGenerator::generateUser(const std::string &locale) { UserResult MockGenerator::GenerateUser(const std::string &locale) {
const std::size_t hash = std::hash<std::string>{}(locale); const std::size_t hash = std::hash<std::string>{}(locale);
UserResult result; UserResult result;

View File

@@ -1,9 +1,9 @@
#include "database.h" #include "database/database.h"
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
void SqliteDatabase::InitializeSchema() { void SqliteDatabase::InitializeSchema() {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
const char *schema = R"( const char *schema = R"(
CREATE TABLE IF NOT EXISTS countries ( CREATE TABLE IF NOT EXISTS countries (
@@ -34,7 +34,7 @@ void SqliteDatabase::InitializeSchema() {
)"; )";
char *errMsg = nullptr; char *errMsg = nullptr;
int rc = sqlite3_exec(db, schema, nullptr, nullptr, &errMsg); int rc = sqlite3_exec(db_, schema, nullptr, nullptr, &errMsg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
std::string error = errMsg ? std::string(errMsg) : "Unknown error"; std::string error = errMsg ? std::string(errMsg) : "Unknown error";
sqlite3_free(errMsg); sqlite3_free(errMsg);
@@ -43,24 +43,24 @@ void SqliteDatabase::InitializeSchema() {
} }
SqliteDatabase::~SqliteDatabase() { SqliteDatabase::~SqliteDatabase() {
if (db) { if (db_) {
sqlite3_close(db); sqlite3_close(db_);
} }
} }
void SqliteDatabase::Initialize(const std::string &dbPath) { void SqliteDatabase::Initialize(const std::string &db_path) {
int rc = sqlite3_open(dbPath.c_str(), &db); int rc = sqlite3_open(db_path.c_str(), &db_);
if (rc) { if (rc) {
throw std::runtime_error("Failed to open SQLite database: " + dbPath); throw std::runtime_error("Failed to open SQLite database: " + db_path);
} }
spdlog::info("OK: SQLite database opened: {}", dbPath); spdlog::info("OK: SQLite database opened: {}", db_path);
InitializeSchema(); InitializeSchema();
} }
void SqliteDatabase::BeginTransaction() { void SqliteDatabase::BeginTransaction() {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char *err = nullptr;
if (sqlite3_exec(db, "BEGIN TRANSACTION", nullptr, nullptr, &err) != if (sqlite3_exec(db_, "BEGIN TRANSACTION", nullptr, nullptr, &err) !=
SQLITE_OK) { SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
@@ -69,9 +69,9 @@ void SqliteDatabase::BeginTransaction() {
} }
void SqliteDatabase::CommitTransaction() { void SqliteDatabase::CommitTransaction() {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
char *err = nullptr; char *err = nullptr;
if (sqlite3_exec(db, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) { if (sqlite3_exec(db_, "COMMIT", nullptr, nullptr, &err) != SQLITE_OK) {
std::string msg = err ? err : "unknown"; std::string msg = err ? err : "unknown";
sqlite3_free(err); sqlite3_free(err);
throw std::runtime_error("CommitTransaction failed: " + msg); throw std::runtime_error("CommitTransaction failed: " + msg);
@@ -81,7 +81,7 @@ void SqliteDatabase::CommitTransaction() {
void SqliteDatabase::InsertCountry(int id, const std::string &name, void SqliteDatabase::InsertCountry(int id, const std::string &name,
const std::string &iso2, const std::string &iso2,
const std::string &iso3) { const std::string &iso3) {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char *query = R"(
INSERT OR IGNORE INTO countries (id, name, iso2, iso3) INSERT OR IGNORE INTO countries (id, name, iso2, iso3)
@@ -89,7 +89,7 @@ void SqliteDatabase::InsertCountry(int id, const std::string &name,
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt *stmt;
int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare country insert"); throw std::runtime_error("Failed to prepare country insert");
@@ -104,9 +104,9 @@ void SqliteDatabase::InsertCountry(int id, const std::string &name,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertState(int id, int countryId, const std::string &name, void SqliteDatabase::InsertState(int id, int country_id, const std::string &name,
const std::string &iso2) { const std::string &iso2) {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char *query = R"(
INSERT OR IGNORE INTO states (id, country_id, name, iso2) INSERT OR IGNORE INTO states (id, country_id, name, iso2)
@@ -114,12 +114,12 @@ void SqliteDatabase::InsertState(int id, int countryId, const std::string &name,
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt *stmt;
int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare state insert"); throw std::runtime_error("Failed to prepare state insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, countryId); sqlite3_bind_int(stmt, 2, country_id);
sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, iso2.c_str(), -1, SQLITE_STATIC);
@@ -129,10 +129,10 @@ void SqliteDatabase::InsertState(int id, int countryId, const std::string &name,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
void SqliteDatabase::InsertCity(int id, int stateId, int countryId, void SqliteDatabase::InsertCity(int id, int state_id, int country_id,
const std::string &name, double latitude, const std::string &name, double latitude,
double longitude) { double longitude) {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
const char *query = R"( const char *query = R"(
INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude) INSERT OR IGNORE INTO cities (id, state_id, country_id, name, latitude, longitude)
@@ -140,13 +140,13 @@ void SqliteDatabase::InsertCity(int id, int stateId, int countryId,
)"; )";
sqlite3_stmt *stmt; sqlite3_stmt *stmt;
int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) if (rc != SQLITE_OK)
throw std::runtime_error("Failed to prepare city insert"); throw std::runtime_error("Failed to prepare city insert");
sqlite3_bind_int(stmt, 1, id); sqlite3_bind_int(stmt, 1, id);
sqlite3_bind_int(stmt, 2, stateId); sqlite3_bind_int(stmt, 2, state_id);
sqlite3_bind_int(stmt, 3, countryId); sqlite3_bind_int(stmt, 3, country_id);
sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 4, name.c_str(), -1, SQLITE_STATIC);
sqlite3_bind_double(stmt, 5, latitude); sqlite3_bind_double(stmt, 5, latitude);
sqlite3_bind_double(stmt, 6, longitude); sqlite3_bind_double(stmt, 6, longitude);
@@ -157,14 +157,13 @@ void SqliteDatabase::InsertCity(int id, int stateId, int countryId,
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
} }
std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() { std::vector<City> SqliteDatabase::QueryCities() {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<City> cities;
std::vector<std::pair<int, std::string>> cities;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt *stmt = nullptr;
const char *query = "SELECT id, name FROM cities ORDER BY name"; const char *query = "SELECT id, name, country_id FROM cities ORDER BY name";
int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query, -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare query"); throw std::runtime_error("Failed to prepare query");
@@ -174,7 +173,8 @@ std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() {
int id = sqlite3_column_int(stmt, 0); int id = sqlite3_column_int(stmt, 0);
const char *name = const char *name =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
cities.push_back({id, name ? std::string(name) : ""}); int country_id = sqlite3_column_int(stmt, 2);
cities.push_back({id, name ? std::string(name) : "", country_id});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);
@@ -182,7 +182,7 @@ std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() {
} }
std::vector<Country> SqliteDatabase::QueryCountries(int limit) { std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<Country> countries; std::vector<Country> countries;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt *stmt = nullptr;
@@ -193,7 +193,7 @@ std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
query += " LIMIT " + std::to_string(limit); query += " LIMIT " + std::to_string(limit);
} }
int rc = sqlite3_prepare_v2(db, query.c_str(), -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare countries query"); throw std::runtime_error("Failed to prepare countries query");
@@ -217,7 +217,7 @@ std::vector<Country> SqliteDatabase::QueryCountries(int limit) {
} }
std::vector<State> SqliteDatabase::QueryStates(int limit) { std::vector<State> SqliteDatabase::QueryStates(int limit) {
std::lock_guard<std::mutex> lock(dbMutex); std::lock_guard<std::mutex> lock(db_mutex_);
std::vector<State> states; std::vector<State> states;
sqlite3_stmt *stmt = nullptr; sqlite3_stmt *stmt = nullptr;
@@ -228,7 +228,7 @@ std::vector<State> SqliteDatabase::QueryStates(int limit) {
query += " LIMIT " + std::to_string(limit); query += " LIMIT " + std::to_string(limit);
} }
int rc = sqlite3_prepare_v2(db, query.c_str(), -1, &stmt, nullptr); int rc = sqlite3_prepare_v2(db_, query.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
throw std::runtime_error("Failed to prepare states query"); throw std::runtime_error("Failed to prepare states query");
@@ -240,9 +240,9 @@ std::vector<State> SqliteDatabase::QueryStates(int limit) {
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1)); reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
const char *iso2 = const char *iso2 =
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2)); reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
int countryId = sqlite3_column_int(stmt, 3); int country_id = sqlite3_column_int(stmt, 3);
states.push_back({id, name ? std::string(name) : "", states.push_back({id, name ? std::string(name) : "",
iso2 ? std::string(iso2) : "", countryId}); iso2 ? std::string(iso2) : "", country_id});
} }
sqlite3_finalize(stmt); sqlite3_finalize(stmt);

View File

@@ -0,0 +1,65 @@
#include <chrono>
#include <spdlog/spdlog.h>
#include "json_handling/json_loader.h"
#include "json_handling/stream_parser.h"
void JsonLoader::LoadWorldCities(const std::string &json_path,
SqliteDatabase &db) {
constexpr size_t kBatchSize = 10000;
auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", json_path);
db.BeginTransaction();
bool transactionOpen = true;
size_t citiesProcessed = 0;
try {
StreamingJsonParser::Parse(
json_path, db,
[&](const CityRecord &record) {
db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude);
++citiesProcessed;
if (citiesProcessed % kBatchSize == 0) {
db.CommitTransaction();
db.BeginTransaction();
}
},
[&](size_t current, size_t /*total*/) {
if (current % kBatchSize == 0 && current > 0) {
spdlog::info(" [Progress] Parsed {} cities...", current);
}
});
spdlog::info(" OK: Parsed all cities from JSON");
if (transactionOpen) {
db.CommitTransaction();
transactionOpen = false;
}
} catch (...) {
if (transactionOpen) {
db.CommitTransaction();
}
throw;
}
auto endTime = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
endTime - startTime);
spdlog::info("\n=== World City Data Loading Summary ===\n");
spdlog::info("Cities inserted: {}", citiesProcessed);
spdlog::info("Elapsed time: {} ms", duration.count());
long long throughput =
(citiesProcessed > 0 && duration.count() > 0)
? (1000LL * static_cast<long long>(citiesProcessed)) /
static_cast<long long>(duration.count())
: 0LL;
spdlog::info("Throughput: {} cities/sec", throughput);
spdlog::info("=======================================\n");
}

View File

@@ -1,15 +1,22 @@
#include "stream_parser.h"
#include "database.h"
#include <cstdio> #include <cstdio>
#include <rapidjson/filereadstream.h> #include <stdexcept>
#include <rapidjson/reader.h>
#include <rapidjson/stringbuffer.h> #include <boost/json.hpp>
#include <boost/json/basic_parser_impl.hpp>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
using namespace rapidjson; #include "database/database.h"
#include "json_handling/stream_parser.h"
class CityRecordHandler {
friend class boost::json::basic_parser<CityRecordHandler>;
class CityRecordHandler : public BaseReaderHandler<UTF8<>, CityRecordHandler> {
public: public:
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
struct ParseContext { struct ParseContext {
SqliteDatabase *db = nullptr; SqliteDatabase *db = nullptr;
std::function<void(const CityRecord &)> on_city; std::function<void(const CityRecord &)> on_city;
@@ -20,11 +27,35 @@ public:
int states_inserted = 0; int states_inserted = 0;
}; };
CityRecordHandler(ParseContext &ctx) : context(ctx) {} explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {}
bool StartArray() { private:
ParseContext &context;
int depth = 0;
bool in_countries_array = false;
bool in_country_object = false;
bool in_states_array = false;
bool in_state_object = false;
bool in_cities_array = false;
bool building_city = false;
int current_country_id = 0;
int current_state_id = 0;
CityRecord current_city = {};
std::string current_key;
std::string current_key_val;
std::string current_string_val;
std::string country_info[3];
std::string state_info[2];
// Boost.JSON SAX Hooks
bool on_document_begin(boost::system::error_code &) { return true; }
bool on_document_end(boost::system::error_code &) { return true; }
bool on_array_begin(boost::system::error_code &) {
depth++; depth++;
if (depth == 1) { if (depth == 1) {
in_countries_array = true; in_countries_array = true;
} else if (depth == 3 && current_key == "states") { } else if (depth == 3 && current_key == "states") {
@@ -35,7 +66,7 @@ public:
return true; return true;
} }
bool EndArray(SizeType /*elementCount*/) { bool on_array_end(std::size_t, boost::system::error_code &) {
if (depth == 1) { if (depth == 1) {
in_countries_array = false; in_countries_array = false;
} else if (depth == 3) { } else if (depth == 3) {
@@ -47,9 +78,8 @@ public:
return true; return true;
} }
bool StartObject() { bool on_object_begin(boost::system::error_code &) {
depth++; depth++;
if (depth == 2 && in_countries_array) { if (depth == 2 && in_countries_array) {
in_country_object = true; in_country_object = true;
current_country_id = 0; current_country_id = 0;
@@ -68,7 +98,7 @@ public:
return true; return true;
} }
bool EndObject(SizeType /*memberCount*/) { bool on_object_end(std::size_t, boost::system::error_code &) {
if (depth == 6 && building_city) { if (depth == 6 && building_city) {
if (current_city.id > 0 && current_state_id > 0 && if (current_city.id > 0 && current_state_id > 0 &&
current_country_id > 0) { current_country_id > 0) {
@@ -84,7 +114,7 @@ public:
context.total_file_size); context.total_file_size);
} }
} catch (const std::exception &e) { } catch (const std::exception &e) {
spdlog::warn(" WARN: Failed to emit city: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
building_city = false; building_city = false;
@@ -95,7 +125,7 @@ public:
state_info[0], state_info[1]); state_info[0], state_info[1]);
context.states_inserted++; context.states_inserted++;
} catch (const std::exception &e) { } catch (const std::exception &e) {
spdlog::warn(" WARN: Failed to insert state: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
in_state_object = false; in_state_object = false;
@@ -106,7 +136,7 @@ public:
country_info[1], country_info[2]); country_info[1], country_info[2]);
context.countries_inserted++; context.countries_inserted++;
} catch (const std::exception &e) { } catch (const std::exception &e) {
spdlog::warn(" WARN: Failed to insert country: {}", e.what()); spdlog::warn("Record parsing failed: {}", e.what());
} }
} }
in_country_object = false; in_country_object = false;
@@ -116,46 +146,71 @@ public:
return true; return true;
} }
bool Key(const char *str, SizeType len, bool /*copy*/) { bool on_key_part(boost::json::string_view s, std::size_t,
current_key.assign(str, len); boost::system::error_code &) {
current_key_val.append(s.data(), s.size());
return true; return true;
} }
bool String(const char *str, SizeType len, bool /*copy*/) { bool on_key(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_key_val.append(s.data(), s.size());
current_key = current_key_val;
current_key_val.clear();
return true;
}
bool on_string_part(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_string_val.append(s.data(), s.size());
return true;
}
bool on_string(boost::json::string_view s, std::size_t,
boost::system::error_code &) {
current_string_val.append(s.data(), s.size());
if (building_city && current_key == "name") { if (building_city && current_key == "name") {
current_city.name.assign(str, len); current_city.name = current_string_val;
} else if (in_state_object && current_key == "name") { } else if (in_state_object && current_key == "name") {
state_info[0].assign(str, len); state_info[0] = current_string_val;
} else if (in_state_object && current_key == "iso2") { } else if (in_state_object && current_key == "iso2") {
state_info[1].assign(str, len); state_info[1] = current_string_val;
} else if (in_country_object && current_key == "name") { } else if (in_country_object && current_key == "name") {
country_info[0].assign(str, len); country_info[0] = current_string_val;
} else if (in_country_object && current_key == "iso2") { } else if (in_country_object && current_key == "iso2") {
country_info[1].assign(str, len); country_info[1] = current_string_val;
} else if (in_country_object && current_key == "iso3") { } else if (in_country_object && current_key == "iso3") {
country_info[2].assign(str, len); country_info[2] = current_string_val;
} }
current_string_val.clear();
return true; return true;
} }
bool Int(int i) { bool on_number_part(boost::json::string_view, boost::system::error_code &) {
return true;
}
bool on_int64(int64_t i, boost::json::string_view,
boost::system::error_code &) {
if (building_city && current_key == "id") { if (building_city && current_key == "id") {
current_city.id = i; current_city.id = static_cast<int>(i);
} else if (in_state_object && current_key == "id") { } else if (in_state_object && current_key == "id") {
current_state_id = i; current_state_id = static_cast<int>(i);
} else if (in_country_object && current_key == "id") { } else if (in_country_object && current_key == "id") {
current_country_id = i; current_country_id = static_cast<int>(i);
} }
return true; return true;
} }
bool Uint(unsigned i) { return Int(static_cast<int>(i)); } bool on_uint64(uint64_t u, boost::json::string_view,
boost::system::error_code &ec) {
return on_int64(static_cast<int64_t>(u), "", ec);
}
bool Int64(int64_t i) { return Int(static_cast<int>(i)); } bool on_double(double d, boost::json::string_view,
boost::system::error_code &) {
bool Uint64(uint64_t i) { return Int(static_cast<int>(i)); }
bool Double(double d) {
if (building_city) { if (building_city) {
if (current_key == "latitude") { if (current_key == "latitude") {
current_city.latitude = d; current_city.latitude = d;
@@ -166,39 +221,26 @@ public:
return true; return true;
} }
bool Bool(bool /*b*/) { return true; } bool on_bool(bool, boost::system::error_code &) { return true; }
bool Null() { return true; } bool on_null(boost::system::error_code &) { return true; }
bool on_comment_part(boost::json::string_view, boost::system::error_code &) {
private: return true;
ParseContext &context; }
bool on_comment(boost::json::string_view, boost::system::error_code &) {
int depth = 0; return true;
bool in_countries_array = false; }
bool in_country_object = false;
bool in_states_array = false;
bool in_state_object = false;
bool in_cities_array = false;
bool building_city = false;
int current_country_id = 0;
int current_state_id = 0;
CityRecord current_city = {};
std::string current_key;
std::string country_info[3];
std::string state_info[2];
}; };
void StreamingJsonParser::Parse( void StreamingJsonParser::Parse(
const std::string &filePath, SqliteDatabase &db, const std::string &file_path, SqliteDatabase &db,
std::function<void(const CityRecord &)> onCity, std::function<void(const CityRecord &)> on_city,
std::function<void(size_t, size_t)> onProgress) { std::function<void(size_t, size_t)> on_progress) {
spdlog::info(" Streaming parse of {}...", filePath); spdlog::info(" Streaming parse of {} (Boost.JSON)...", file_path);
FILE *file = std::fopen(filePath.c_str(), "rb"); FILE *file = std::fopen(file_path.c_str(), "rb");
if (!file) { if (!file) {
throw std::runtime_error("Failed to open JSON file: " + filePath); throw std::runtime_error("Failed to open JSON file: " + file_path);
} }
size_t total_size = 0; size_t total_size = 0;
@@ -210,25 +252,37 @@ void StreamingJsonParser::Parse(
std::rewind(file); std::rewind(file);
} }
CityRecordHandler::ParseContext ctx{&db, onCity, onProgress, 0, CityRecordHandler::ParseContext ctx{&db, on_city, on_progress, 0,
total_size, 0, 0}; total_size, 0, 0};
CityRecordHandler handler(ctx); boost::json::basic_parser<CityRecordHandler> parser(
boost::json::parse_options{}, ctx);
Reader reader;
char buf[65536]; char buf[65536];
FileReadStream frs(file, buf, sizeof(buf)); size_t bytes_read;
boost::system::error_code ec;
if (!reader.Parse(frs, handler)) { while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
ParseErrorCode errCode = reader.GetParseErrorCode(); char const *p = buf;
size_t errOffset = reader.GetErrorOffset(); std::size_t remain = bytes_read;
std::fclose(file);
throw std::runtime_error(std::string("JSON parse error at offset ") + while (remain > 0) {
std::to_string(errOffset) + std::size_t consumed = parser.write_some(true, p, remain, ec);
" (code: " + std::to_string(errCode) + ")"); if (ec) {
std::fclose(file);
throw std::runtime_error("JSON parse error: " + ec.message());
}
p += consumed;
remain -= consumed;
}
} }
parser.write_some(false, nullptr, 0, ec); // Signal EOF
std::fclose(file); std::fclose(file);
if (ec) {
throw std::runtime_error("JSON parse error at EOF: " + ec.message());
}
spdlog::info(" OK: Parsed {} countries, {} states, {} cities", spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted);
} }

View File

@@ -1,45 +0,0 @@
#include "json_loader.h"
#include "stream_parser.h"
#include <chrono>
#include <spdlog/spdlog.h>
void JsonLoader::LoadWorldCities(const std::string &jsonPath,
SqliteDatabase &db) {
auto startTime = std::chrono::high_resolution_clock::now();
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", jsonPath);
db.BeginTransaction();
size_t citiesProcessed = 0;
StreamingJsonParser::Parse(
jsonPath, db,
[&](const CityRecord &record) {
db.InsertCity(record.id, record.state_id, record.country_id,
record.name, record.latitude, record.longitude);
citiesProcessed++;
},
[&](size_t current, size_t total) {
if (current % 10000 == 0 && current > 0) {
spdlog::info(" [Progress] Parsed {} cities...", current);
}
});
spdlog::info(" OK: Parsed all cities from JSON");
db.CommitTransaction();
auto endTime = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
endTime - startTime);
spdlog::info("\n=== World City Data Loading Summary ===\n");
spdlog::info("Cities inserted: {}", citiesProcessed);
spdlog::info("Elapsed time: {} ms", duration.count());
long long throughput =
(citiesProcessed > 0 && duration.count() > 0)
? (1000LL * static_cast<long long>(citiesProcessed)) /
static_cast<long long>(duration.count())
: 0LL;
spdlog::info("Throughput: {} cities/sec", throughput);
spdlog::info("=======================================\n");
}

View File

@@ -1,534 +0,0 @@
#include "llama_generator.h"
#include "llama.h"
#include <algorithm>
#include <array>
#include <cctype>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include <spdlog/spdlog.h>
namespace {
std::string trim(std::string value) {
auto notSpace = [](unsigned char ch) { return !std::isspace(ch); };
value.erase(value.begin(),
std::find_if(value.begin(), value.end(), notSpace));
value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(),
value.end());
return value;
}
std::string stripCommonPrefix(std::string line) {
line = trim(std::move(line));
// Strip simple list markers like "- ", "* ", "1. ", "2) ".
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
line = trim(line.substr(1));
} else {
std::size_t i = 0;
while (i < line.size() &&
std::isdigit(static_cast<unsigned char>(line[i]))) {
++i;
}
if (i > 0 && i < line.size() && (line[i] == '.' || line[i] == ')')) {
line = trim(line.substr(i + 1));
}
}
auto stripLabel = [&line](const std::string &label) {
if (line.size() >= label.size()) {
bool matches = true;
for (std::size_t i = 0; i < label.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(line[i])) !=
std::tolower(static_cast<unsigned char>(label[i]))) {
matches = false;
break;
}
}
if (matches) {
line = trim(line.substr(label.size()));
}
}
};
stripLabel("name:");
stripLabel("brewery name:");
stripLabel("description:");
stripLabel("username:");
stripLabel("bio:");
return trim(std::move(line));
}
std::string toChatPrompt(const llama_model *model,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
return userPrompt;
}
const llama_chat_message message{
"user",
userPrompt.c_str(),
};
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
int32_t required =
llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, &message, 1, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
std::string toChatPrompt(const llama_model *model,
const std::string &systemPrompt,
const std::string &userPrompt) {
const char *tmpl = llama_model_chat_template(model, nullptr);
if (tmpl == nullptr) {
// Fall back to concatenating but keep system and user parts distinct.
return systemPrompt + "\n\n" + userPrompt;
}
const llama_chat_message messages[2] = {
{"system", systemPrompt.c_str()},
{"user", userPrompt.c_str()},
};
std::vector<char> buffer(std::max<std::size_t>(
1024, (systemPrompt.size() + userPrompt.size()) * 4));
int32_t required =
llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
if (required >= static_cast<int32_t>(buffer.size())) {
buffer.resize(static_cast<std::size_t>(required) + 1);
required = llama_chat_apply_template(tmpl, messages, 2, true, buffer.data(),
static_cast<int32_t>(buffer.size()));
if (required < 0) {
throw std::runtime_error("LlamaGenerator: failed to apply chat template");
}
}
return std::string(buffer.data(), static_cast<std::size_t>(required));
}
void appendTokenPiece(const llama_vocab *vocab, llama_token token,
std::string &output) {
std::array<char, 256> buffer{};
int32_t bytes =
llama_token_to_piece(vocab, token, buffer.data(),
static_cast<int32_t>(buffer.size()), 0, true);
if (bytes < 0) {
std::vector<char> dynamicBuffer(static_cast<std::size_t>(-bytes));
bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(),
static_cast<int32_t>(dynamicBuffer.size()), 0,
true);
if (bytes < 0) {
throw std::runtime_error(
"LlamaGenerator: failed to decode sampled token piece");
}
output.append(dynamicBuffer.data(), static_cast<std::size_t>(bytes));
return;
}
output.append(buffer.data(), static_cast<std::size_t>(bytes));
}
std::pair<std::string, std::string>
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
std::string normalized = raw;
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
std::vector<std::string> lines;
std::stringstream stream(normalized);
std::string line;
while (std::getline(stream, line)) {
line = stripCommonPrefix(std::move(line));
if (!line.empty()) {
lines.push_back(std::move(line));
}
}
// Filter out obvious internal-thought / meta lines that sometimes leak from
// models (e.g. "<think>", "Okay, so the user is asking me...").
std::vector<std::string> filtered;
for (auto &l : lines) {
std::string low = l;
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
// Skip single-token angle-bracket markers like <think> or <...>
if (!l.empty() && l.front() == '<' && l.back() == '>') {
continue;
}
// Skip short internal commentary that starts with common discourse markers
if (low.rfind("okay,", 0) == 0 || low.rfind("wait,", 0) == 0 ||
low.rfind("hmm", 0) == 0) {
continue;
}
// Skip lines that look like self-descriptions of what the model is doing
if (low.find("user is asking") != std::string::npos ||
low.find("protocol") != std::string::npos ||
low.find("parse") != std::string::npos ||
low.find("return only") != std::string::npos) {
continue;
}
filtered.push_back(std::move(l));
}
if (filtered.size() < 2) {
throw std::runtime_error(errorMessage);
}
std::string first = trim(filtered.front());
std::string second;
for (std::size_t i = 1; i < filtered.size(); ++i) {
if (!second.empty()) {
second += ' ';
}
second += filtered[i];
}
second = trim(std::move(second));
if (first.empty() || second.empty()) {
throw std::runtime_error(errorMessage);
}
return {first, second};
}
} // namespace
LlamaGenerator::~LlamaGenerator() {
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_free();
}
void LlamaGenerator::load(const std::string &modelPath) {
if (modelPath.empty()) {
throw std::runtime_error("LlamaGenerator: model path must not be empty");
}
if (context_ != nullptr) {
llama_free(context_);
context_ = nullptr;
}
if (model_ != nullptr) {
llama_model_free(model_);
model_ = nullptr;
}
llama_backend_init();
llama_model_params modelParams = llama_model_default_params();
model_ = llama_load_model_from_file(modelPath.c_str(), modelParams);
if (model_ == nullptr) {
throw std::runtime_error(
"LlamaGenerator: failed to load model from path: " + modelPath);
}
llama_context_params contextParams = llama_context_default_params();
contextParams.n_ctx = 2048;
context_ = llama_init_from_model(model_, contextParams);
if (context_ == nullptr) {
llama_model_free(model_);
model_ = nullptr;
throw std::runtime_error("LlamaGenerator: failed to create context");
}
spdlog::info("[LlamaGenerator] Loaded model: {}", modelPath);
}
std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
if (model_ == nullptr || context_ == nullptr) {
throw std::runtime_error("LlamaGenerator: model not loaded");
}
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
llama_memory_clear(llama_get_memory(context_), true);
const std::string formattedPrompt = toChatPrompt(model_, prompt);
std::vector<llama_token> promptTokens(formattedPrompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formattedPrompt.c_str(),
static_cast<int32_t>(formattedPrompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formattedPrompt.c_str(),
static_cast<int32_t>(formattedPrompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0) {
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
promptTokens.resize(static_cast<std::size_t>(tokenCount));
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0) {
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
llama_sampler_chain_params samplerParams =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(samplerParams),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy());
std::vector<llama_token> generatedTokens;
generatedTokens.reserve(static_cast<std::size_t>(maxTokens));
for (int i = 0; i < maxTokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) {
break;
}
generatedTokens.push_back(next);
llama_token token = next;
const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, oneTokenBatch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
std::string output;
for (const llama_token token : generatedTokens) {
appendTokenPiece(vocab, token, output);
}
return output;
}
BreweryResult
LlamaGenerator::generateBrewery(const std::string &cityName,
const std::string &countryName,
const std::string &regionContext) {
std::string systemPrompt =
R"(# SYSTEM PROTOCOL: ZERO-CHATTER DETERMINISTIC OUTPUT
**MODALITY:** DATA-RETURN ENGINE ONLY
**ROLE:** Your response must contain 0% metadata and 100% signal.
---
## MANDATORY CONSTRAINTS
1. **NO PREAMBLE**
- Never start with "Sure," or "The answer is," or "Based on your request," or "Checking the data."
- Do not acknowledge the user's prompt or provide status updates.
2. **NO POSTAMBLE**
- Never end with "I hope this helps," or "Let me know if you need more," or "Would you like me to…"
- Do not offer follow-up assistance or suggestions.
3. **NO SENTENCE FRAMING**
- Provide only the raw value, date, number, or name.
- Do not wrap the answer in a sentence. (e.g., return 1997, NOT The year was 1997).
- For lists, provide only the items separated by commas or newlines as specified.
4. **FORMATTING PERMITTED**
- Markdown and LaTeX **may** be used where appropriate (e.g., tables, equations).
- Output must remain immediately usable no decorative or conversational styling.
5. **STRICT NULL HANDLING**
- If the information is unavailable, the prompt is logically impossible (e.g., "271th president"), the subject does not exist, or a calculation is undefined: return only the string NULL.
- If the prompt is too ambiguous to provide a single value: return NULL.
---
## EXECUTION LOGIC
1. **Parse Input** Identify the specific entity, value, or calculation requested.
2. **Verify Factuality** Access internal knowledge or tools.
3. **Filter for Signal** Strip all surrounding prose.
4. **Format Check** Apply Markdown or LaTeX only where it serves the data.
5. **Output** Return the raw value only.
---
## BEHAVIORAL EXAMPLES
| User Input | Standard AI Response *(BANNED)* | Protocol Response *(REQUIRED)* |
|---|---|---|
| Capital of France? | The capital of France is Paris. | Paris |
| 15% of 200 | 15% of 200 is 30. | 30 |
| Who wrote '1984'? | George Orwell wrote that novel. | George Orwell |
| ISO code for Japan | The code is JP. | JP |
| $\sqrt{x}$ where $x$ is a potato | A potato has no square root. | NULL |
| 500th US President | There haven't been that many. | NULL |
| Pythagorean theorem | The theorem states... | $a^2 + b^2 = c^2$ |
---
## FINAL INSTRUCTION
Total silence is preferred over conversational error. Any deviation from the raw-value-only format is a protocol failure. Proceed with next input.)";
std::string prompt =
"Generate a craft brewery name and 1000 character description for a "
"brewery located in " +
cityName +
(countryName.empty() ? std::string("")
: std::string(", ") + countryName) +
". " + regionContext +
" Respond with exactly two lines: first line is the name, second line is "
"the description. Do not include bullets, numbering, or any extra text.";
const std::string raw = infer(systemPrompt, prompt, 512);
auto [name, description] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed brewery response");
return {name, description};
}
std::string LlamaGenerator::infer(const std::string &systemPrompt,
const std::string &prompt, int maxTokens) {
if (model_ == nullptr || context_ == nullptr) {
throw std::runtime_error("LlamaGenerator: model not loaded");
}
const llama_vocab *vocab = llama_model_get_vocab(model_);
if (vocab == nullptr) {
throw std::runtime_error("LlamaGenerator: vocab unavailable");
}
llama_memory_clear(llama_get_memory(context_), true);
const std::string formattedPrompt =
toChatPrompt(model_, systemPrompt, prompt);
std::vector<llama_token> promptTokens(formattedPrompt.size() + 8);
int32_t tokenCount = llama_tokenize(
vocab, formattedPrompt.c_str(),
static_cast<int32_t>(formattedPrompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
if (tokenCount < 0) {
promptTokens.resize(static_cast<std::size_t>(-tokenCount));
tokenCount = llama_tokenize(
vocab, formattedPrompt.c_str(),
static_cast<int32_t>(formattedPrompt.size()), promptTokens.data(),
static_cast<int32_t>(promptTokens.size()), true, true);
}
if (tokenCount < 0) {
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
}
promptTokens.resize(static_cast<std::size_t>(tokenCount));
const llama_batch promptBatch = llama_batch_get_one(
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
if (llama_decode(context_, promptBatch) != 0) {
throw std::runtime_error("LlamaGenerator: prompt decode failed");
}
llama_sampler_chain_params samplerParams =
llama_sampler_chain_default_params();
using SamplerPtr =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
SamplerPtr sampler(llama_sampler_chain_init(samplerParams),
&llama_sampler_free);
if (!sampler) {
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
}
llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy());
std::vector<llama_token> generatedTokens;
generatedTokens.reserve(static_cast<std::size_t>(maxTokens));
for (int i = 0; i < maxTokens; ++i) {
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
if (llama_vocab_is_eog(vocab, next)) {
break;
}
generatedTokens.push_back(next);
llama_token token = next;
const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1);
if (llama_decode(context_, oneTokenBatch) != 0) {
throw std::runtime_error(
"LlamaGenerator: decode failed during generation");
}
}
std::string output;
for (const llama_token token : generatedTokens) {
appendTokenPiece(vocab, token, output);
}
return output;
}
UserResult LlamaGenerator::generateUser(const std::string &locale) {
std::string prompt =
"Generate a plausible craft beer enthusiast username and a one-sentence "
"bio. Locale: " +
locale +
". Respond with exactly two lines: first line is the username (no "
"spaces), second line is the bio. Do not include bullets, numbering, "
"or any extra text.";
const std::string raw = infer(prompt, 128);
auto [username, bio] =
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
username.erase(
std::remove_if(username.begin(), username.end(),
[](unsigned char ch) { return std::isspace(ch); }),
username.end());
if (username.empty() || bio.empty()) {
throw std::runtime_error("LlamaGenerator: malformed user response");
}
return {username, bio};
}

View File

@@ -1,106 +1,118 @@
#include "data_downloader.h" #include <iostream>
#include "data_generator.h"
#include "database.h"
#include "json_loader.h"
#include "llama_generator.h"
#include "mock_generator.h"
#include <curl/curl.h>
#include <filesystem>
#include <memory> #include <memory>
#include <spdlog/spdlog.h>
#include <vector>
static bool FileExists(const std::string &filePath) { #include <boost/program_options.hpp>
return std::filesystem::exists(filePath); #include <spdlog/spdlog.h>
#include "biergarten_data_generator.h"
#include "web_client/curl_web_client.h"
#include "database/database.h"
namespace po = boost::program_options;
/**
* @brief Parse command-line arguments into ApplicationOptions.
*
* @param argc Command-line argument count.
* @param argv Command-line arguments.
* @param options Output ApplicationOptions struct.
* @return true if parsing succeeded and should proceed, false otherwise.
*/
bool ParseArguments(int argc, char **argv, ApplicationOptions &options) {
// If no arguments provided, display usage and exit
if (argc == 1) {
std::cout << "Biergarten Pipeline - Geographic Data Pipeline with Brewery Generation\n\n";
std::cout << "Usage: biergarten-pipeline [options]\n\n";
std::cout << "Options:\n";
std::cout << " --mocked Use mocked generator for brewery/user data\n";
std::cout << " --model, -m PATH Path to LLM model file (gguf) for generation\n";
std::cout << " --cache-dir, -c DIR Directory for cached JSON (default: /tmp)\n";
std::cout << " --temperature TEMP LLM sampling temperature 0.0-1.0 (default: 0.8)\n";
std::cout << " --top-p VALUE Nucleus sampling parameter 0.0-1.0 (default: 0.92)\n";
std::cout << " --seed SEED Random seed: -1 for random (default: -1)\n";
std::cout << " --help, -h Show this help message\n\n";
std::cout << "Note: --mocked and --model are mutually exclusive. Exactly one must be provided.\n";
std::cout << "Data source is always pinned to commit c5eb7772 (stable 2026-03-28).\n";
return false;
}
po::options_description desc("Pipeline Options");
desc.add_options()("help,h", "Produce help message")(
"mocked", po::bool_switch(),
"Use mocked generator for brewery/user data")(
"model,m", po::value<std::string>()->default_value(""),
"Path to LLM model (gguf)")(
"cache-dir,c", po::value<std::string>()->default_value("/tmp"),
"Directory for cached JSON")(
"temperature", po::value<float>()->default_value(0.8f),
"Sampling temperature (higher = more random)")(
"top-p", po::value<float>()->default_value(0.92f),
"Nucleus sampling top-p in (0,1] (higher = more random)")(
"seed", po::value<int>()->default_value(-1),
"Sampler seed: -1 for random, otherwise non-negative integer");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
if (vm.count("help")) {
std::cout << desc << "\n";
return false;
}
// Check for mutually exclusive --mocked and --model flags
bool use_mocked = vm["mocked"].as<bool>();
std::string model_path = vm["model"].as<std::string>();
if (use_mocked && !model_path.empty()) {
spdlog::error("ERROR: --mocked and --model are mutually exclusive");
return false;
}
if (!use_mocked && model_path.empty()) {
spdlog::error("ERROR: Either --mocked or --model must be specified");
return false;
}
// Warn if sampling parameters are provided with --mocked
if (use_mocked) {
bool hasTemperature = vm["temperature"].defaulted() == false;
bool hasTopP = vm["top-p"].defaulted() == false;
bool hasSeed = vm["seed"].defaulted() == false;
if (hasTemperature || hasTopP || hasSeed) {
spdlog::warn("WARNING: Sampling parameters (--temperature, --top-p, --seed) are ignored when using --mocked");
}
}
options.use_mocked = use_mocked;
options.model_path = model_path;
options.cache_dir = vm["cache-dir"].as<std::string>();
options.temperature = vm["temperature"].as<float>();
options.top_p = vm["top-p"].as<float>();
options.seed = vm["seed"].as<int>();
// commit is always pinned to c5eb7772
return true;
} }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
try { try {
curl_global_init(CURL_GLOBAL_DEFAULT); const CurlGlobalState curl_state;
std::string modelPath = argc > 1 ? argv[1] : ""; ApplicationOptions options;
std::string cacheDir = argc > 2 ? argv[2] : "/tmp"; if (!ParseArguments(argc, argv, options)) {
std::string commit = return 0;
argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28
std::string countryName = argc > 4 ? argv[4] : "";
std::string jsonPath = cacheDir + "/countries+states+cities.json";
std::string dbPath = cacheDir + "/biergarten-pipeline.db";
bool hasJsonCache = FileExists(jsonPath);
bool hasDbCache = FileExists(dbPath);
SqliteDatabase db;
spdlog::info("Initializing SQLite database at {}...", dbPath);
db.Initialize(dbPath);
if (hasDbCache && hasJsonCache) {
spdlog::info("[Pipeline] Cache hit: skipping download and parse");
} else {
spdlog::info("\n[Pipeline] Downloading geographic data from GitHub...");
DataDownloader downloader;
downloader.DownloadCountriesDatabase(jsonPath, commit);
JsonLoader::LoadWorldCities(jsonPath, db);
} }
spdlog::info("Initializing brewery generator..."); auto webClient = std::make_shared<CURLWebClient>();
std::unique_ptr<IDataGenerator> generator; SqliteDatabase database;
if (modelPath.empty()) {
generator = std::make_unique<MockGenerator>();
spdlog::info("[Generator] Using MockGenerator (no model path provided)");
} else {
generator = std::make_unique<LlamaGenerator>();
spdlog::info("[Generator] Using LlamaGenerator: {}", modelPath);
}
generator->load(modelPath);
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); BiergartenDataGenerator generator(options, webClient, database);
return generator.Run();
auto countries = db.QueryCountries(50);
auto states = db.QueryStates(50);
auto cities = db.QueryCities();
spdlog::info("\nTotal records loaded:");
spdlog::info(" Countries: {}", db.QueryCountries(0).size());
spdlog::info(" States: {}", db.QueryStates(0).size());
spdlog::info(" Cities: {}", cities.size());
struct GeneratedBrewery {
int cityId;
std::string cityName;
BreweryResult brewery;
};
std::vector<GeneratedBrewery> generatedBreweries;
const size_t sampleCount = std::min(size_t(30), cities.size());
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
for (size_t i = 0; i < sampleCount; i++) {
const auto &[cityId, cityName] = cities[i];
auto brewery = generator->generateBrewery(cityName, countryName, "");
generatedBreweries.push_back({cityId, cityName, brewery});
}
spdlog::info("\n=== GENERATED DATA DUMP ===");
for (size_t i = 0; i < generatedBreweries.size(); i++) {
const auto &entry = generatedBreweries[i];
spdlog::info("{}. city_id={} city=\"{}\"", i + 1, entry.cityId,
entry.cityName);
spdlog::info(" brewery_name=\"{}\"", entry.brewery.name);
spdlog::info(" brewery_description=\"{}\"", entry.brewery.description);
}
spdlog::info("\nOK: Pipeline completed successfully");
curl_global_cleanup();
return 0;
} catch (const std::exception &e) { } catch (const std::exception &e) {
spdlog::error("ERROR: Pipeline failed: {}", e.what()); spdlog::error("ERROR: Application failed: {}", e.what());
curl_global_cleanup();
return 1; return 1;
} }
} }

View File

@@ -0,0 +1,139 @@
#include "web_client/curl_web_client.h"
#include <cstdio>
#include <curl/curl.h>
#include <fstream>
#include <memory>
#include <sstream>
#include <stdexcept>
CurlGlobalState::CurlGlobalState() {
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl globally");
}
}
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
namespace {
// curl write callback that appends response data into a std::string
size_t WriteCallbackString(void *contents, size_t size, size_t nmemb,
void *userp) {
size_t realsize = size * nmemb;
auto *s = static_cast<std::string *>(userp);
s->append(static_cast<char *>(contents), realsize);
return realsize;
}
// curl write callback that writes to a file stream
size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb,
void *userp) {
size_t realsize = size * nmemb;
auto *outFile = static_cast<std::ofstream *>(userp);
outFile->write(static_cast<char *>(contents), realsize);
return realsize;
}
// RAII wrapper for CURL handle using unique_ptr
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
CurlHandle create_handle() {
CURL *handle = curl_easy_init();
if (!handle) {
throw std::runtime_error(
"[CURLWebClient] Failed to initialize libcurl handle");
}
return CurlHandle(handle, &curl_easy_cleanup);
}
void set_common_get_options(CURL *curl, const std::string &url,
long connect_timeout, long total_timeout) {
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
}
} // namespace
CURLWebClient::CURLWebClient() {}
CURLWebClient::~CURLWebClient() {}
void CURLWebClient::DownloadToFile(const std::string &url,
const std::string &file_path) {
auto curl = create_handle();
std::ofstream outFile(file_path, std::ios::binary);
if (!outFile.is_open()) {
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " +
file_path);
}
set_common_get_options(curl.get(), url, 30L, 300L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
static_cast<void *>(&outFile));
CURLcode res = curl_easy_perform(curl.get());
outFile.close();
if (res != CURLE_OK) {
std::remove(file_path.c_str());
std::string error = std::string("[CURLWebClient] Download failed: ") +
curl_easy_strerror(res);
throw std::runtime_error(error);
}
long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
std::remove(file_path.c_str());
std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str());
}
}
std::string CURLWebClient::Get(const std::string &url) {
auto curl = create_handle();
std::string response_string;
set_common_get_options(curl.get(), url, 10L, 20L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
std::string error =
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
throw std::runtime_error(error);
}
long httpCode = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
if (httpCode != 200) {
std::stringstream ss;
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
throw std::runtime_error(ss.str());
}
return response_string;
}
std::string CURLWebClient::UrlEncode(const std::string &value) {
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
char *output = curl_easy_escape(nullptr, value.c_str(), 0);
if (output) {
std::string result(output);
curl_free(output);
return result;
}
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
}

View File

@@ -0,0 +1,77 @@
#include "wikipedia/wikipedia_service.h"
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
WikipediaService::WikipediaService(std::shared_ptr<IWebClient> client)
: client_(std::move(client)) {}
std::string WikipediaService::FetchExtract(std::string_view query) {
const std::string encoded = client_->UrlEncode(std::string(query));
const std::string url =
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
"&prop=extracts&explaintext=true&format=json";
const std::string body = client_->Get(url);
boost::system::error_code ec;
boost::json::value doc = boost::json::parse(body, ec);
if (!ec && doc.is_object()) {
auto &pages = doc.at("query").at("pages").get_object();
if (!pages.empty()) {
auto &page = pages.begin()->value().get_object();
if (page.contains("extract") && page.at("extract").is_string()) {
std::string extract(page.at("extract").as_string().c_str());
spdlog::debug("WikipediaService fetched {} chars for '{}'",
extract.size(), query);
return extract;
}
}
}
return {};
}
std::string WikipediaService::GetSummary(std::string_view city,
std::string_view country) {
const std::string key = std::string(city) + "|" + std::string(country);
const auto cacheIt = cache_.find(key);
if (cacheIt != cache_.end()) {
return cacheIt->second;
}
std::string result;
if (!client_) {
cache_.emplace(key, result);
return result;
}
std::string regionQuery(city);
if (!country.empty()) {
regionQuery += ", ";
regionQuery += country;
}
const std::string beerQuery = "beer in " + std::string(city);
try {
const std::string regionExtract = FetchExtract(regionQuery);
const std::string beerExtract = FetchExtract(beerQuery);
if (!regionExtract.empty()) {
result += regionExtract;
}
if (!beerExtract.empty()) {
if (!result.empty())
result += "\n\n";
result += beerExtract;
}
} catch (const std::runtime_error &e) {
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
e.what());
}
cache_.emplace(key, result);
return result;
}