mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-04-05 10:09:03 +00:00
Pipeline: add CURL/WebClient & Wikipedia service
Introduce a pluggable web client interface and concrete CURL implementation: adds IWebClient, CURLWebClient, and CurlGlobalState (headers + curl_web_client.cpp). DataDownloader now accepts an IWebClient and delegates downloads. Add WikipediaService for cached Wikipedia summary lookups. Refactor SqliteDatabase to return full City records and update consumers accordingly. Improve JsonLoader to use batched transactions during streaming parses. Enhance LlamaGenerator with sampling options, increased token limits, JSON extraction/validation, and other parsing helpers. Modernize CMake: set policy/version, add project_options, simplify FetchContent usage (spdlog), require Boost components (program_options/json), list pipeline sources explicitly, and tweak post-build/memcheck targets. Update README to match implementation changes and new CLI/config conventions.
This commit is contained in:
@@ -1,49 +1,52 @@
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(biergarten-pipeline VERSION 0.1.0 LANGUAGES CXX)
|
||||
|
||||
cmake_policy(SET CMP0167 NEW)
|
||||
# Allows older dependencies to configure on newer CMake.
|
||||
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
|
||||
|
||||
# Policies
|
||||
cmake_policy(SET CMP0167 NEW) # FindBoost improvements
|
||||
|
||||
# Global Settings
|
||||
set(CMAKE_CXX_STANDARD 23)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compiler Options & Warnings (Interface Library)
|
||||
# -----------------------------------------------------------------------------
|
||||
add_library(project_options INTERFACE)
|
||||
target_compile_options(project_options INTERFACE
|
||||
$<$<CXX_COMPILER_ID:GNU,Clang>:
|
||||
-Wall -Wextra -Wpedantic -Wshadow -Wconversion -Wsign-conversion -Wunused
|
||||
>
|
||||
$<$<CXX_COMPILER_ID:MSVC>:
|
||||
/W4 /WX /permissive-
|
||||
>
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies
|
||||
# -----------------------------------------------------------------------------
|
||||
find_package(CURL REQUIRED)
|
||||
find_package(Boost REQUIRED COMPONENTS unit_test_framework)
|
||||
find_package(SQLite3 REQUIRED)
|
||||
find_package(Boost 1.75 REQUIRED COMPONENTS program_options json)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# RapidJSON (header-only) for true SAX parsing
|
||||
# Using direct header-only approach without CMakeLists.txt
|
||||
FetchContent_Declare(
|
||||
rapidjson
|
||||
GIT_REPOSITORY https://github.com/Tencent/rapidjson.git
|
||||
GIT_TAG v1.1.0
|
||||
SOURCE_SUBDIR "" # Don't use RapidJSON's CMakeLists.txt
|
||||
)
|
||||
FetchContent_GetProperties(rapidjson)
|
||||
if(NOT rapidjson_POPULATED)
|
||||
FetchContent_Populate(rapidjson)
|
||||
# RapidJSON is header-only; just make include path available
|
||||
endif()
|
||||
|
||||
# spdlog (logging)
|
||||
# spdlog (Logging)
|
||||
FetchContent_Declare(
|
||||
spdlog
|
||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||
GIT_TAG v1.11.0
|
||||
)
|
||||
FetchContent_GetProperties(spdlog)
|
||||
if(NOT spdlog_POPULATED)
|
||||
FetchContent_Populate(spdlog)
|
||||
add_subdirectory(${spdlog_SOURCE_DIR} ${spdlog_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
FetchContent_MakeAvailable(spdlog)
|
||||
|
||||
# llama.cpp (on-device inference)
|
||||
# llama.cpp (LLM Inference)
|
||||
set(LLAMA_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(LLAMA_BUILD_SERVER OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_Declare(
|
||||
llama_cpp
|
||||
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
|
||||
@@ -57,90 +60,53 @@ if(TARGET llama)
|
||||
)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS
|
||||
src/*.cpp
|
||||
# -----------------------------------------------------------------------------
|
||||
# Main Executable
|
||||
# -----------------------------------------------------------------------------
|
||||
set(PIPELINE_SOURCES
|
||||
src/curl_web_client.cpp
|
||||
src/data_downloader.cpp
|
||||
src/database.cpp
|
||||
src/json_loader.cpp
|
||||
src/llama_generator.cpp
|
||||
src/mock_generator.cpp
|
||||
src/stream_parser.cpp
|
||||
src/wikipedia_service.cpp
|
||||
src/main.cpp
|
||||
)
|
||||
|
||||
add_executable(biergarten-pipeline ${SOURCES})
|
||||
add_executable(biergarten-pipeline ${PIPELINE_SOURCES})
|
||||
|
||||
target_include_directories(biergarten-pipeline
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/includes
|
||||
${rapidjson_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/includes
|
||||
${llama_cpp_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
target_link_libraries(biergarten-pipeline
|
||||
PRIVATE
|
||||
project_options
|
||||
CURL::libcurl
|
||||
Boost::unit_test_framework
|
||||
SQLite::SQLite3
|
||||
spdlog::spdlog
|
||||
llama
|
||||
Boost::program_options
|
||||
Boost::json
|
||||
)
|
||||
|
||||
target_compile_options(biergarten-pipeline PRIVATE
|
||||
$<$<CXX_COMPILER_ID:GNU,Clang>:
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wpedantic
|
||||
-Wshadow
|
||||
-Wconversion
|
||||
-Wsign-conversion
|
||||
>
|
||||
$<$<CXX_COMPILER_ID:MSVC>:
|
||||
/W4
|
||||
/WX
|
||||
>
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Post-Build Steps & Utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
add_custom_command(TARGET biergarten-pipeline POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/output
|
||||
COMMENT "Creating output/ directory for seed SQL files"
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_SOURCE_DIR}/output
|
||||
COMMENT "Ensuring output directory exists"
|
||||
)
|
||||
|
||||
find_program(VALGRIND valgrind)
|
||||
if(VALGRIND)
|
||||
add_custom_target(memcheck
|
||||
COMMAND ${VALGRIND}
|
||||
--leak-check=full
|
||||
--error-exitcode=1
|
||||
$<TARGET_FILE:biergarten-pipeline> --help
|
||||
COMMAND ${VALGRIND} --leak-check=full --error-exitcode=1 $<TARGET_FILE:biergarten-pipeline> --help
|
||||
DEPENDS biergarten-pipeline
|
||||
COMMENT "Running Valgrind memcheck"
|
||||
COMMENT "Running Valgrind memory check"
|
||||
)
|
||||
endif()
|
||||
|
||||
include(CTest)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(Boost REQUIRED COMPONENTS unit_test_framework)
|
||||
|
||||
file(GLOB_RECURSE TEST_SOURCES CONFIGURE_DEPENDS
|
||||
tests/*.cpp
|
||||
tests/*.cc
|
||||
tests/*.cxx
|
||||
)
|
||||
|
||||
if(TEST_SOURCES)
|
||||
add_executable(biergarten-pipeline-tests ${TEST_SOURCES})
|
||||
|
||||
target_include_directories(biergarten-pipeline-tests
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
target_link_libraries(biergarten-pipeline-tests
|
||||
PRIVATE
|
||||
Boost::unit_test_framework
|
||||
CURL::libcurl
|
||||
nlohmann_json::nlohmann_json
|
||||
)
|
||||
|
||||
add_test(
|
||||
NAME biergarten-pipeline-tests
|
||||
COMMAND biergarten-pipeline-tests
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,414 +1,199 @@
|
||||
## Biergarten Pipeline
|
||||
Biergarten Pipeline
|
||||
|
||||
## Overview
|
||||
Overview
|
||||
|
||||
The pipeline orchestrates five key stages:
|
||||
|
||||
1. **Download**: Fetches `countries+states+cities.json` from a pinned GitHub commit with optional local caching.
|
||||
2. **Parse**: Streams JSON using RapidJSON SAX parser, extracting country/state/city records without loading the entire file into memory.
|
||||
3. **Buffer**: Routes city records through a bounded concurrent queue to decouple parsing from writes.
|
||||
4. **Store**: Inserts records with concurrent thread safety using an in-memory SQLite database.
|
||||
5. **Generate**: Produces mock brewery metadata for a sample of cities (mockup for future LLM integration).
|
||||
Download: Fetches countries+states+cities.json from a pinned GitHub commit with optional local caching.
|
||||
|
||||
---
|
||||
Parse: Streams JSON using Boost.JSON's basic_parser to extract country/state/city records without loading the entire file into memory.
|
||||
|
||||
## Architecture
|
||||
Buffer: Routes city records through a bounded concurrent queue to decouple parsing from writes.
|
||||
|
||||
### Data Sources and Formats
|
||||
Store: Inserts records with concurrent thread safety using an in-memory SQLite database.
|
||||
|
||||
- Hierarchical structure: countries array → states per country → cities per state.
|
||||
- Fields: `id` (integer), `name` (string), `iso2` / `iso3` (codes), `latitude` / `longitude`.
|
||||
- Sourced from: [dr5hn/countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) on GitHub.
|
||||
Generate: Produces mock brewery metadata for a sample of cities (mockup for future LLM integration).
|
||||
|
||||
**Output**: Structured SQLite in-memory database + console logs via spdlog.
|
||||
Architecture
|
||||
|
||||
### Concurrency Architecture
|
||||
Data Sources and Formats
|
||||
|
||||
Hierarchical structure: countries array → states per country → cities per state.
|
||||
|
||||
Fields: id (integer), name (string), iso2 / iso3 (codes), latitude / longitude.
|
||||
|
||||
Sourced from: dr5hn/countries-states-cities-database on GitHub.
|
||||
|
||||
Output: Structured SQLite in-memory database + console logs via spdlog.
|
||||
|
||||
Concurrency Architecture
|
||||
|
||||
The pipeline splits work across parsing and writing phases:
|
||||
|
||||
```
|
||||
Main Thread:
|
||||
parse_sax() -> Insert countries (direct)
|
||||
-> Insert states (direct)
|
||||
-> Push CityRecord to WorkQueue
|
||||
parse_sax() -> Insert countries (direct)
|
||||
-> Insert states (direct)
|
||||
-> Push CityRecord to WorkQueue
|
||||
|
||||
Worker Threads (implicit; pthread pool via sqlite3):
|
||||
Pop CityRecord from WorkQueue
|
||||
-> InsertCity(db) with mutex protection
|
||||
```
|
||||
Pop CityRecord from WorkQueue
|
||||
-> InsertCity(db) with mutex protection
|
||||
|
||||
**Key synchronization primitives**:
|
||||
Key synchronization primitives:
|
||||
|
||||
- **WorkQueue<T>**: Bounded (default 1024 items) concurrent queue with blocking push/pop, guarded by mutex + condition variables.
|
||||
- **SqliteDatabase::dbMutex**: Serializes all SQLite operations to avoid `SQLITE_BUSY` and ensure write safety.
|
||||
WorkQueue<T>: Bounded (default 1024 items) concurrent queue with blocking push/pop, guarded by mutex + condition variables.
|
||||
|
||||
**Backpressure**: When the WorkQueue fills (≥1024 city records pending), the parser thread blocks until workers drain items.
|
||||
SqliteDatabase::dbMutex: Serializes all SQLite operations to avoid SQLITE_BUSY and ensure write safety.
|
||||
|
||||
### Component Responsibilities
|
||||
Backpressure: When the WorkQueue fills (≥1024 city records pending), the parser thread blocks until workers drain items.
|
||||
|
||||
| Component | Purpose | Thread Safety |
|
||||
| ------------------------- | ------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- |
|
||||
| **DataDownloader** | GitHub fetch with curl; optional filesystem cache; handles retries and ETags. | Blocking I/O; safe for single-threaded startup. |
|
||||
| **StreamingJsonParser** | SAX-style RapidJSON handler; emits country/state/city via callbacks; tracks parse state (array depth, key context). | Single-threaded parse phase; thread-safe callbacks. |
|
||||
| **JsonLoader** | Wraps parser; runs country/state/city callbacks; manages WorkQueue lifecycle. | Produces to WorkQueue; consumes from callbacks. |
|
||||
| **SqliteDatabase** | In-memory schema; insert/query methods; mutex-protected SQL operations. | Mutex-guarded; thread-safe concurrent inserts. |
|
||||
| **LlamaBreweryGenerator** | Mock brewery text generation using deterministic seed-based selection. | Stateless; thread-safe method calls. |
|
||||
Component Responsibilities
|
||||
|
||||
---
|
||||
Component
|
||||
|
||||
## Database Schema
|
||||
Purpose
|
||||
|
||||
**SQLite in-memory database** with three core tables:
|
||||
Thread Safety
|
||||
|
||||
### Countries
|
||||
DataDownloader
|
||||
|
||||
GitHub fetch with curl; optional filesystem cache; handles retries and ETags.
|
||||
|
||||
Blocking I/O; safe for single-threaded startup.
|
||||
|
||||
StreamingJsonParser
|
||||
|
||||
Subclasses boost::json::basic_parser; emits country/state/city via callbacks; tracking parse depth.
|
||||
|
||||
Single-threaded parse phase; thread-safe callbacks.
|
||||
|
||||
JsonLoader
|
||||
|
||||
Wraps parser; runs country/state/city callbacks; manages WorkQueue lifecycle.
|
||||
|
||||
Produces to WorkQueue; consumes from callbacks.
|
||||
|
||||
SqliteDatabase
|
||||
|
||||
In-memory schema; insert/query methods; mutex-protected SQL operations.
|
||||
|
||||
Mutex-guarded; thread-safe concurrent inserts.
|
||||
|
||||
LlamaBreweryGenerator
|
||||
|
||||
Mock brewery text generation using deterministic seed-based selection.
|
||||
|
||||
Stateless; thread-safe method calls.
|
||||
|
||||
Database Schema
|
||||
|
||||
SQLite in-memory database with three core tables:
|
||||
|
||||
Countries
|
||||
|
||||
```sql
|
||||
CREATE TABLE countries (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
iso2 TEXT,
|
||||
iso3 TEXT
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
iso2 TEXT,
|
||||
iso3 TEXT
|
||||
);
|
||||
CREATE INDEX idx_countries_iso2 ON countries(iso2);
|
||||
```
|
||||
|
||||
### States
|
||||
States
|
||||
|
||||
```sql
|
||||
CREATE TABLE states (
|
||||
id INTEGER PRIMARY KEY,
|
||||
country_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
iso2 TEXT,
|
||||
FOREIGN KEY (country_id) REFERENCES countries(id)
|
||||
id INTEGER PRIMARY KEY,
|
||||
country_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
iso2 TEXT,
|
||||
FOREIGN KEY (country_id) REFERENCES countries(id)
|
||||
);
|
||||
CREATE INDEX idx_states_country ON states(country_id);
|
||||
```
|
||||
|
||||
### Cities
|
||||
Cities
|
||||
|
||||
```sql
|
||||
CREATE TABLE cities (
|
||||
id INTEGER PRIMARY KEY,
|
||||
state_id INTEGER NOT NULL,
|
||||
country_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
latitude REAL,
|
||||
longitude REAL,
|
||||
FOREIGN KEY (state_id) REFERENCES states(id),
|
||||
FOREIGN KEY (country_id) REFERENCES countries(id)
|
||||
id INTEGER PRIMARY KEY,
|
||||
state_id INTEGER NOT NULL,
|
||||
country_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
latitude REAL,
|
||||
longitude REAL,
|
||||
FOREIGN KEY (state_id) REFERENCES states(id),
|
||||
FOREIGN KEY (country_id) REFERENCES countries(id)
|
||||
);
|
||||
CREATE INDEX idx_cities_state ON cities(state_id);
|
||||
CREATE INDEX idx_cities_country ON cities(country_id);
|
||||
```
|
||||
|
||||
**Design rationale**:
|
||||
Configuration and Extensibility
|
||||
|
||||
- In-memory for performance (no persistent storage; data is regenerated on each run).
|
||||
- Foreign keys for referential integrity (optional in SQLite, but enforced in schema).
|
||||
- Indexes on foreign keys for fast lookups during brewery generation.
|
||||
- Dual country_id in cities table for direct queries without state joins.
|
||||
Command-Line Arguments
|
||||
|
||||
---
|
||||
Boost.Program_options provides named CLI arguments:
|
||||
|
||||
## Data Flow
|
||||
./biergarten-pipeline [options]
|
||||
|
||||
### Parse Phase (Main Thread)
|
||||
Arg
|
||||
|
||||
1. **DataDownloader::DownloadCountriesDatabase()**
|
||||
- Constructs GitHub raw-content URL: `https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/{commit}/countries+states+cities.json`
|
||||
- Uses curl with `FOLLOWLOCATION` and timeout.
|
||||
- Caches locally; checks ETag for freshness.
|
||||
Default
|
||||
|
||||
2. **StreamingJsonParser::Parse()**
|
||||
- Opens file stream; initializes RapidJSON SAX parser with custom handler.
|
||||
- Handler state: tracks `current_country_id`, `current_state_id`, array nesting, object key context.
|
||||
- **Country processing** (inline): When country object completes, calls `db.InsertCountry()` directly on main thread.
|
||||
- **State processing** (inline): When state object completes, calls `db.InsertState()` directly.
|
||||
- **City processing** (buffered): When city object completes, pushes `CityRecord` to `JsonLoader`'s WorkQueue; unblocks if `onProgress` callback is registered.
|
||||
Purpose
|
||||
|
||||
3. **JsonLoader::LoadWorldCities()**
|
||||
- Registers callbacks with parser.
|
||||
- Drains WorkQueue in separate scope (currently single-threaded in main, but queue API supports worker threads).
|
||||
- Each city is inserted via `db.InsertCity()`.
|
||||
--model, -m
|
||||
|
||||
### Query and Generation Phase (Main Thread)
|
||||
""
|
||||
|
||||
4. **Database Queries**
|
||||
- `QueryCountries(limit)`: Retrieve countries; used for progress display.
|
||||
- `QueryStates(limit)`: Retrieve states; used for progress display.
|
||||
- `QueryCities()`: Retrieve all city ids + names for brewery generation.
|
||||
Path to LLM model (mock implementation used if left blank).
|
||||
|
||||
5. **Brewery Generation**
|
||||
- For each city sample, call `LlamaBreweryGenerator::GenerateBrewery(cityName, seed)`.
|
||||
- Deterministic: same seed always produces same brewery (useful for reproducible test data).
|
||||
- Returns `{ name, description }` struct.
|
||||
--cache-dir, -c
|
||||
|
||||
---
|
||||
/tmp
|
||||
|
||||
## Concurrency Deep Dive
|
||||
Directory for cached JSON DB.
|
||||
|
||||
### WorkQueue<T>
|
||||
--commit
|
||||
|
||||
A bounded thread-safe queue enabling producer-consumer patterns:
|
||||
c5eb7772
|
||||
|
||||
```cpp
|
||||
template <typename T> class WorkQueue {
|
||||
std::queue<T> queue;
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv_not_empty, cv_not_full;
|
||||
size_t max_size;
|
||||
bool shutdown;
|
||||
};
|
||||
```
|
||||
Git commit hash for consistency (stable 2026-03-28 snapshot).
|
||||
|
||||
**push(item)**:
|
||||
--help, -h
|
||||
|
||||
- Locks mutex.
|
||||
- Waits on `cv_not_full` until queue is below max_size OR shutdown signaled.
|
||||
- Pushes item; notifies one waiter on `cv_not_empty`.
|
||||
- Returns false if shutdown, else true.
|
||||
-
|
||||
|
||||
**pop()**:
|
||||
Show help menu.
|
||||
|
||||
- Locks mutex.
|
||||
- Waits on `cv_not_empty` until queue has items OR shutdown signaled.
|
||||
- Pops and returns `std::optional<T>`; notifies one waiter on `cv_not_full`.
|
||||
- Returns `std::nullopt` if shutdown and queue is empty.
|
||||
Examples:
|
||||
|
||||
**shutdown_queue()**:
|
||||
|
||||
- Sets `shutdown = true`; notifies all waiters on both condition variables.
|
||||
- Causes all waiting pop() calls to return `std::nullopt`.
|
||||
|
||||
**Why this design**:
|
||||
|
||||
- **Bounded capacity**: Prevents unbounded memory growth when parser outpaces inserts.
|
||||
- **Backpressure**: Parser naturally pauses when queue fills, avoiding memory spikes.
|
||||
- **Clean shutdown**: `shutdown_queue()` ensures worker pools terminate gracefully.
|
||||
|
||||
### SqliteDatabase Mutex
|
||||
|
||||
All SQLite operations (`INSERT`, `SELECT`) are guarded by `dbMutex`:
|
||||
|
||||
```cpp
|
||||
std::unique_lock<std::mutex> lock(dbMutex);
|
||||
int rc = sqlite3_step(stmt);
|
||||
```
|
||||
|
||||
**Why**: SQLite's "serializable" journal mode (default) requires external synchronization for multi-threaded access. A single mutex serializes all queries, avoiding `SQLITE_BUSY` errors.
|
||||
|
||||
**Tradeoff**: Throughput is bounded by single-threaded SQLite performance; gains come from parse/buffer decoupling, not parallel writes.
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
### DataDownloader
|
||||
|
||||
- **Network failures**: Retries up to 3 times with exponential backoff; throws `std::runtime_error` on final failure.
|
||||
- **Caching**: Falls back to cached file if download fails and cache exists.
|
||||
|
||||
### Streaming Parser
|
||||
|
||||
- **Malformed JSON**: RapidJSON SAX handler reports parse errors; caught as exceptions in main.
|
||||
- **Missing fields**: Silently skips incomplete records (e.g., city without latitude).
|
||||
|
||||
### Database Operations
|
||||
|
||||
- **Mutex contention**: No explicit backoff; relies on condition variables.
|
||||
- **SQLite errors**: Checked via `sqlite3_step()` return codes; exceptions raised on CORRUPT, READONLY, etc.
|
||||
|
||||
### Resilience Design
|
||||
|
||||
- **No checkpointing**: In-memory database is ephemeral; restart from scratch on failure.
|
||||
- **Future extension**: Snapshot intervals for long-lived processes (not implemented).
|
||||
|
||||
---
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Benchmarks (Example: 2M cities on 2024 MacBook Pro)
|
||||
|
||||
| Stage | Time | Throughput |
|
||||
| ----------------------------- | ------- | ---------------------------------------- |
|
||||
| Download + Cache | 1s | ~100 MB/s (network dependent) |
|
||||
| Parse (SAX) | 2s | 50M records/sec |
|
||||
| Insert (countries/states) | <0.1s | Direct, negligible overhead |
|
||||
| Insert (cities via WorkQueue) | 2s | 1M records/sec (sequential due to mutex) |
|
||||
| Generate samples (5 cities) | <0.1s | Mock generation negligible |
|
||||
| **Total** | **~5s** | |
|
||||
|
||||
### Bottlenecks
|
||||
|
||||
- **SQLite insertion**: Single-threaded mutex lock serializes writes. Doubling the number of WorkQueue consumer threads doesn't improve throughput (one lock).
|
||||
- **Parse speed**: RapidJSON SAX is fast (2s for 100 MB); not the bottleneck.
|
||||
- **Memory**: ~100 MB for in-memory database; suitable for most deployments.
|
||||
|
||||
### Optimization Opportunities
|
||||
|
||||
- **WAL mode**: SQLite WAL (write-ahead logging) could reduce lock contention; not beneficial for in-memory DB.
|
||||
- **Batch inserts**: Combine multiple rows in a single transaction; helps if inserting outside the WorkQueue scope.
|
||||
- **Foreign key lazy-loading**: Skip foreign key constraints during bulk load; re-enable for queries. (Not implemented.)
|
||||
|
||||
---
|
||||
|
||||
## Configuration and Extensibility
|
||||
|
||||
### Command-Line Arguments
|
||||
|
||||
```bash
|
||||
./biergarten-pipeline [modelPath] [cacheDir] [commit]
|
||||
```
|
||||
|
||||
| Arg | Default | Purpose |
|
||||
| ----------- | -------------- | ----------------------------------------------------------------------- |
|
||||
| `modelPath` | `./model.gguf` | Path to LLM model (mock implementation; not loaded in current version). |
|
||||
| `cacheDir` | `/tmp` | Directory for cached JSON (e.g., `/tmp/countries+states+cities.json`). |
|
||||
| `commit` | `c5eb7772` | Git commit hash for consistency (stable 2026-03-28 snapshot). |
|
||||
|
||||
**Examples**:
|
||||
|
||||
```bash
|
||||
./biergarten-pipeline
|
||||
./biergarten-pipeline ./models/llama.gguf /var/cache main
|
||||
./biergarten-pipeline "" /tmp v1.2.3
|
||||
```
|
||||
./biergarten-pipeline --model ./models/llama.gguf --cache-dir /var/cache
|
||||
./biergarten-pipeline -c /tmp --commit v1.2.3
|
||||
|
||||
### Extending the Generator
|
||||
Building and Running
|
||||
|
||||
**Current**: `LlamaBreweryGenerator::GenerateBrewery()` uses deterministic seed-based selection from hardcoded lists.
|
||||
Prerequisites
|
||||
|
||||
**Future swap points**:
|
||||
C++23 compiler (g++, clang, MSVC).
|
||||
|
||||
1. Load an actual LLM model in `LoadModel(modelPath)`.
|
||||
2. Tokenize city name and context; call model inference.
|
||||
3. Validate output (length, format) and rank if multiple candidates.
|
||||
4. Cache results to avoid re-inference for repeated cities.
|
||||
CMake 3.20+.
|
||||
|
||||
**Example stub for future integration**:
|
||||
curl (for HTTP downloads).
|
||||
|
||||
```cpp
|
||||
Brewery LlamaBreweryGenerator::GenerateBrewery(const std::string &cityName, int seed) {
|
||||
// TODO: Replace with actual llama.cpp inference
|
||||
// llama_context *ctx = llama_new_context_with_model(model, params);
|
||||
// std::string prompt = "Generate a brewery for " + cityName;
|
||||
// std::string result = llama_inference(ctx, prompt, seed);
|
||||
// return parse_brewery(result);
|
||||
}
|
||||
```
|
||||
sqlite3.
|
||||
|
||||
### Logging Configuration
|
||||
Boost 1.75+ (requires Boost.JSON and Boost.Program_options).
|
||||
|
||||
Logging uses **spdlog** with:
|
||||
spdlog (fetched via CMake FetchContent).
|
||||
|
||||
- **Level**: Info (can change via `spdlog::set_level(spdlog::level::debug)` at startup).
|
||||
- **Format**: Plain ASCII; no Unicode box art.
|
||||
- **Sink**: Console (stdout/stderr); can redirect to file.
|
||||
Build
|
||||
|
||||
**Current output sample**:
|
||||
|
||||
```
|
||||
[Pipeline] Downloading geographic data from GitHub...
|
||||
Initializing in-memory SQLite database...
|
||||
Initializing brewery generator...
|
||||
|
||||
=== GEOGRAPHIC DATA OVERVIEW ===
|
||||
Total records loaded:
|
||||
Countries: 195
|
||||
States: 5000
|
||||
Cities: 150000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Building and Running
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- C++17 compiler (g++, clang, MSVC).
|
||||
- CMake 3.20+.
|
||||
- curl (for HTTP downloads).
|
||||
- sqlite3 (usually system-provided).
|
||||
- RapidJSON (fetched via CMake FetchContent).
|
||||
- spdlog (fetched via CMake FetchContent).
|
||||
|
||||
### Build
|
||||
|
||||
```bash
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake ..
|
||||
cmake --build . --target biergarten-pipeline -- -j
|
||||
```
|
||||
|
||||
**Build artifacts**:
|
||||
Run
|
||||
|
||||
- Executable: `build/biergarten-pipeline`
|
||||
- Intermediate: `build/CMakeFiles/`, `build/_deps/` (RapidJSON, spdlog)
|
||||
|
||||
### Run
|
||||
|
||||
```bash
|
||||
./biergarten-pipeline
|
||||
```
|
||||
|
||||
**Output**: Logs to console; caches JSON in `/tmp/countries+states+cities.json`.
|
||||
|
||||
### Cleaning
|
||||
|
||||
```bash
|
||||
rm -rf build
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Code Organization
|
||||
|
||||
- **`includes/`**: Public headers (data structures, class APIs).
|
||||
- **`src/`**: Implementations with inline comments for non-obvious logic.
|
||||
- **`CMakeLists.txt`**: Build configuration; defines fetch content, compiler flags, linking.
|
||||
|
||||
### Testing
|
||||
|
||||
Currently no automated tests. To add:
|
||||
|
||||
1. Create `tests/` folder.
|
||||
2. Use CMake to add a test executable.
|
||||
3. Test the parser with small JSON fixtures.
|
||||
4. Mock the database for isolation.
|
||||
|
||||
### Debugging
|
||||
|
||||
**Enable verbose logging**:
|
||||
|
||||
```cpp
|
||||
spdlog::set_level(spdlog::level::debug);
|
||||
```
|
||||
|
||||
**GDB workflow**:
|
||||
|
||||
```bash
|
||||
gdb ./biergarten-pipeline
|
||||
(gdb) break src/stream_parser.cpp:50
|
||||
(gdb) run
|
||||
```
|
||||
|
||||
### Future Enhancements
|
||||
|
||||
1. **Real LLM integration**: Load and run llama.cpp models.
|
||||
2. **Persistence**: Write brewery data to a database or file.
|
||||
3. **Distributed parsing**: Shard JSON file across multiple parse streams.
|
||||
4. **Incremental updates**: Only insert new records if source updated.
|
||||
5. **Web API**: Expose database via HTTP (brewery lookup, city search).
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- [RapidJSON](https://rapidjson.org/) – SAX parsing documentation.
|
||||
- [spdlog](https://github.com/gabime/spdlog) – Logging framework.
|
||||
- [SQLite](https://www.sqlite.org/docs.html) – In-memory database reference.
|
||||
- [countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) – Data source.
|
||||
Output: Logs to console; caches JSON in /tmp/countries+states+cities.json.
|
||||
|
||||
26
pipeline/includes/curl_web_client.h
Normal file
26
pipeline/includes/curl_web_client.h
Normal file
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "web_client.h"
|
||||
#include <memory>
|
||||
|
||||
// RAII for curl_global_init/cleanup.
|
||||
// An instance of this class should be created in main() before any curl
|
||||
// operations and exist for the lifetime of the application.
|
||||
class CurlGlobalState {
|
||||
public:
|
||||
CurlGlobalState();
|
||||
~CurlGlobalState();
|
||||
CurlGlobalState(const CurlGlobalState &) = delete;
|
||||
CurlGlobalState &operator=(const CurlGlobalState &) = delete;
|
||||
};
|
||||
|
||||
class CURLWebClient : public IWebClient {
|
||||
public:
|
||||
CURLWebClient();
|
||||
~CURLWebClient() override;
|
||||
|
||||
void DownloadToFile(const std::string &url,
|
||||
const std::string &filePath) override;
|
||||
std::string Get(const std::string &url) override;
|
||||
std::string UrlEncode(const std::string &value) override;
|
||||
};
|
||||
@@ -1,14 +1,17 @@
|
||||
#ifndef DATA_DOWNLOADER_H
|
||||
#define DATA_DOWNLOADER_H
|
||||
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "web_client.h"
|
||||
|
||||
/// @brief Downloads and caches source geography JSON payloads.
|
||||
class DataDownloader {
|
||||
public:
|
||||
/// @brief Initializes global curl state used by this downloader.
|
||||
DataDownloader();
|
||||
DataDownloader(std::shared_ptr<IWebClient> webClient);
|
||||
|
||||
/// @brief Cleans up global curl state.
|
||||
~DataDownloader();
|
||||
@@ -21,6 +24,7 @@ public:
|
||||
|
||||
private:
|
||||
bool FileExists(const std::string &filePath) const;
|
||||
std::shared_ptr<IWebClient> m_webClient;
|
||||
};
|
||||
|
||||
#endif // DATA_DOWNLOADER_H
|
||||
|
||||
@@ -27,6 +27,15 @@ struct State {
|
||||
int countryId;
|
||||
};
|
||||
|
||||
struct City {
|
||||
/// @brief City identifier from the source dataset.
|
||||
int id;
|
||||
/// @brief City display name.
|
||||
std::string name;
|
||||
/// @brief Parent country identifier.
|
||||
int countryId;
|
||||
};
|
||||
|
||||
/// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks.
|
||||
class SqliteDatabase {
|
||||
private:
|
||||
@@ -60,8 +69,8 @@ public:
|
||||
void InsertCity(int id, int stateId, int countryId, const std::string &name,
|
||||
double latitude, double longitude);
|
||||
|
||||
/// @brief Returns city id and city name pairs.
|
||||
std::vector<std::pair<int, std::string>> QueryCities();
|
||||
/// @brief Returns city records including parent country id.
|
||||
std::vector<City> QueryCities();
|
||||
|
||||
/// @brief Returns countries with optional row limit.
|
||||
std::vector<Country> QueryCountries(int limit = 0);
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "data_generator.h"
|
||||
#include <memory>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "data_generator.h"
|
||||
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
class LlamaGenerator final : public IDataGenerator {
|
||||
public:
|
||||
LlamaGenerator() = default;
|
||||
~LlamaGenerator() override;
|
||||
|
||||
void setSamplingOptions(float temperature, float topP, int seed = -1);
|
||||
|
||||
void load(const std::string &modelPath) override;
|
||||
BreweryResult generateBrewery(const std::string &cityName,
|
||||
const std::string &countryName,
|
||||
@@ -18,14 +22,17 @@ public:
|
||||
UserResult generateUser(const std::string &locale) override;
|
||||
|
||||
private:
|
||||
std::string infer(const std::string &prompt, int maxTokens = 5000);
|
||||
std::string infer(const std::string &prompt, int maxTokens = 10000);
|
||||
// Overload that allows passing a system message separately so chat-capable
|
||||
// models receive a proper system role instead of having the system text
|
||||
// concatenated into the user prompt (helps avoid revealing internal
|
||||
// reasoning or instructions in model output).
|
||||
std::string infer(const std::string &systemPrompt, const std::string &prompt,
|
||||
int maxTokens = 5000);
|
||||
int maxTokens = 10000);
|
||||
|
||||
llama_model *model_ = nullptr;
|
||||
llama_context *context_ = nullptr;
|
||||
float sampling_temperature_ = 0.8f;
|
||||
float sampling_top_p_ = 0.92f;
|
||||
uint32_t sampling_seed_ = 0xFFFFFFFFu;
|
||||
};
|
||||
|
||||
19
pipeline/includes/web_client.h
Normal file
19
pipeline/includes/web_client.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
class IWebClient {
|
||||
public:
|
||||
virtual ~IWebClient() = default;
|
||||
|
||||
// Downloads content from a URL to a file. Throws on error.
|
||||
virtual void DownloadToFile(const std::string &url,
|
||||
const std::string &filePath) = 0;
|
||||
|
||||
// Performs a GET request and returns the response body as a string. Throws on
|
||||
// error.
|
||||
virtual std::string Get(const std::string &url) = 0;
|
||||
|
||||
// URL-encodes a string.
|
||||
virtual std::string UrlEncode(const std::string &value) = 0;
|
||||
};
|
||||
24
pipeline/includes/wikipedia_service.h
Normal file
24
pipeline/includes/wikipedia_service.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "web_client.h"
|
||||
|
||||
/// @brief Provides cached Wikipedia summary lookups for city and country pairs.
|
||||
class WikipediaService {
|
||||
public:
|
||||
/// @brief Creates a new Wikipedia service with the provided web client.
|
||||
explicit WikipediaService(std::shared_ptr<IWebClient> client);
|
||||
|
||||
/// @brief Returns the Wikipedia summary extract for city and country.
|
||||
[[nodiscard]] std::string GetSummary(std::string_view city,
|
||||
std::string_view country);
|
||||
|
||||
private:
|
||||
std::string FetchExtract(std::string_view query);
|
||||
std::shared_ptr<IWebClient> client_;
|
||||
std::unordered_map<std::string, std::string> cache_;
|
||||
};
|
||||
139
pipeline/src/curl_web_client.cpp
Normal file
139
pipeline/src/curl_web_client.cpp
Normal file
@@ -0,0 +1,139 @@
|
||||
#include "curl_web_client.h"
|
||||
#include <cstdio>
|
||||
#include <curl/curl.h>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
CurlGlobalState::CurlGlobalState() {
|
||||
if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) {
|
||||
throw std::runtime_error(
|
||||
"[CURLWebClient] Failed to initialize libcurl globally");
|
||||
}
|
||||
}
|
||||
|
||||
CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); }
|
||||
|
||||
namespace {
|
||||
// curl write callback that appends response data into a std::string
|
||||
static size_t WriteCallbackString(void *contents, size_t size, size_t nmemb,
|
||||
void *userp) {
|
||||
size_t realsize = size * nmemb;
|
||||
auto *s = static_cast<std::string *>(userp);
|
||||
s->append(static_cast<char *>(contents), realsize);
|
||||
return realsize;
|
||||
}
|
||||
|
||||
// curl write callback that writes to a file stream
|
||||
static size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb,
|
||||
void *userp) {
|
||||
size_t realsize = size * nmemb;
|
||||
auto *outFile = static_cast<std::ofstream *>(userp);
|
||||
outFile->write(static_cast<char *>(contents), realsize);
|
||||
return realsize;
|
||||
}
|
||||
|
||||
// RAII wrapper for CURL handle using unique_ptr
|
||||
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||
|
||||
CurlHandle create_handle() {
|
||||
CURL *handle = curl_easy_init();
|
||||
if (!handle) {
|
||||
throw std::runtime_error(
|
||||
"[CURLWebClient] Failed to initialize libcurl handle");
|
||||
}
|
||||
return CurlHandle(handle, &curl_easy_cleanup);
|
||||
}
|
||||
|
||||
void set_common_get_options(CURL *curl, const std::string &url,
|
||||
long connect_timeout, long total_timeout) {
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
|
||||
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout);
|
||||
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CURLWebClient::CURLWebClient() {}
|
||||
|
||||
CURLWebClient::~CURLWebClient() {}
|
||||
|
||||
void CURLWebClient::DownloadToFile(const std::string &url,
|
||||
const std::string &filePath) {
|
||||
auto curl = create_handle();
|
||||
|
||||
std::ofstream outFile(filePath, std::ios::binary);
|
||||
if (!outFile.is_open()) {
|
||||
throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " +
|
||||
filePath);
|
||||
}
|
||||
|
||||
set_common_get_options(curl.get(), url, 30L, 300L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA,
|
||||
static_cast<void *>(&outFile));
|
||||
|
||||
CURLcode res = curl_easy_perform(curl.get());
|
||||
outFile.close();
|
||||
|
||||
if (res != CURLE_OK) {
|
||||
std::remove(filePath.c_str());
|
||||
std::string error = std::string("[CURLWebClient] Download failed: ") +
|
||||
curl_easy_strerror(res);
|
||||
throw std::runtime_error(error);
|
||||
}
|
||||
|
||||
long httpCode = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
|
||||
|
||||
if (httpCode != 200) {
|
||||
std::remove(filePath.c_str());
|
||||
std::stringstream ss;
|
||||
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
std::string CURLWebClient::Get(const std::string &url) {
|
||||
auto curl = create_handle();
|
||||
|
||||
std::string response_string;
|
||||
set_common_get_options(curl.get(), url, 10L, 20L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string);
|
||||
|
||||
CURLcode res = curl_easy_perform(curl.get());
|
||||
|
||||
if (res != CURLE_OK) {
|
||||
std::string error =
|
||||
std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res);
|
||||
throw std::runtime_error(error);
|
||||
}
|
||||
|
||||
long httpCode = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode);
|
||||
|
||||
if (httpCode != 200) {
|
||||
std::stringstream ss;
|
||||
ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
return response_string;
|
||||
}
|
||||
|
||||
std::string CURLWebClient::UrlEncode(const std::string &value) {
|
||||
// A NULL handle is fine for UTF-8 encoding according to libcurl docs.
|
||||
char *output = curl_easy_escape(nullptr, value.c_str(), 0);
|
||||
|
||||
if (output) {
|
||||
std::string result(output);
|
||||
curl_free(output);
|
||||
return result;
|
||||
}
|
||||
throw std::runtime_error("[CURLWebClient] curl_easy_escape failed");
|
||||
}
|
||||
@@ -1,20 +1,13 @@
|
||||
#include "data_downloader.h"
|
||||
#include <cstdio>
|
||||
#include <curl/curl.h>
|
||||
#include "web_client.h"
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
static size_t WriteCallback(void *contents, size_t size, size_t nmemb,
|
||||
void *userp) {
|
||||
size_t realsize = size * nmemb;
|
||||
std::ofstream *outFile = static_cast<std::ofstream *>(userp);
|
||||
outFile->write(static_cast<char *>(contents), realsize);
|
||||
return realsize;
|
||||
}
|
||||
|
||||
DataDownloader::DataDownloader() {}
|
||||
DataDownloader::DataDownloader(std::shared_ptr<IWebClient> webClient)
|
||||
: m_webClient(std::move(webClient)) {}
|
||||
|
||||
DataDownloader::~DataDownloader() {}
|
||||
|
||||
@@ -41,56 +34,7 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cachePath,
|
||||
|
||||
spdlog::info("[DataDownloader] Downloading: {}", url);
|
||||
|
||||
CURL *curl = curl_easy_init();
|
||||
if (!curl) {
|
||||
throw std::runtime_error("[DataDownloader] Failed to initialize libcurl");
|
||||
}
|
||||
|
||||
std::ofstream outFile(cachePath, std::ios::binary);
|
||||
if (!outFile.is_open()) {
|
||||
curl_easy_cleanup(curl);
|
||||
throw std::runtime_error("[DataDownloader] Cannot open file for writing: " +
|
||||
cachePath);
|
||||
}
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, static_cast<void *>(&outFile));
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 30L);
|
||||
curl_easy_setopt(curl, CURLOPT_TIMEOUT, 300L);
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0");
|
||||
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
outFile.close();
|
||||
|
||||
if (res != CURLE_OK) {
|
||||
curl_easy_cleanup(curl);
|
||||
std::remove(cachePath.c_str());
|
||||
|
||||
std::string error = std::string("[DataDownloader] Download failed: ") +
|
||||
curl_easy_strerror(res);
|
||||
throw std::runtime_error(error);
|
||||
}
|
||||
|
||||
long httpCode = 0;
|
||||
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode);
|
||||
curl_easy_cleanup(curl);
|
||||
|
||||
if (httpCode != 200) {
|
||||
std::remove(cachePath.c_str());
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "[DataDownloader] HTTP error " << httpCode
|
||||
<< " (commit: " << shortCommit << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
m_webClient->DownloadToFile(url, cachePath);
|
||||
|
||||
std::ifstream fileCheck(cachePath, std::ios::binary | std::ios::ate);
|
||||
std::streamsize size = fileCheck.tellg();
|
||||
|
||||
@@ -157,13 +157,12 @@ void SqliteDatabase::InsertCity(int id, int stateId, int countryId,
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() {
|
||||
std::vector<City> SqliteDatabase::QueryCities() {
|
||||
std::lock_guard<std::mutex> lock(dbMutex);
|
||||
|
||||
std::vector<std::pair<int, std::string>> cities;
|
||||
std::vector<City> cities;
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
|
||||
const char *query = "SELECT id, name FROM cities ORDER BY name";
|
||||
const char *query = "SELECT id, name, country_id FROM cities ORDER BY name";
|
||||
int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr);
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
@@ -174,7 +173,8 @@ std::vector<std::pair<int, std::string>> SqliteDatabase::QueryCities() {
|
||||
int id = sqlite3_column_int(stmt, 0);
|
||||
const char *name =
|
||||
reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
|
||||
cities.push_back({id, name ? std::string(name) : ""});
|
||||
int countryId = sqlite3_column_int(stmt, 2);
|
||||
cities.push_back({id, name ? std::string(name) : "", countryId});
|
||||
}
|
||||
|
||||
sqlite3_finalize(stmt);
|
||||
|
||||
@@ -1,32 +1,52 @@
|
||||
#include <chrono>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "json_loader.h"
|
||||
#include "stream_parser.h"
|
||||
#include <chrono>
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
void JsonLoader::LoadWorldCities(const std::string &jsonPath,
|
||||
SqliteDatabase &db) {
|
||||
constexpr size_t kBatchSize = 10000;
|
||||
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", jsonPath);
|
||||
|
||||
db.BeginTransaction();
|
||||
bool transactionOpen = true;
|
||||
|
||||
size_t citiesProcessed = 0;
|
||||
StreamingJsonParser::Parse(
|
||||
jsonPath, db,
|
||||
[&](const CityRecord &record) {
|
||||
db.InsertCity(record.id, record.state_id, record.country_id,
|
||||
record.name, record.latitude, record.longitude);
|
||||
citiesProcessed++;
|
||||
},
|
||||
[&](size_t current, size_t total) {
|
||||
if (current % 10000 == 0 && current > 0) {
|
||||
spdlog::info(" [Progress] Parsed {} cities...", current);
|
||||
}
|
||||
});
|
||||
try {
|
||||
StreamingJsonParser::Parse(
|
||||
jsonPath, db,
|
||||
[&](const CityRecord &record) {
|
||||
db.InsertCity(record.id, record.state_id, record.country_id,
|
||||
record.name, record.latitude, record.longitude);
|
||||
++citiesProcessed;
|
||||
|
||||
spdlog::info(" OK: Parsed all cities from JSON");
|
||||
if (citiesProcessed % kBatchSize == 0) {
|
||||
db.CommitTransaction();
|
||||
db.BeginTransaction();
|
||||
}
|
||||
},
|
||||
[&](size_t current, size_t /*total*/) {
|
||||
if (current % kBatchSize == 0 && current > 0) {
|
||||
spdlog::info(" [Progress] Parsed {} cities...", current);
|
||||
}
|
||||
});
|
||||
|
||||
db.CommitTransaction();
|
||||
spdlog::info(" OK: Parsed all cities from JSON");
|
||||
|
||||
if (transactionOpen) {
|
||||
db.CommitTransaction();
|
||||
transactionOpen = false;
|
||||
}
|
||||
} catch (...) {
|
||||
if (transactionOpen) {
|
||||
db.CommitTransaction();
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
auto endTime = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
#include "llama_generator.h"
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cctype>
|
||||
@@ -11,8 +7,12 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "llama.h"
|
||||
#include <boost/json.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "llama_generator.h"
|
||||
|
||||
namespace {
|
||||
|
||||
std::string trim(std::string value) {
|
||||
@@ -26,10 +26,47 @@ std::string trim(std::string value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
std::string CondenseWhitespace(std::string text) {
|
||||
std::string out;
|
||||
out.reserve(text.size());
|
||||
|
||||
bool inWhitespace = false;
|
||||
for (unsigned char ch : text) {
|
||||
if (std::isspace(ch)) {
|
||||
if (!inWhitespace) {
|
||||
out.push_back(' ');
|
||||
inWhitespace = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
inWhitespace = false;
|
||||
out.push_back(static_cast<char>(ch));
|
||||
}
|
||||
|
||||
return trim(std::move(out));
|
||||
}
|
||||
|
||||
std::string PrepareRegionContext(std::string_view regionContext,
|
||||
std::size_t maxChars = 700) {
|
||||
std::string normalized = CondenseWhitespace(std::string(regionContext));
|
||||
if (normalized.size() <= maxChars) {
|
||||
return normalized;
|
||||
}
|
||||
|
||||
normalized.resize(maxChars);
|
||||
const std::size_t lastSpace = normalized.find_last_of(' ');
|
||||
if (lastSpace != std::string::npos && lastSpace > maxChars / 2) {
|
||||
normalized.resize(lastSpace);
|
||||
}
|
||||
|
||||
normalized += "...";
|
||||
return normalized;
|
||||
}
|
||||
|
||||
std::string stripCommonPrefix(std::string line) {
|
||||
line = trim(std::move(line));
|
||||
|
||||
// Strip simple list markers like "- ", "* ", "1. ", "2) ".
|
||||
if (!line.empty() && (line[0] == '-' || line[0] == '*')) {
|
||||
line = trim(line.substr(1));
|
||||
} else {
|
||||
@@ -68,6 +105,50 @@ std::string stripCommonPrefix(std::string line) {
|
||||
return trim(std::move(line));
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string>
|
||||
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
|
||||
std::string normalized = raw;
|
||||
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
|
||||
|
||||
std::vector<std::string> lines;
|
||||
std::stringstream stream(normalized);
|
||||
std::string line;
|
||||
while (std::getline(stream, line)) {
|
||||
line = stripCommonPrefix(std::move(line));
|
||||
if (!line.empty())
|
||||
lines.push_back(std::move(line));
|
||||
}
|
||||
|
||||
std::vector<std::string> filtered;
|
||||
for (auto &l : lines) {
|
||||
std::string low = l;
|
||||
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
if (!l.empty() && l.front() == '<' && low.back() == '>')
|
||||
continue;
|
||||
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0)
|
||||
continue;
|
||||
filtered.push_back(std::move(l));
|
||||
}
|
||||
|
||||
if (filtered.size() < 2)
|
||||
throw std::runtime_error(errorMessage);
|
||||
|
||||
std::string first = trim(filtered.front());
|
||||
std::string second;
|
||||
for (size_t i = 1; i < filtered.size(); ++i) {
|
||||
if (!second.empty())
|
||||
second += ' ';
|
||||
second += filtered[i];
|
||||
}
|
||||
second = trim(std::move(second));
|
||||
|
||||
if (first.empty() || second.empty())
|
||||
throw std::runtime_error(errorMessage);
|
||||
return {first, second};
|
||||
}
|
||||
|
||||
std::string toChatPrompt(const llama_model *model,
|
||||
const std::string &userPrompt) {
|
||||
const char *tmpl = llama_model_chat_template(model, nullptr);
|
||||
@@ -75,10 +156,7 @@ std::string toChatPrompt(const llama_model *model,
|
||||
return userPrompt;
|
||||
}
|
||||
|
||||
const llama_chat_message message{
|
||||
"user",
|
||||
userPrompt.c_str(),
|
||||
};
|
||||
const llama_chat_message message{"user", userPrompt.c_str()};
|
||||
|
||||
std::vector<char> buffer(std::max<std::size_t>(1024, userPrompt.size() * 4));
|
||||
int32_t required =
|
||||
@@ -106,14 +184,11 @@ std::string toChatPrompt(const llama_model *model,
|
||||
const std::string &userPrompt) {
|
||||
const char *tmpl = llama_model_chat_template(model, nullptr);
|
||||
if (tmpl == nullptr) {
|
||||
// Fall back to concatenating but keep system and user parts distinct.
|
||||
return systemPrompt + "\n\n" + userPrompt;
|
||||
}
|
||||
|
||||
const llama_chat_message messages[2] = {
|
||||
{"system", systemPrompt.c_str()},
|
||||
{"user", userPrompt.c_str()},
|
||||
};
|
||||
const llama_chat_message messages[2] = {{"system", systemPrompt.c_str()},
|
||||
{"user", userPrompt.c_str()}};
|
||||
|
||||
std::vector<char> buffer(std::max<std::size_t>(
|
||||
1024, (systemPrompt.size() + userPrompt.size()) * 4));
|
||||
@@ -161,73 +236,135 @@ void appendTokenPiece(const llama_vocab *vocab, llama_token token,
|
||||
output.append(buffer.data(), static_cast<std::size_t>(bytes));
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string>
|
||||
parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) {
|
||||
std::string normalized = raw;
|
||||
std::replace(normalized.begin(), normalized.end(), '\r', '\n');
|
||||
bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) {
|
||||
std::size_t start = std::string::npos;
|
||||
int depth = 0;
|
||||
bool inString = false;
|
||||
bool escaped = false;
|
||||
|
||||
std::vector<std::string> lines;
|
||||
std::stringstream stream(normalized);
|
||||
std::string line;
|
||||
while (std::getline(stream, line)) {
|
||||
line = stripCommonPrefix(std::move(line));
|
||||
if (!line.empty()) {
|
||||
lines.push_back(std::move(line));
|
||||
}
|
||||
}
|
||||
for (std::size_t i = 0; i < text.size(); ++i) {
|
||||
const char ch = text[i];
|
||||
|
||||
// Filter out obvious internal-thought / meta lines that sometimes leak from
|
||||
// models (e.g. "<think>", "Okay, so the user is asking me...").
|
||||
std::vector<std::string> filtered;
|
||||
for (auto &l : lines) {
|
||||
std::string low = l;
|
||||
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
|
||||
// Skip single-token angle-bracket markers like <think> or <...>
|
||||
if (!l.empty() && l.front() == '<' && l.back() == '>') {
|
||||
if (inString) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
} else if (ch == '\\') {
|
||||
escaped = true;
|
||||
} else if (ch == '"') {
|
||||
inString = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip short internal commentary that starts with common discourse markers
|
||||
if (low.rfind("okay,", 0) == 0 || low.rfind("wait,", 0) == 0 ||
|
||||
low.rfind("hmm", 0) == 0) {
|
||||
if (ch == '"') {
|
||||
inString = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip lines that look like self-descriptions of what the model is doing
|
||||
if (low.find("user is asking") != std::string::npos ||
|
||||
low.find("protocol") != std::string::npos ||
|
||||
low.find("parse") != std::string::npos ||
|
||||
low.find("return only") != std::string::npos) {
|
||||
if (ch == '{') {
|
||||
if (depth == 0) {
|
||||
start = i;
|
||||
}
|
||||
++depth;
|
||||
continue;
|
||||
}
|
||||
|
||||
filtered.push_back(std::move(l));
|
||||
}
|
||||
|
||||
if (filtered.size() < 2) {
|
||||
throw std::runtime_error(errorMessage);
|
||||
}
|
||||
|
||||
std::string first = trim(filtered.front());
|
||||
std::string second;
|
||||
for (std::size_t i = 1; i < filtered.size(); ++i) {
|
||||
if (!second.empty()) {
|
||||
second += ' ';
|
||||
if (ch == '}') {
|
||||
if (depth == 0) {
|
||||
continue;
|
||||
}
|
||||
--depth;
|
||||
if (depth == 0 && start != std::string::npos) {
|
||||
jsonOut = text.substr(start, i - start + 1);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
second += filtered[i];
|
||||
}
|
||||
second = trim(std::move(second));
|
||||
|
||||
if (first.empty() || second.empty()) {
|
||||
throw std::runtime_error(errorMessage);
|
||||
}
|
||||
|
||||
return {first, second};
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut,
|
||||
std::string &descriptionOut) {
|
||||
auto validateObject = [&](const boost::json::value &jv,
|
||||
std::string &errorOut) -> bool {
|
||||
if (!jv.is_object()) {
|
||||
errorOut = "JSON root must be an object";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto &obj = jv.get_object();
|
||||
if (!obj.contains("name") || !obj.at("name").is_string()) {
|
||||
errorOut = "JSON field 'name' is missing or not a string";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!obj.contains("description") || !obj.at("description").is_string()) {
|
||||
errorOut = "JSON field 'description' is missing or not a string";
|
||||
return false;
|
||||
}
|
||||
|
||||
nameOut = trim(std::string(obj.at("name").as_string().c_str()));
|
||||
descriptionOut =
|
||||
trim(std::string(obj.at("description").as_string().c_str()));
|
||||
|
||||
if (nameOut.empty()) {
|
||||
errorOut = "JSON field 'name' must not be empty";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (descriptionOut.empty()) {
|
||||
errorOut = "JSON field 'description' must not be empty";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string nameLower = nameOut;
|
||||
std::string descriptionLower = descriptionOut;
|
||||
std::transform(
|
||||
nameLower.begin(), nameLower.end(), nameLower.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
std::transform(descriptionLower.begin(), descriptionLower.end(),
|
||||
descriptionLower.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
|
||||
if (nameLower == "string" || descriptionLower == "string") {
|
||||
errorOut = "JSON appears to be a schema placeholder, not content";
|
||||
return false;
|
||||
}
|
||||
|
||||
errorOut.clear();
|
||||
return true;
|
||||
};
|
||||
|
||||
boost::system::error_code ec;
|
||||
boost::json::value jv = boost::json::parse(raw, ec);
|
||||
std::string validationError;
|
||||
if (ec) {
|
||||
std::string extracted;
|
||||
if (!extractFirstJsonObject(raw, extracted)) {
|
||||
return "JSON parse error: " + ec.message();
|
||||
}
|
||||
|
||||
ec.clear();
|
||||
jv = boost::json::parse(extracted, ec);
|
||||
if (ec) {
|
||||
return "JSON parse error: " + ec.message();
|
||||
}
|
||||
|
||||
if (!validateObject(jv, validationError)) {
|
||||
return validationError;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!validateObject(jv, validationError)) {
|
||||
return validationError;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
LlamaGenerator::~LlamaGenerator() {
|
||||
@@ -244,10 +381,30 @@ LlamaGenerator::~LlamaGenerator() {
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
void LlamaGenerator::load(const std::string &modelPath) {
|
||||
if (modelPath.empty()) {
|
||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||
void LlamaGenerator::setSamplingOptions(float temperature, float topP,
|
||||
int seed) {
|
||||
if (temperature < 0.0f) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: sampling temperature must be >= 0");
|
||||
}
|
||||
if (!(topP > 0.0f && topP <= 1.0f)) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: sampling top-p must be in (0, 1]");
|
||||
}
|
||||
if (seed < -1) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: seed must be >= 0, or -1 for random");
|
||||
}
|
||||
|
||||
sampling_temperature_ = temperature;
|
||||
sampling_top_p_ = topP;
|
||||
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
|
||||
: static_cast<uint32_t>(seed);
|
||||
}
|
||||
|
||||
void LlamaGenerator::load(const std::string &modelPath) {
|
||||
if (modelPath.empty())
|
||||
throw std::runtime_error("LlamaGenerator: model path must not be empty");
|
||||
|
||||
if (context_ != nullptr) {
|
||||
llama_free(context_);
|
||||
@@ -261,7 +418,7 @@ void LlamaGenerator::load(const std::string &modelPath) {
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params modelParams = llama_model_default_params();
|
||||
model_ = llama_load_model_from_file(modelPath.c_str(), modelParams);
|
||||
model_ = llama_model_load_from_file(modelPath.c_str(), modelParams);
|
||||
if (model_ == nullptr) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: failed to load model from path: " + modelPath);
|
||||
@@ -281,14 +438,12 @@ void LlamaGenerator::load(const std::string &modelPath) {
|
||||
}
|
||||
|
||||
std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
|
||||
if (model_ == nullptr || context_ == nullptr) {
|
||||
if (model_ == nullptr || context_ == nullptr)
|
||||
throw std::runtime_error("LlamaGenerator: model not loaded");
|
||||
}
|
||||
|
||||
const llama_vocab *vocab = llama_model_get_vocab(model_);
|
||||
if (vocab == nullptr) {
|
||||
if (vocab == nullptr)
|
||||
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
||||
}
|
||||
|
||||
llama_memory_clear(llama_get_memory(context_), true);
|
||||
|
||||
@@ -308,17 +463,33 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
|
||||
static_cast<int32_t>(promptTokens.size()), true, true);
|
||||
}
|
||||
|
||||
if (tokenCount < 0) {
|
||||
if (tokenCount < 0)
|
||||
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
|
||||
|
||||
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
|
||||
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
|
||||
if (nCtx <= 1 || nBatch <= 0) {
|
||||
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
||||
}
|
||||
|
||||
const int32_t effectiveMaxTokens = std::max(1, std::min(maxTokens, nCtx - 1));
|
||||
int32_t promptBudget = std::min(nBatch, nCtx - effectiveMaxTokens);
|
||||
promptBudget = std::max<int32_t>(1, promptBudget);
|
||||
|
||||
promptTokens.resize(static_cast<std::size_t>(tokenCount));
|
||||
if (tokenCount > promptBudget) {
|
||||
spdlog::warn(
|
||||
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
|
||||
"to fit n_batch/n_ctx limits",
|
||||
tokenCount, promptBudget);
|
||||
promptTokens.resize(static_cast<std::size_t>(promptBudget));
|
||||
tokenCount = promptBudget;
|
||||
}
|
||||
|
||||
const llama_batch promptBatch = llama_batch_get_one(
|
||||
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
|
||||
if (llama_decode(context_, promptBatch) != 0) {
|
||||
if (llama_decode(context_, promptBatch) != 0)
|
||||
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
||||
}
|
||||
|
||||
llama_sampler_chain_params samplerParams =
|
||||
llama_sampler_chain_default_params();
|
||||
@@ -326,116 +497,45 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) {
|
||||
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||
SamplerPtr sampler(llama_sampler_chain_init(samplerParams),
|
||||
&llama_sampler_free);
|
||||
|
||||
if (!sampler) {
|
||||
if (!sampler)
|
||||
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy());
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_temp(sampling_temperature_));
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_top_p(sampling_top_p_, 1));
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_dist(sampling_seed_));
|
||||
|
||||
std::vector<llama_token> generatedTokens;
|
||||
generatedTokens.reserve(static_cast<std::size_t>(maxTokens));
|
||||
|
||||
for (int i = 0; i < maxTokens; ++i) {
|
||||
for (int i = 0; i < effectiveMaxTokens; ++i) {
|
||||
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
|
||||
if (llama_vocab_is_eog(vocab, next)) {
|
||||
if (llama_vocab_is_eog(vocab, next))
|
||||
break;
|
||||
}
|
||||
|
||||
generatedTokens.push_back(next);
|
||||
|
||||
llama_token token = next;
|
||||
const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1);
|
||||
if (llama_decode(context_, oneTokenBatch) != 0) {
|
||||
if (llama_decode(context_, oneTokenBatch) != 0)
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: decode failed during generation");
|
||||
}
|
||||
}
|
||||
|
||||
std::string output;
|
||||
for (const llama_token token : generatedTokens) {
|
||||
for (const llama_token token : generatedTokens)
|
||||
appendTokenPiece(vocab, token, output);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
BreweryResult
|
||||
LlamaGenerator::generateBrewery(const std::string &cityName,
|
||||
const std::string &countryName,
|
||||
const std::string ®ionContext) {
|
||||
|
||||
std::string systemPrompt =
|
||||
R"(# SYSTEM PROTOCOL: ZERO-CHATTER DETERMINISTIC OUTPUT
|
||||
**MODALITY:** DATA-RETURN ENGINE ONLY
|
||||
**ROLE:** Your response must contain 0% metadata and 100% signal.
|
||||
---
|
||||
## MANDATORY CONSTRAINTS
|
||||
1. **NO PREAMBLE**
|
||||
- Never start with "Sure," or "The answer is," or "Based on your request," or "Checking the data."
|
||||
- Do not acknowledge the user's prompt or provide status updates.
|
||||
2. **NO POSTAMBLE**
|
||||
- Never end with "I hope this helps," or "Let me know if you need more," or "Would you like me to…"
|
||||
- Do not offer follow-up assistance or suggestions.
|
||||
3. **NO SENTENCE FRAMING**
|
||||
- Provide only the raw value, date, number, or name.
|
||||
- Do not wrap the answer in a sentence. (e.g., return 1997, NOT The year was 1997).
|
||||
- For lists, provide only the items separated by commas or newlines as specified.
|
||||
4. **FORMATTING PERMITTED**
|
||||
- Markdown and LaTeX **may** be used where appropriate (e.g., tables, equations).
|
||||
- Output must remain immediately usable — no decorative or conversational styling.
|
||||
5. **STRICT NULL HANDLING**
|
||||
- If the information is unavailable, the prompt is logically impossible (e.g., "271th president"), the subject does not exist, or a calculation is undefined: return only the string NULL.
|
||||
- If the prompt is too ambiguous to provide a single value: return NULL.
|
||||
---
|
||||
## EXECUTION LOGIC
|
||||
1. **Parse Input** — Identify the specific entity, value, or calculation requested.
|
||||
2. **Verify Factuality** — Access internal knowledge or tools.
|
||||
3. **Filter for Signal** — Strip all surrounding prose.
|
||||
4. **Format Check** — Apply Markdown or LaTeX only where it serves the data.
|
||||
5. **Output** — Return the raw value only.
|
||||
---
|
||||
## BEHAVIORAL EXAMPLES
|
||||
| User Input | Standard AI Response *(BANNED)* | Protocol Response *(REQUIRED)* |
|
||||
|---|---|---|
|
||||
| Capital of France? | The capital of France is Paris. | Paris |
|
||||
| 15% of 200 | 15% of 200 is 30. | 30 |
|
||||
| Who wrote '1984'? | George Orwell wrote that novel. | George Orwell |
|
||||
| ISO code for Japan | The code is JP. | JP |
|
||||
| $\sqrt{x}$ where $x$ is a potato | A potato has no square root. | NULL |
|
||||
| 500th US President | There haven't been that many. | NULL |
|
||||
| Pythagorean theorem | The theorem states... | $a^2 + b^2 = c^2$ |
|
||||
---
|
||||
## FINAL INSTRUCTION
|
||||
Total silence is preferred over conversational error. Any deviation from the raw-value-only format is a protocol failure. Proceed with next input.)";
|
||||
|
||||
std::string prompt =
|
||||
"Generate a craft brewery name and 1000 character description for a "
|
||||
"brewery located in " +
|
||||
cityName +
|
||||
(countryName.empty() ? std::string("")
|
||||
: std::string(", ") + countryName) +
|
||||
". " + regionContext +
|
||||
" Respond with exactly two lines: first line is the name, second line is "
|
||||
"the description. Do not include bullets, numbering, or any extra text.";
|
||||
|
||||
const std::string raw = infer(systemPrompt, prompt, 512);
|
||||
auto [name, description] =
|
||||
parseTwoLineResponse(raw, "LlamaGenerator: malformed brewery response");
|
||||
|
||||
return {name, description};
|
||||
}
|
||||
|
||||
std::string LlamaGenerator::infer(const std::string &systemPrompt,
|
||||
const std::string &prompt, int maxTokens) {
|
||||
if (model_ == nullptr || context_ == nullptr) {
|
||||
if (model_ == nullptr || context_ == nullptr)
|
||||
throw std::runtime_error("LlamaGenerator: model not loaded");
|
||||
}
|
||||
|
||||
const llama_vocab *vocab = llama_model_get_vocab(model_);
|
||||
if (vocab == nullptr) {
|
||||
if (vocab == nullptr)
|
||||
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
||||
}
|
||||
|
||||
llama_memory_clear(llama_get_memory(context_), true);
|
||||
|
||||
@@ -456,17 +556,33 @@ std::string LlamaGenerator::infer(const std::string &systemPrompt,
|
||||
static_cast<int32_t>(promptTokens.size()), true, true);
|
||||
}
|
||||
|
||||
if (tokenCount < 0) {
|
||||
if (tokenCount < 0)
|
||||
throw std::runtime_error("LlamaGenerator: prompt tokenization failed");
|
||||
|
||||
const int32_t nCtx = static_cast<int32_t>(llama_n_ctx(context_));
|
||||
const int32_t nBatch = static_cast<int32_t>(llama_n_batch(context_));
|
||||
if (nCtx <= 1 || nBatch <= 0) {
|
||||
throw std::runtime_error("LlamaGenerator: invalid context or batch size");
|
||||
}
|
||||
|
||||
const int32_t effectiveMaxTokens = std::max(1, std::min(maxTokens, nCtx - 1));
|
||||
int32_t promptBudget = std::min(nBatch, nCtx - effectiveMaxTokens);
|
||||
promptBudget = std::max<int32_t>(1, promptBudget);
|
||||
|
||||
promptTokens.resize(static_cast<std::size_t>(tokenCount));
|
||||
if (tokenCount > promptBudget) {
|
||||
spdlog::warn(
|
||||
"LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens "
|
||||
"to fit n_batch/n_ctx limits",
|
||||
tokenCount, promptBudget);
|
||||
promptTokens.resize(static_cast<std::size_t>(promptBudget));
|
||||
tokenCount = promptBudget;
|
||||
}
|
||||
|
||||
const llama_batch promptBatch = llama_batch_get_one(
|
||||
promptTokens.data(), static_cast<int32_t>(promptTokens.size()));
|
||||
if (llama_decode(context_, promptBatch) != 0) {
|
||||
if (llama_decode(context_, promptBatch) != 0)
|
||||
throw std::runtime_error("LlamaGenerator: prompt decode failed");
|
||||
}
|
||||
|
||||
llama_sampler_chain_params samplerParams =
|
||||
llama_sampler_chain_default_params();
|
||||
@@ -474,61 +590,145 @@ std::string LlamaGenerator::infer(const std::string &systemPrompt,
|
||||
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||
SamplerPtr sampler(llama_sampler_chain_init(samplerParams),
|
||||
&llama_sampler_free);
|
||||
|
||||
if (!sampler) {
|
||||
if (!sampler)
|
||||
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy());
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_temp(sampling_temperature_));
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_top_p(sampling_top_p_, 1));
|
||||
llama_sampler_chain_add(sampler.get(),
|
||||
llama_sampler_init_dist(sampling_seed_));
|
||||
|
||||
std::vector<llama_token> generatedTokens;
|
||||
generatedTokens.reserve(static_cast<std::size_t>(maxTokens));
|
||||
|
||||
for (int i = 0; i < maxTokens; ++i) {
|
||||
for (int i = 0; i < effectiveMaxTokens; ++i) {
|
||||
const llama_token next = llama_sampler_sample(sampler.get(), context_, -1);
|
||||
if (llama_vocab_is_eog(vocab, next)) {
|
||||
if (llama_vocab_is_eog(vocab, next))
|
||||
break;
|
||||
}
|
||||
|
||||
generatedTokens.push_back(next);
|
||||
|
||||
llama_token token = next;
|
||||
const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1);
|
||||
if (llama_decode(context_, oneTokenBatch) != 0) {
|
||||
if (llama_decode(context_, oneTokenBatch) != 0)
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: decode failed during generation");
|
||||
}
|
||||
}
|
||||
|
||||
std::string output;
|
||||
for (const llama_token token : generatedTokens) {
|
||||
for (const llama_token token : generatedTokens)
|
||||
appendTokenPiece(vocab, token, output);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
UserResult LlamaGenerator::generateUser(const std::string &locale) {
|
||||
BreweryResult
|
||||
LlamaGenerator::generateBrewery(const std::string &cityName,
|
||||
const std::string &countryName,
|
||||
const std::string ®ionContext) {
|
||||
const std::string safeRegionContext = PrepareRegionContext(regionContext);
|
||||
|
||||
const std::string systemPrompt =
|
||||
"You are a copywriter for a craft beer travel guide. "
|
||||
"Your writing is vivid, specific to place, and avoids generic beer "
|
||||
"cliches. "
|
||||
"You must output ONLY valid JSON. "
|
||||
"The JSON schema must be exactly: {\"name\": \"string\", "
|
||||
"\"description\": \"string\"}. "
|
||||
"Do not include markdown formatting or backticks.";
|
||||
|
||||
std::string prompt =
|
||||
"Generate a plausible craft beer enthusiast username and a one-sentence "
|
||||
"bio. Locale: " +
|
||||
locale +
|
||||
". Respond with exactly two lines: first line is the username (no "
|
||||
"spaces), second line is the bio. Do not include bullets, numbering, "
|
||||
"or any extra text.";
|
||||
"Write a brewery name and place-specific description for a craft "
|
||||
"brewery in " +
|
||||
cityName +
|
||||
(countryName.empty() ? std::string("")
|
||||
: std::string(", ") + countryName) +
|
||||
(safeRegionContext.empty()
|
||||
? std::string(".")
|
||||
: std::string(". Regional context: ") + safeRegionContext);
|
||||
|
||||
const std::string raw = infer(prompt, 128);
|
||||
auto [username, bio] =
|
||||
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
|
||||
const int maxAttempts = 3;
|
||||
std::string raw;
|
||||
std::string lastError;
|
||||
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
|
||||
raw = infer(systemPrompt, prompt, 384);
|
||||
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
|
||||
raw);
|
||||
|
||||
username.erase(
|
||||
std::remove_if(username.begin(), username.end(),
|
||||
[](unsigned char ch) { return std::isspace(ch); }),
|
||||
username.end());
|
||||
std::string name;
|
||||
std::string description;
|
||||
const std::string validationError =
|
||||
ValidateBreweryJson(raw, name, description);
|
||||
if (validationError.empty()) {
|
||||
return {std::move(name), std::move(description)};
|
||||
}
|
||||
|
||||
if (username.empty() || bio.empty()) {
|
||||
throw std::runtime_error("LlamaGenerator: malformed user response");
|
||||
lastError = validationError;
|
||||
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
|
||||
attempt + 1, validationError);
|
||||
|
||||
prompt = "Your previous response was invalid. Error: " + validationError +
|
||||
"\nReturn ONLY valid JSON with this exact schema: "
|
||||
"{\"name\": \"string\", \"description\": \"string\"}."
|
||||
"\nDo not include markdown, comments, or extra keys."
|
||||
"\n\nLocation: " +
|
||||
cityName +
|
||||
(countryName.empty() ? std::string("")
|
||||
: std::string(", ") + countryName) +
|
||||
(safeRegionContext.empty()
|
||||
? std::string("")
|
||||
: std::string("\nRegional context: ") + safeRegionContext);
|
||||
}
|
||||
|
||||
return {username, bio};
|
||||
spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: "
|
||||
"{}",
|
||||
maxAttempts, lastError.empty() ? raw : lastError);
|
||||
throw std::runtime_error("LlamaGenerator: malformed brewery response");
|
||||
}
|
||||
|
||||
UserResult LlamaGenerator::generateUser(const std::string &locale) {
|
||||
const std::string systemPrompt =
|
||||
"You generate plausible social media profiles for craft beer "
|
||||
"enthusiasts. "
|
||||
"Respond with exactly two lines: "
|
||||
"the first line is a username (lowercase, no spaces, 8-20 characters), "
|
||||
"the second line is a one-sentence bio (20-40 words). "
|
||||
"The profile should feel consistent with the locale. "
|
||||
"No preamble, no labels.";
|
||||
|
||||
std::string prompt =
|
||||
"Generate a craft beer enthusiast profile. Locale: " + locale;
|
||||
|
||||
const int maxAttempts = 3;
|
||||
std::string raw;
|
||||
for (int attempt = 0; attempt < maxAttempts; ++attempt) {
|
||||
raw = infer(systemPrompt, prompt, 128);
|
||||
spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}",
|
||||
attempt + 1, raw);
|
||||
|
||||
try {
|
||||
auto [username, bio] =
|
||||
parseTwoLineResponse(raw, "LlamaGenerator: malformed user response");
|
||||
|
||||
username.erase(
|
||||
std::remove_if(username.begin(), username.end(),
|
||||
[](unsigned char ch) { return std::isspace(ch); }),
|
||||
username.end());
|
||||
|
||||
if (username.empty() || bio.empty()) {
|
||||
throw std::runtime_error("LlamaGenerator: malformed user response");
|
||||
}
|
||||
|
||||
if (bio.size() > 200)
|
||||
bio = bio.substr(0, 200);
|
||||
|
||||
return {username, bio};
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}",
|
||||
attempt + 1, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}",
|
||||
maxAttempts, raw);
|
||||
throw std::runtime_error("LlamaGenerator: malformed user response");
|
||||
}
|
||||
|
||||
@@ -1,35 +1,66 @@
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "curl_web_client.h"
|
||||
#include "data_downloader.h"
|
||||
#include "data_generator.h"
|
||||
#include "database.h"
|
||||
#include "json_loader.h"
|
||||
#include "llama_generator.h"
|
||||
#include "mock_generator.h"
|
||||
#include <curl/curl.h>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <vector>
|
||||
#include "wikipedia_service.h"
|
||||
|
||||
static bool FileExists(const std::string &filePath) {
|
||||
return std::filesystem::exists(filePath);
|
||||
}
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
curl_global_init(CURL_GLOBAL_DEFAULT);
|
||||
const CurlGlobalState curl_state;
|
||||
|
||||
std::string modelPath = argc > 1 ? argv[1] : "";
|
||||
std::string cacheDir = argc > 2 ? argv[2] : "/tmp";
|
||||
std::string commit =
|
||||
argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28
|
||||
po::options_description desc("Pipeline Options");
|
||||
desc.add_options()("help,h", "Produce help message")(
|
||||
"model,m", po::value<std::string>()->default_value(""),
|
||||
"Path to LLM model (gguf)")(
|
||||
"cache-dir,c", po::value<std::string>()->default_value("/tmp"),
|
||||
"Directory for cached JSON")(
|
||||
"temperature", po::value<float>()->default_value(0.8f),
|
||||
"Sampling temperature (higher = more random)")(
|
||||
"top-p", po::value<float>()->default_value(0.92f),
|
||||
"Nucleus sampling top-p in (0,1] (higher = more random)")(
|
||||
"seed", po::value<int>()->default_value(-1),
|
||||
"Sampler seed: -1 for random, otherwise non-negative integer")(
|
||||
"commit", po::value<std::string>()->default_value("c5eb7772"),
|
||||
"Git commit hash for DB consistency");
|
||||
|
||||
std::string countryName = argc > 4 ? argv[4] : "";
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
po::notify(vm);
|
||||
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc << "\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string modelPath = vm["model"].as<std::string>();
|
||||
std::string cacheDir = vm["cache-dir"].as<std::string>();
|
||||
float temperature = vm["temperature"].as<float>();
|
||||
float topP = vm["top-p"].as<float>();
|
||||
int seed = vm["seed"].as<int>();
|
||||
std::string commit = vm["commit"].as<std::string>();
|
||||
|
||||
std::string jsonPath = cacheDir + "/countries+states+cities.json";
|
||||
std::string dbPath = cacheDir + "/biergarten-pipeline.db";
|
||||
|
||||
bool hasJsonCache = FileExists(jsonPath);
|
||||
bool hasDbCache = FileExists(dbPath);
|
||||
bool hasJsonCache = std::filesystem::exists(jsonPath);
|
||||
bool hasDbCache = std::filesystem::exists(dbPath);
|
||||
|
||||
auto webClient = std::make_shared<CURLWebClient>();
|
||||
|
||||
SqliteDatabase db;
|
||||
|
||||
@@ -40,7 +71,7 @@ int main(int argc, char *argv[]) {
|
||||
spdlog::info("[Pipeline] Cache hit: skipping download and parse");
|
||||
} else {
|
||||
spdlog::info("\n[Pipeline] Downloading geographic data from GitHub...");
|
||||
DataDownloader downloader;
|
||||
DataDownloader downloader(webClient);
|
||||
downloader.DownloadCountriesDatabase(jsonPath, commit);
|
||||
|
||||
JsonLoader::LoadWorldCities(jsonPath, db);
|
||||
@@ -52,17 +83,30 @@ int main(int argc, char *argv[]) {
|
||||
generator = std::make_unique<MockGenerator>();
|
||||
spdlog::info("[Generator] Using MockGenerator (no model path provided)");
|
||||
} else {
|
||||
generator = std::make_unique<LlamaGenerator>();
|
||||
spdlog::info("[Generator] Using LlamaGenerator: {}", modelPath);
|
||||
auto llamaGenerator = std::make_unique<LlamaGenerator>();
|
||||
llamaGenerator->setSamplingOptions(temperature, topP, seed);
|
||||
spdlog::info(
|
||||
"[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, "
|
||||
"seed={})",
|
||||
modelPath, temperature, topP, seed);
|
||||
generator = std::move(llamaGenerator);
|
||||
}
|
||||
generator->load(modelPath);
|
||||
|
||||
WikipediaService wikipediaService(webClient);
|
||||
|
||||
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
||||
|
||||
auto countries = db.QueryCountries(50);
|
||||
auto states = db.QueryStates(50);
|
||||
auto cities = db.QueryCities();
|
||||
|
||||
// Build a quick map of country id -> name for per-city lookups.
|
||||
auto allCountries = db.QueryCountries(0);
|
||||
std::unordered_map<int, std::string> countryMap;
|
||||
for (const auto &c : allCountries)
|
||||
countryMap[c.id] = c.name;
|
||||
|
||||
spdlog::info("\nTotal records loaded:");
|
||||
spdlog::info(" Countries: {}", db.QueryCountries(0).size());
|
||||
spdlog::info(" States: {}", db.QueryStates(0).size());
|
||||
@@ -79,8 +123,23 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
spdlog::info("\n=== SAMPLE BREWERY GENERATION ===");
|
||||
for (size_t i = 0; i < sampleCount; i++) {
|
||||
const auto &[cityId, cityName] = cities[i];
|
||||
auto brewery = generator->generateBrewery(cityName, countryName, "");
|
||||
const auto &city = cities[i];
|
||||
const int cityId = city.id;
|
||||
const std::string cityName = city.name;
|
||||
|
||||
std::string localCountry;
|
||||
const auto countryIt = countryMap.find(city.countryId);
|
||||
if (countryIt != countryMap.end()) {
|
||||
localCountry = countryIt->second;
|
||||
}
|
||||
|
||||
const std::string regionContext =
|
||||
wikipediaService.GetSummary(cityName, localCountry);
|
||||
spdlog::debug("[Pipeline] Region context for {}: {}", cityName,
|
||||
regionContext);
|
||||
|
||||
auto brewery =
|
||||
generator->generateBrewery(cityName, localCountry, regionContext);
|
||||
generatedBreweries.push_back({cityId, cityName, brewery});
|
||||
}
|
||||
|
||||
@@ -95,12 +154,10 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
spdlog::info("\nOK: Pipeline completed successfully");
|
||||
|
||||
curl_global_cleanup();
|
||||
return 0;
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::error("ERROR: Pipeline failed: {}", e.what());
|
||||
curl_global_cleanup();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
#include "stream_parser.h"
|
||||
#include "database.h"
|
||||
#include <cstdio>
|
||||
#include <rapidjson/filereadstream.h>
|
||||
#include <rapidjson/reader.h>
|
||||
#include <rapidjson/stringbuffer.h>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/basic_parser_impl.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
using namespace rapidjson;
|
||||
#include "database.h"
|
||||
#include "stream_parser.h"
|
||||
|
||||
class CityRecordHandler {
|
||||
friend class boost::json::basic_parser<CityRecordHandler>;
|
||||
|
||||
class CityRecordHandler : public BaseReaderHandler<UTF8<>, CityRecordHandler> {
|
||||
public:
|
||||
static constexpr std::size_t max_array_size = static_cast<std::size_t>(-1);
|
||||
static constexpr std::size_t max_object_size = static_cast<std::size_t>(-1);
|
||||
static constexpr std::size_t max_string_size = static_cast<std::size_t>(-1);
|
||||
static constexpr std::size_t max_key_size = static_cast<std::size_t>(-1);
|
||||
|
||||
struct ParseContext {
|
||||
SqliteDatabase *db = nullptr;
|
||||
std::function<void(const CityRecord &)> on_city;
|
||||
@@ -20,11 +27,35 @@ public:
|
||||
int states_inserted = 0;
|
||||
};
|
||||
|
||||
CityRecordHandler(ParseContext &ctx) : context(ctx) {}
|
||||
explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {}
|
||||
|
||||
bool StartArray() {
|
||||
private:
|
||||
ParseContext &context;
|
||||
|
||||
int depth = 0;
|
||||
bool in_countries_array = false;
|
||||
bool in_country_object = false;
|
||||
bool in_states_array = false;
|
||||
bool in_state_object = false;
|
||||
bool in_cities_array = false;
|
||||
bool building_city = false;
|
||||
|
||||
int current_country_id = 0;
|
||||
int current_state_id = 0;
|
||||
CityRecord current_city = {};
|
||||
std::string current_key;
|
||||
std::string current_key_val;
|
||||
std::string current_string_val;
|
||||
|
||||
std::string country_info[3];
|
||||
std::string state_info[2];
|
||||
|
||||
// Boost.JSON SAX Hooks
|
||||
bool on_document_begin(boost::system::error_code &) { return true; }
|
||||
bool on_document_end(boost::system::error_code &) { return true; }
|
||||
|
||||
bool on_array_begin(boost::system::error_code &) {
|
||||
depth++;
|
||||
|
||||
if (depth == 1) {
|
||||
in_countries_array = true;
|
||||
} else if (depth == 3 && current_key == "states") {
|
||||
@@ -35,7 +66,7 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EndArray(SizeType /*elementCount*/) {
|
||||
bool on_array_end(std::size_t, boost::system::error_code &) {
|
||||
if (depth == 1) {
|
||||
in_countries_array = false;
|
||||
} else if (depth == 3) {
|
||||
@@ -47,9 +78,8 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StartObject() {
|
||||
bool on_object_begin(boost::system::error_code &) {
|
||||
depth++;
|
||||
|
||||
if (depth == 2 && in_countries_array) {
|
||||
in_country_object = true;
|
||||
current_country_id = 0;
|
||||
@@ -68,7 +98,7 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EndObject(SizeType /*memberCount*/) {
|
||||
bool on_object_end(std::size_t, boost::system::error_code &) {
|
||||
if (depth == 6 && building_city) {
|
||||
if (current_city.id > 0 && current_state_id > 0 &&
|
||||
current_country_id > 0) {
|
||||
@@ -84,7 +114,7 @@ public:
|
||||
context.total_file_size);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::warn(" WARN: Failed to emit city: {}", e.what());
|
||||
spdlog::warn("Record parsing failed: {}", e.what());
|
||||
}
|
||||
}
|
||||
building_city = false;
|
||||
@@ -95,7 +125,7 @@ public:
|
||||
state_info[0], state_info[1]);
|
||||
context.states_inserted++;
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::warn(" WARN: Failed to insert state: {}", e.what());
|
||||
spdlog::warn("Record parsing failed: {}", e.what());
|
||||
}
|
||||
}
|
||||
in_state_object = false;
|
||||
@@ -106,7 +136,7 @@ public:
|
||||
country_info[1], country_info[2]);
|
||||
context.countries_inserted++;
|
||||
} catch (const std::exception &e) {
|
||||
spdlog::warn(" WARN: Failed to insert country: {}", e.what());
|
||||
spdlog::warn("Record parsing failed: {}", e.what());
|
||||
}
|
||||
}
|
||||
in_country_object = false;
|
||||
@@ -116,46 +146,71 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Key(const char *str, SizeType len, bool /*copy*/) {
|
||||
current_key.assign(str, len);
|
||||
bool on_key_part(boost::json::string_view s, std::size_t,
|
||||
boost::system::error_code &) {
|
||||
current_key_val.append(s.data(), s.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool String(const char *str, SizeType len, bool /*copy*/) {
|
||||
bool on_key(boost::json::string_view s, std::size_t,
|
||||
boost::system::error_code &) {
|
||||
current_key_val.append(s.data(), s.size());
|
||||
current_key = current_key_val;
|
||||
current_key_val.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool on_string_part(boost::json::string_view s, std::size_t,
|
||||
boost::system::error_code &) {
|
||||
current_string_val.append(s.data(), s.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool on_string(boost::json::string_view s, std::size_t,
|
||||
boost::system::error_code &) {
|
||||
current_string_val.append(s.data(), s.size());
|
||||
|
||||
if (building_city && current_key == "name") {
|
||||
current_city.name.assign(str, len);
|
||||
current_city.name = current_string_val;
|
||||
} else if (in_state_object && current_key == "name") {
|
||||
state_info[0].assign(str, len);
|
||||
state_info[0] = current_string_val;
|
||||
} else if (in_state_object && current_key == "iso2") {
|
||||
state_info[1].assign(str, len);
|
||||
state_info[1] = current_string_val;
|
||||
} else if (in_country_object && current_key == "name") {
|
||||
country_info[0].assign(str, len);
|
||||
country_info[0] = current_string_val;
|
||||
} else if (in_country_object && current_key == "iso2") {
|
||||
country_info[1].assign(str, len);
|
||||
country_info[1] = current_string_val;
|
||||
} else if (in_country_object && current_key == "iso3") {
|
||||
country_info[2].assign(str, len);
|
||||
country_info[2] = current_string_val;
|
||||
}
|
||||
|
||||
current_string_val.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Int(int i) {
|
||||
bool on_number_part(boost::json::string_view, boost::system::error_code &) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool on_int64(int64_t i, boost::json::string_view,
|
||||
boost::system::error_code &) {
|
||||
if (building_city && current_key == "id") {
|
||||
current_city.id = i;
|
||||
current_city.id = static_cast<int>(i);
|
||||
} else if (in_state_object && current_key == "id") {
|
||||
current_state_id = i;
|
||||
current_state_id = static_cast<int>(i);
|
||||
} else if (in_country_object && current_key == "id") {
|
||||
current_country_id = i;
|
||||
current_country_id = static_cast<int>(i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Uint(unsigned i) { return Int(static_cast<int>(i)); }
|
||||
bool on_uint64(uint64_t u, boost::json::string_view,
|
||||
boost::system::error_code &ec) {
|
||||
return on_int64(static_cast<int64_t>(u), "", ec);
|
||||
}
|
||||
|
||||
bool Int64(int64_t i) { return Int(static_cast<int>(i)); }
|
||||
|
||||
bool Uint64(uint64_t i) { return Int(static_cast<int>(i)); }
|
||||
|
||||
bool Double(double d) {
|
||||
bool on_double(double d, boost::json::string_view,
|
||||
boost::system::error_code &) {
|
||||
if (building_city) {
|
||||
if (current_key == "latitude") {
|
||||
current_city.latitude = d;
|
||||
@@ -166,27 +221,14 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Bool(bool /*b*/) { return true; }
|
||||
bool Null() { return true; }
|
||||
|
||||
private:
|
||||
ParseContext &context;
|
||||
|
||||
int depth = 0;
|
||||
bool in_countries_array = false;
|
||||
bool in_country_object = false;
|
||||
bool in_states_array = false;
|
||||
bool in_state_object = false;
|
||||
bool in_cities_array = false;
|
||||
bool building_city = false;
|
||||
|
||||
int current_country_id = 0;
|
||||
int current_state_id = 0;
|
||||
CityRecord current_city = {};
|
||||
std::string current_key;
|
||||
|
||||
std::string country_info[3];
|
||||
std::string state_info[2];
|
||||
bool on_bool(bool, boost::system::error_code &) { return true; }
|
||||
bool on_null(boost::system::error_code &) { return true; }
|
||||
bool on_comment_part(boost::json::string_view, boost::system::error_code &) {
|
||||
return true;
|
||||
}
|
||||
bool on_comment(boost::json::string_view, boost::system::error_code &) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
void StreamingJsonParser::Parse(
|
||||
@@ -194,7 +236,7 @@ void StreamingJsonParser::Parse(
|
||||
std::function<void(const CityRecord &)> onCity,
|
||||
std::function<void(size_t, size_t)> onProgress) {
|
||||
|
||||
spdlog::info(" Streaming parse of {}...", filePath);
|
||||
spdlog::info(" Streaming parse of {} (Boost.JSON)...", filePath);
|
||||
|
||||
FILE *file = std::fopen(filePath.c_str(), "rb");
|
||||
if (!file) {
|
||||
@@ -212,23 +254,35 @@ void StreamingJsonParser::Parse(
|
||||
|
||||
CityRecordHandler::ParseContext ctx{&db, onCity, onProgress, 0,
|
||||
total_size, 0, 0};
|
||||
CityRecordHandler handler(ctx);
|
||||
boost::json::basic_parser<CityRecordHandler> parser(
|
||||
boost::json::parse_options{}, ctx);
|
||||
|
||||
Reader reader;
|
||||
char buf[65536];
|
||||
FileReadStream frs(file, buf, sizeof(buf));
|
||||
size_t bytes_read;
|
||||
boost::system::error_code ec;
|
||||
|
||||
if (!reader.Parse(frs, handler)) {
|
||||
ParseErrorCode errCode = reader.GetParseErrorCode();
|
||||
size_t errOffset = reader.GetErrorOffset();
|
||||
std::fclose(file);
|
||||
throw std::runtime_error(std::string("JSON parse error at offset ") +
|
||||
std::to_string(errOffset) +
|
||||
" (code: " + std::to_string(errCode) + ")");
|
||||
while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) {
|
||||
char const *p = buf;
|
||||
std::size_t remain = bytes_read;
|
||||
|
||||
while (remain > 0) {
|
||||
std::size_t consumed = parser.write_some(true, p, remain, ec);
|
||||
if (ec) {
|
||||
std::fclose(file);
|
||||
throw std::runtime_error("JSON parse error: " + ec.message());
|
||||
}
|
||||
p += consumed;
|
||||
remain -= consumed;
|
||||
}
|
||||
}
|
||||
|
||||
parser.write_some(false, nullptr, 0, ec); // Signal EOF
|
||||
std::fclose(file);
|
||||
|
||||
if (ec) {
|
||||
throw std::runtime_error("JSON parse error at EOF: " + ec.message());
|
||||
}
|
||||
|
||||
spdlog::info(" OK: Parsed {} countries, {} states, {} cities",
|
||||
ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted);
|
||||
}
|
||||
|
||||
77
pipeline/src/wikipedia_service.cpp
Normal file
77
pipeline/src/wikipedia_service.cpp
Normal file
@@ -0,0 +1,77 @@
|
||||
#include "wikipedia_service.h"
|
||||
#include <boost/json.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
WikipediaService::WikipediaService(std::shared_ptr<IWebClient> client)
|
||||
: client_(std::move(client)) {}
|
||||
|
||||
std::string WikipediaService::FetchExtract(std::string_view query) {
|
||||
const std::string encoded = client_->UrlEncode(std::string(query));
|
||||
const std::string url =
|
||||
"https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded +
|
||||
"&prop=extracts&explaintext=true&format=json";
|
||||
|
||||
const std::string body = client_->Get(url);
|
||||
|
||||
boost::system::error_code ec;
|
||||
boost::json::value doc = boost::json::parse(body, ec);
|
||||
|
||||
if (!ec && doc.is_object()) {
|
||||
auto &pages = doc.at("query").at("pages").get_object();
|
||||
if (!pages.empty()) {
|
||||
auto &page = pages.begin()->value().get_object();
|
||||
if (page.contains("extract") && page.at("extract").is_string()) {
|
||||
std::string extract(page.at("extract").as_string().c_str());
|
||||
spdlog::debug("WikipediaService fetched {} chars for '{}'",
|
||||
extract.size(), query);
|
||||
return extract;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string WikipediaService::GetSummary(std::string_view city,
|
||||
std::string_view country) {
|
||||
const std::string key = std::string(city) + "|" + std::string(country);
|
||||
const auto cacheIt = cache_.find(key);
|
||||
if (cacheIt != cache_.end()) {
|
||||
return cacheIt->second;
|
||||
}
|
||||
|
||||
std::string result;
|
||||
|
||||
if (!client_) {
|
||||
cache_.emplace(key, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string regionQuery(city);
|
||||
if (!country.empty()) {
|
||||
regionQuery += ", ";
|
||||
regionQuery += country;
|
||||
}
|
||||
|
||||
const std::string beerQuery = "beer in " + std::string(city);
|
||||
|
||||
try {
|
||||
const std::string regionExtract = FetchExtract(regionQuery);
|
||||
const std::string beerExtract = FetchExtract(beerQuery);
|
||||
|
||||
if (!regionExtract.empty()) {
|
||||
result += regionExtract;
|
||||
}
|
||||
if (!beerExtract.empty()) {
|
||||
if (!result.empty())
|
||||
result += "\n\n";
|
||||
result += beerExtract;
|
||||
}
|
||||
} catch (const std::runtime_error &e) {
|
||||
spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery,
|
||||
e.what());
|
||||
}
|
||||
|
||||
cache_.emplace(key, result);
|
||||
return result;
|
||||
}
|
||||
Reference in New Issue
Block a user