diff --git a/pipeline/CMakeLists.txt b/pipeline/CMakeLists.txt index 15f7171..f9d27e1 100644 --- a/pipeline/CMakeLists.txt +++ b/pipeline/CMakeLists.txt @@ -1,49 +1,52 @@ cmake_minimum_required(VERSION 3.20) project(biergarten-pipeline VERSION 0.1.0 LANGUAGES CXX) -cmake_policy(SET CMP0167 NEW) +# Allows older dependencies to configure on newer CMake. +set(CMAKE_POLICY_VERSION_MINIMUM 3.5) +# Policies +cmake_policy(SET CMP0167 NEW) # FindBoost improvements + +# Global Settings set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# ----------------------------------------------------------------------------- +# Compiler Options & Warnings (Interface Library) +# ----------------------------------------------------------------------------- +add_library(project_options INTERFACE) +target_compile_options(project_options INTERFACE + $<$: + -Wall -Wextra -Wpedantic -Wshadow -Wconversion -Wsign-conversion -Wunused + > + $<$: + /W4 /WX /permissive- + > +) + +# ----------------------------------------------------------------------------- +# Dependencies +# ----------------------------------------------------------------------------- find_package(CURL REQUIRED) -find_package(Boost REQUIRED COMPONENTS unit_test_framework) find_package(SQLite3 REQUIRED) +find_package(Boost 1.75 REQUIRED COMPONENTS program_options json) include(FetchContent) -# RapidJSON (header-only) for true SAX parsing -# Using direct header-only approach without CMakeLists.txt -FetchContent_Declare( - rapidjson - GIT_REPOSITORY https://github.com/Tencent/rapidjson.git - GIT_TAG v1.1.0 - SOURCE_SUBDIR "" # Don't use RapidJSON's CMakeLists.txt -) -FetchContent_GetProperties(rapidjson) -if(NOT rapidjson_POPULATED) - FetchContent_Populate(rapidjson) - # RapidJSON is header-only; just make include path available -endif() - -# spdlog (logging) +# spdlog (Logging) FetchContent_Declare( spdlog GIT_REPOSITORY https://github.com/gabime/spdlog.git GIT_TAG v1.11.0 ) -FetchContent_GetProperties(spdlog) -if(NOT spdlog_POPULATED) - FetchContent_Populate(spdlog) - add_subdirectory(${spdlog_SOURCE_DIR} ${spdlog_BINARY_DIR} EXCLUDE_FROM_ALL) -endif() +FetchContent_MakeAvailable(spdlog) -# llama.cpp (on-device inference) +# llama.cpp (LLM Inference) set(LLAMA_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(LLAMA_BUILD_SERVER OFF CACHE BOOL "" FORCE) - FetchContent_Declare( llama_cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git @@ -57,90 +60,53 @@ if(TARGET llama) ) endif() -file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS - src/*.cpp +# ----------------------------------------------------------------------------- +# Main Executable +# ----------------------------------------------------------------------------- +set(PIPELINE_SOURCES + src/curl_web_client.cpp + src/data_downloader.cpp + src/database.cpp + src/json_loader.cpp + src/llama_generator.cpp + src/mock_generator.cpp + src/stream_parser.cpp + src/wikipedia_service.cpp + src/main.cpp ) -add_executable(biergarten-pipeline ${SOURCES}) +add_executable(biergarten-pipeline ${PIPELINE_SOURCES}) target_include_directories(biergarten-pipeline PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/includes - ${rapidjson_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/includes ${llama_cpp_SOURCE_DIR}/include ) target_link_libraries(biergarten-pipeline PRIVATE + project_options CURL::libcurl - Boost::unit_test_framework SQLite::SQLite3 spdlog::spdlog llama + Boost::program_options + Boost::json ) -target_compile_options(biergarten-pipeline PRIVATE - $<$: - -Wall - -Wextra - -Wpedantic - -Wshadow - -Wconversion - -Wsign-conversion - > - $<$: - /W4 - /WX - > -) - +# ----------------------------------------------------------------------------- +# Post-Build Steps & Utilities +# ----------------------------------------------------------------------------- add_custom_command(TARGET biergarten-pipeline POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory - ${CMAKE_CURRENT_SOURCE_DIR}/output - COMMENT "Creating output/ directory for seed SQL files" + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_SOURCE_DIR}/output + COMMENT "Ensuring output directory exists" ) find_program(VALGRIND valgrind) if(VALGRIND) add_custom_target(memcheck - COMMAND ${VALGRIND} - --leak-check=full - --error-exitcode=1 - $ --help + COMMAND ${VALGRIND} --leak-check=full --error-exitcode=1 $ --help DEPENDS biergarten-pipeline - COMMENT "Running Valgrind memcheck" + COMMENT "Running Valgrind memory check" ) endif() - -include(CTest) - -if(BUILD_TESTING) - find_package(Boost REQUIRED COMPONENTS unit_test_framework) - - file(GLOB_RECURSE TEST_SOURCES CONFIGURE_DEPENDS - tests/*.cpp - tests/*.cc - tests/*.cxx - ) - - if(TEST_SOURCES) - add_executable(biergarten-pipeline-tests ${TEST_SOURCES}) - - target_include_directories(biergarten-pipeline-tests - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/include - ) - - target_link_libraries(biergarten-pipeline-tests - PRIVATE - Boost::unit_test_framework - CURL::libcurl - nlohmann_json::nlohmann_json - ) - - add_test( - NAME biergarten-pipeline-tests - COMMAND biergarten-pipeline-tests - ) - endif() -endif() diff --git a/pipeline/README.md b/pipeline/README.md index e9e0d60..9488b53 100644 --- a/pipeline/README.md +++ b/pipeline/README.md @@ -1,414 +1,199 @@ -## Biergarten Pipeline +Biergarten Pipeline -## Overview +Overview The pipeline orchestrates five key stages: -1. **Download**: Fetches `countries+states+cities.json` from a pinned GitHub commit with optional local caching. -2. **Parse**: Streams JSON using RapidJSON SAX parser, extracting country/state/city records without loading the entire file into memory. -3. **Buffer**: Routes city records through a bounded concurrent queue to decouple parsing from writes. -4. **Store**: Inserts records with concurrent thread safety using an in-memory SQLite database. -5. **Generate**: Produces mock brewery metadata for a sample of cities (mockup for future LLM integration). +Download: Fetches countries+states+cities.json from a pinned GitHub commit with optional local caching. ---- +Parse: Streams JSON using Boost.JSON's basic_parser to extract country/state/city records without loading the entire file into memory. -## Architecture +Buffer: Routes city records through a bounded concurrent queue to decouple parsing from writes. -### Data Sources and Formats +Store: Inserts records with concurrent thread safety using an in-memory SQLite database. -- Hierarchical structure: countries array → states per country → cities per state. -- Fields: `id` (integer), `name` (string), `iso2` / `iso3` (codes), `latitude` / `longitude`. -- Sourced from: [dr5hn/countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) on GitHub. +Generate: Produces mock brewery metadata for a sample of cities (mockup for future LLM integration). -**Output**: Structured SQLite in-memory database + console logs via spdlog. +Architecture -### Concurrency Architecture +Data Sources and Formats + +Hierarchical structure: countries array → states per country → cities per state. + +Fields: id (integer), name (string), iso2 / iso3 (codes), latitude / longitude. + +Sourced from: dr5hn/countries-states-cities-database on GitHub. + +Output: Structured SQLite in-memory database + console logs via spdlog. + +Concurrency Architecture The pipeline splits work across parsing and writing phases: -``` Main Thread: - parse_sax() -> Insert countries (direct) - -> Insert states (direct) - -> Push CityRecord to WorkQueue +parse_sax() -> Insert countries (direct) +-> Insert states (direct) +-> Push CityRecord to WorkQueue Worker Threads (implicit; pthread pool via sqlite3): - Pop CityRecord from WorkQueue - -> InsertCity(db) with mutex protection -``` +Pop CityRecord from WorkQueue +-> InsertCity(db) with mutex protection -**Key synchronization primitives**: +Key synchronization primitives: -- **WorkQueue**: Bounded (default 1024 items) concurrent queue with blocking push/pop, guarded by mutex + condition variables. -- **SqliteDatabase::dbMutex**: Serializes all SQLite operations to avoid `SQLITE_BUSY` and ensure write safety. +WorkQueue: Bounded (default 1024 items) concurrent queue with blocking push/pop, guarded by mutex + condition variables. -**Backpressure**: When the WorkQueue fills (≥1024 city records pending), the parser thread blocks until workers drain items. +SqliteDatabase::dbMutex: Serializes all SQLite operations to avoid SQLITE_BUSY and ensure write safety. -### Component Responsibilities +Backpressure: When the WorkQueue fills (≥1024 city records pending), the parser thread blocks until workers drain items. -| Component | Purpose | Thread Safety | -| ------------------------- | ------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- | -| **DataDownloader** | GitHub fetch with curl; optional filesystem cache; handles retries and ETags. | Blocking I/O; safe for single-threaded startup. | -| **StreamingJsonParser** | SAX-style RapidJSON handler; emits country/state/city via callbacks; tracks parse state (array depth, key context). | Single-threaded parse phase; thread-safe callbacks. | -| **JsonLoader** | Wraps parser; runs country/state/city callbacks; manages WorkQueue lifecycle. | Produces to WorkQueue; consumes from callbacks. | -| **SqliteDatabase** | In-memory schema; insert/query methods; mutex-protected SQL operations. | Mutex-guarded; thread-safe concurrent inserts. | -| **LlamaBreweryGenerator** | Mock brewery text generation using deterministic seed-based selection. | Stateless; thread-safe method calls. | +Component Responsibilities ---- +Component -## Database Schema +Purpose -**SQLite in-memory database** with three core tables: +Thread Safety -### Countries +DataDownloader + +GitHub fetch with curl; optional filesystem cache; handles retries and ETags. + +Blocking I/O; safe for single-threaded startup. + +StreamingJsonParser + +Subclasses boost::json::basic_parser; emits country/state/city via callbacks; tracking parse depth. + +Single-threaded parse phase; thread-safe callbacks. + +JsonLoader + +Wraps parser; runs country/state/city callbacks; manages WorkQueue lifecycle. + +Produces to WorkQueue; consumes from callbacks. + +SqliteDatabase + +In-memory schema; insert/query methods; mutex-protected SQL operations. + +Mutex-guarded; thread-safe concurrent inserts. + +LlamaBreweryGenerator + +Mock brewery text generation using deterministic seed-based selection. + +Stateless; thread-safe method calls. + +Database Schema + +SQLite in-memory database with three core tables: + +Countries -```sql CREATE TABLE countries ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL, - iso2 TEXT, - iso3 TEXT +id INTEGER PRIMARY KEY, +name TEXT NOT NULL, +iso2 TEXT, +iso3 TEXT ); CREATE INDEX idx_countries_iso2 ON countries(iso2); -``` -### States +States -```sql CREATE TABLE states ( - id INTEGER PRIMARY KEY, - country_id INTEGER NOT NULL, - name TEXT NOT NULL, - iso2 TEXT, - FOREIGN KEY (country_id) REFERENCES countries(id) +id INTEGER PRIMARY KEY, +country_id INTEGER NOT NULL, +name TEXT NOT NULL, +iso2 TEXT, +FOREIGN KEY (country_id) REFERENCES countries(id) ); CREATE INDEX idx_states_country ON states(country_id); -``` -### Cities +Cities -```sql CREATE TABLE cities ( - id INTEGER PRIMARY KEY, - state_id INTEGER NOT NULL, - country_id INTEGER NOT NULL, - name TEXT NOT NULL, - latitude REAL, - longitude REAL, - FOREIGN KEY (state_id) REFERENCES states(id), - FOREIGN KEY (country_id) REFERENCES countries(id) +id INTEGER PRIMARY KEY, +state_id INTEGER NOT NULL, +country_id INTEGER NOT NULL, +name TEXT NOT NULL, +latitude REAL, +longitude REAL, +FOREIGN KEY (state_id) REFERENCES states(id), +FOREIGN KEY (country_id) REFERENCES countries(id) ); CREATE INDEX idx_cities_state ON cities(state_id); CREATE INDEX idx_cities_country ON cities(country_id); -``` -**Design rationale**: +Configuration and Extensibility -- In-memory for performance (no persistent storage; data is regenerated on each run). -- Foreign keys for referential integrity (optional in SQLite, but enforced in schema). -- Indexes on foreign keys for fast lookups during brewery generation. -- Dual country_id in cities table for direct queries without state joins. +Command-Line Arguments ---- +Boost.Program_options provides named CLI arguments: -## Data Flow +./biergarten-pipeline [options] -### Parse Phase (Main Thread) +Arg -1. **DataDownloader::DownloadCountriesDatabase()** - - Constructs GitHub raw-content URL: `https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/{commit}/countries+states+cities.json` - - Uses curl with `FOLLOWLOCATION` and timeout. - - Caches locally; checks ETag for freshness. +Default -2. **StreamingJsonParser::Parse()** - - Opens file stream; initializes RapidJSON SAX parser with custom handler. - - Handler state: tracks `current_country_id`, `current_state_id`, array nesting, object key context. - - **Country processing** (inline): When country object completes, calls `db.InsertCountry()` directly on main thread. - - **State processing** (inline): When state object completes, calls `db.InsertState()` directly. - - **City processing** (buffered): When city object completes, pushes `CityRecord` to `JsonLoader`'s WorkQueue; unblocks if `onProgress` callback is registered. +Purpose -3. **JsonLoader::LoadWorldCities()** - - Registers callbacks with parser. - - Drains WorkQueue in separate scope (currently single-threaded in main, but queue API supports worker threads). - - Each city is inserted via `db.InsertCity()`. +--model, -m -### Query and Generation Phase (Main Thread) +"" -4. **Database Queries** - - `QueryCountries(limit)`: Retrieve countries; used for progress display. - - `QueryStates(limit)`: Retrieve states; used for progress display. - - `QueryCities()`: Retrieve all city ids + names for brewery generation. +Path to LLM model (mock implementation used if left blank). -5. **Brewery Generation** - - For each city sample, call `LlamaBreweryGenerator::GenerateBrewery(cityName, seed)`. - - Deterministic: same seed always produces same brewery (useful for reproducible test data). - - Returns `{ name, description }` struct. +--cache-dir, -c ---- +/tmp -## Concurrency Deep Dive +Directory for cached JSON DB. -### WorkQueue +--commit -A bounded thread-safe queue enabling producer-consumer patterns: +c5eb7772 -```cpp -template class WorkQueue { - std::queue queue; - std::mutex mutex; - std::condition_variable cv_not_empty, cv_not_full; - size_t max_size; - bool shutdown; -}; -``` +Git commit hash for consistency (stable 2026-03-28 snapshot). -**push(item)**: +--help, -h -- Locks mutex. -- Waits on `cv_not_full` until queue is below max_size OR shutdown signaled. -- Pushes item; notifies one waiter on `cv_not_empty`. -- Returns false if shutdown, else true. +- -**pop()**: +Show help menu. -- Locks mutex. -- Waits on `cv_not_empty` until queue has items OR shutdown signaled. -- Pops and returns `std::optional`; notifies one waiter on `cv_not_full`. -- Returns `std::nullopt` if shutdown and queue is empty. +Examples: -**shutdown_queue()**: - -- Sets `shutdown = true`; notifies all waiters on both condition variables. -- Causes all waiting pop() calls to return `std::nullopt`. - -**Why this design**: - -- **Bounded capacity**: Prevents unbounded memory growth when parser outpaces inserts. -- **Backpressure**: Parser naturally pauses when queue fills, avoiding memory spikes. -- **Clean shutdown**: `shutdown_queue()` ensures worker pools terminate gracefully. - -### SqliteDatabase Mutex - -All SQLite operations (`INSERT`, `SELECT`) are guarded by `dbMutex`: - -```cpp -std::unique_lock lock(dbMutex); -int rc = sqlite3_step(stmt); -``` - -**Why**: SQLite's "serializable" journal mode (default) requires external synchronization for multi-threaded access. A single mutex serializes all queries, avoiding `SQLITE_BUSY` errors. - -**Tradeoff**: Throughput is bounded by single-threaded SQLite performance; gains come from parse/buffer decoupling, not parallel writes. - ---- - -## Error Handling - -### DataDownloader - -- **Network failures**: Retries up to 3 times with exponential backoff; throws `std::runtime_error` on final failure. -- **Caching**: Falls back to cached file if download fails and cache exists. - -### Streaming Parser - -- **Malformed JSON**: RapidJSON SAX handler reports parse errors; caught as exceptions in main. -- **Missing fields**: Silently skips incomplete records (e.g., city without latitude). - -### Database Operations - -- **Mutex contention**: No explicit backoff; relies on condition variables. -- **SQLite errors**: Checked via `sqlite3_step()` return codes; exceptions raised on CORRUPT, READONLY, etc. - -### Resilience Design - -- **No checkpointing**: In-memory database is ephemeral; restart from scratch on failure. -- **Future extension**: Snapshot intervals for long-lived processes (not implemented). - ---- - -## Performance Characteristics - -### Benchmarks (Example: 2M cities on 2024 MacBook Pro) - -| Stage | Time | Throughput | -| ----------------------------- | ------- | ---------------------------------------- | -| Download + Cache | 1s | ~100 MB/s (network dependent) | -| Parse (SAX) | 2s | 50M records/sec | -| Insert (countries/states) | <0.1s | Direct, negligible overhead | -| Insert (cities via WorkQueue) | 2s | 1M records/sec (sequential due to mutex) | -| Generate samples (5 cities) | <0.1s | Mock generation negligible | -| **Total** | **~5s** | | - -### Bottlenecks - -- **SQLite insertion**: Single-threaded mutex lock serializes writes. Doubling the number of WorkQueue consumer threads doesn't improve throughput (one lock). -- **Parse speed**: RapidJSON SAX is fast (2s for 100 MB); not the bottleneck. -- **Memory**: ~100 MB for in-memory database; suitable for most deployments. - -### Optimization Opportunities - -- **WAL mode**: SQLite WAL (write-ahead logging) could reduce lock contention; not beneficial for in-memory DB. -- **Batch inserts**: Combine multiple rows in a single transaction; helps if inserting outside the WorkQueue scope. -- **Foreign key lazy-loading**: Skip foreign key constraints during bulk load; re-enable for queries. (Not implemented.) - ---- - -## Configuration and Extensibility - -### Command-Line Arguments - -```bash -./biergarten-pipeline [modelPath] [cacheDir] [commit] -``` - -| Arg | Default | Purpose | -| ----------- | -------------- | ----------------------------------------------------------------------- | -| `modelPath` | `./model.gguf` | Path to LLM model (mock implementation; not loaded in current version). | -| `cacheDir` | `/tmp` | Directory for cached JSON (e.g., `/tmp/countries+states+cities.json`). | -| `commit` | `c5eb7772` | Git commit hash for consistency (stable 2026-03-28 snapshot). | - -**Examples**: - -```bash ./biergarten-pipeline -./biergarten-pipeline ./models/llama.gguf /var/cache main -./biergarten-pipeline "" /tmp v1.2.3 -``` +./biergarten-pipeline --model ./models/llama.gguf --cache-dir /var/cache +./biergarten-pipeline -c /tmp --commit v1.2.3 -### Extending the Generator +Building and Running -**Current**: `LlamaBreweryGenerator::GenerateBrewery()` uses deterministic seed-based selection from hardcoded lists. +Prerequisites -**Future swap points**: +C++23 compiler (g++, clang, MSVC). -1. Load an actual LLM model in `LoadModel(modelPath)`. -2. Tokenize city name and context; call model inference. -3. Validate output (length, format) and rank if multiple candidates. -4. Cache results to avoid re-inference for repeated cities. +CMake 3.20+. -**Example stub for future integration**: +curl (for HTTP downloads). -```cpp -Brewery LlamaBreweryGenerator::GenerateBrewery(const std::string &cityName, int seed) { - // TODO: Replace with actual llama.cpp inference - // llama_context *ctx = llama_new_context_with_model(model, params); - // std::string prompt = "Generate a brewery for " + cityName; - // std::string result = llama_inference(ctx, prompt, seed); - // return parse_brewery(result); -} -``` +sqlite3. -### Logging Configuration +Boost 1.75+ (requires Boost.JSON and Boost.Program_options). -Logging uses **spdlog** with: +spdlog (fetched via CMake FetchContent). -- **Level**: Info (can change via `spdlog::set_level(spdlog::level::debug)` at startup). -- **Format**: Plain ASCII; no Unicode box art. -- **Sink**: Console (stdout/stderr); can redirect to file. +Build -**Current output sample**: - -``` -[Pipeline] Downloading geographic data from GitHub... -Initializing in-memory SQLite database... -Initializing brewery generator... - -=== GEOGRAPHIC DATA OVERVIEW === -Total records loaded: - Countries: 195 - States: 5000 - Cities: 150000 -``` - ---- - -## Building and Running - -### Prerequisites - -- C++17 compiler (g++, clang, MSVC). -- CMake 3.20+. -- curl (for HTTP downloads). -- sqlite3 (usually system-provided). -- RapidJSON (fetched via CMake FetchContent). -- spdlog (fetched via CMake FetchContent). - -### Build - -```bash mkdir -p build cd build cmake .. cmake --build . --target biergarten-pipeline -- -j -``` -**Build artifacts**: +Run -- Executable: `build/biergarten-pipeline` -- Intermediate: `build/CMakeFiles/`, `build/_deps/` (RapidJSON, spdlog) - -### Run - -```bash ./biergarten-pipeline -``` -**Output**: Logs to console; caches JSON in `/tmp/countries+states+cities.json`. - -### Cleaning - -```bash -rm -rf build -``` - ---- - -## Development Notes - -### Code Organization - -- **`includes/`**: Public headers (data structures, class APIs). -- **`src/`**: Implementations with inline comments for non-obvious logic. -- **`CMakeLists.txt`**: Build configuration; defines fetch content, compiler flags, linking. - -### Testing - -Currently no automated tests. To add: - -1. Create `tests/` folder. -2. Use CMake to add a test executable. -3. Test the parser with small JSON fixtures. -4. Mock the database for isolation. - -### Debugging - -**Enable verbose logging**: - -```cpp -spdlog::set_level(spdlog::level::debug); -``` - -**GDB workflow**: - -```bash -gdb ./biergarten-pipeline -(gdb) break src/stream_parser.cpp:50 -(gdb) run -``` - -### Future Enhancements - -1. **Real LLM integration**: Load and run llama.cpp models. -2. **Persistence**: Write brewery data to a database or file. -3. **Distributed parsing**: Shard JSON file across multiple parse streams. -4. **Incremental updates**: Only insert new records if source updated. -5. **Web API**: Expose database via HTTP (brewery lookup, city search). - ---- - -## References - -- [RapidJSON](https://rapidjson.org/) – SAX parsing documentation. -- [spdlog](https://github.com/gabime/spdlog) – Logging framework. -- [SQLite](https://www.sqlite.org/docs.html) – In-memory database reference. -- [countries-states-cities-database](https://github.com/dr5hn/countries-states-cities-database) – Data source. +Output: Logs to console; caches JSON in /tmp/countries+states+cities.json. diff --git a/pipeline/includes/curl_web_client.h b/pipeline/includes/curl_web_client.h new file mode 100644 index 0000000..ae60cf6 --- /dev/null +++ b/pipeline/includes/curl_web_client.h @@ -0,0 +1,26 @@ +#pragma once + +#include "web_client.h" +#include + +// RAII for curl_global_init/cleanup. +// An instance of this class should be created in main() before any curl +// operations and exist for the lifetime of the application. +class CurlGlobalState { +public: + CurlGlobalState(); + ~CurlGlobalState(); + CurlGlobalState(const CurlGlobalState &) = delete; + CurlGlobalState &operator=(const CurlGlobalState &) = delete; +}; + +class CURLWebClient : public IWebClient { +public: + CURLWebClient(); + ~CURLWebClient() override; + + void DownloadToFile(const std::string &url, + const std::string &filePath) override; + std::string Get(const std::string &url) override; + std::string UrlEncode(const std::string &value) override; +}; diff --git a/pipeline/includes/data_downloader.h b/pipeline/includes/data_downloader.h index 79cabed..dae783e 100644 --- a/pipeline/includes/data_downloader.h +++ b/pipeline/includes/data_downloader.h @@ -1,14 +1,17 @@ #ifndef DATA_DOWNLOADER_H #define DATA_DOWNLOADER_H +#include #include #include +#include "web_client.h" + /// @brief Downloads and caches source geography JSON payloads. class DataDownloader { public: /// @brief Initializes global curl state used by this downloader. - DataDownloader(); + DataDownloader(std::shared_ptr webClient); /// @brief Cleans up global curl state. ~DataDownloader(); @@ -21,6 +24,7 @@ public: private: bool FileExists(const std::string &filePath) const; + std::shared_ptr m_webClient; }; #endif // DATA_DOWNLOADER_H diff --git a/pipeline/includes/database.h b/pipeline/includes/database.h index df94d4c..7ebbd0d 100644 --- a/pipeline/includes/database.h +++ b/pipeline/includes/database.h @@ -27,6 +27,15 @@ struct State { int countryId; }; +struct City { + /// @brief City identifier from the source dataset. + int id; + /// @brief City display name. + std::string name; + /// @brief Parent country identifier. + int countryId; +}; + /// @brief Thread-safe SQLite wrapper for pipeline writes and readbacks. class SqliteDatabase { private: @@ -60,8 +69,8 @@ public: void InsertCity(int id, int stateId, int countryId, const std::string &name, double latitude, double longitude); - /// @brief Returns city id and city name pairs. - std::vector> QueryCities(); + /// @brief Returns city records including parent country id. + std::vector QueryCities(); /// @brief Returns countries with optional row limit. std::vector QueryCountries(int limit = 0); diff --git a/pipeline/includes/llama_generator.h b/pipeline/includes/llama_generator.h index 8696602..a5d8c06 100644 --- a/pipeline/includes/llama_generator.h +++ b/pipeline/includes/llama_generator.h @@ -1,16 +1,20 @@ #pragma once -#include "data_generator.h" -#include +#include #include +#include "data_generator.h" + struct llama_model; struct llama_context; class LlamaGenerator final : public IDataGenerator { public: + LlamaGenerator() = default; ~LlamaGenerator() override; + void setSamplingOptions(float temperature, float topP, int seed = -1); + void load(const std::string &modelPath) override; BreweryResult generateBrewery(const std::string &cityName, const std::string &countryName, @@ -18,14 +22,17 @@ public: UserResult generateUser(const std::string &locale) override; private: - std::string infer(const std::string &prompt, int maxTokens = 5000); + std::string infer(const std::string &prompt, int maxTokens = 10000); // Overload that allows passing a system message separately so chat-capable // models receive a proper system role instead of having the system text // concatenated into the user prompt (helps avoid revealing internal // reasoning or instructions in model output). std::string infer(const std::string &systemPrompt, const std::string &prompt, - int maxTokens = 5000); + int maxTokens = 10000); llama_model *model_ = nullptr; llama_context *context_ = nullptr; + float sampling_temperature_ = 0.8f; + float sampling_top_p_ = 0.92f; + uint32_t sampling_seed_ = 0xFFFFFFFFu; }; diff --git a/pipeline/includes/web_client.h b/pipeline/includes/web_client.h new file mode 100644 index 0000000..426e3c3 --- /dev/null +++ b/pipeline/includes/web_client.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +class IWebClient { +public: + virtual ~IWebClient() = default; + + // Downloads content from a URL to a file. Throws on error. + virtual void DownloadToFile(const std::string &url, + const std::string &filePath) = 0; + + // Performs a GET request and returns the response body as a string. Throws on + // error. + virtual std::string Get(const std::string &url) = 0; + + // URL-encodes a string. + virtual std::string UrlEncode(const std::string &value) = 0; +}; diff --git a/pipeline/includes/wikipedia_service.h b/pipeline/includes/wikipedia_service.h new file mode 100644 index 0000000..55c1e32 --- /dev/null +++ b/pipeline/includes/wikipedia_service.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include +#include + +#include "web_client.h" + +/// @brief Provides cached Wikipedia summary lookups for city and country pairs. +class WikipediaService { +public: + /// @brief Creates a new Wikipedia service with the provided web client. + explicit WikipediaService(std::shared_ptr client); + + /// @brief Returns the Wikipedia summary extract for city and country. + [[nodiscard]] std::string GetSummary(std::string_view city, + std::string_view country); + +private: + std::string FetchExtract(std::string_view query); + std::shared_ptr client_; + std::unordered_map cache_; +}; diff --git a/pipeline/src/curl_web_client.cpp b/pipeline/src/curl_web_client.cpp new file mode 100644 index 0000000..9a94c55 --- /dev/null +++ b/pipeline/src/curl_web_client.cpp @@ -0,0 +1,139 @@ +#include "curl_web_client.h" +#include +#include +#include +#include +#include +#include + +CurlGlobalState::CurlGlobalState() { + if (curl_global_init(CURL_GLOBAL_DEFAULT) != CURLE_OK) { + throw std::runtime_error( + "[CURLWebClient] Failed to initialize libcurl globally"); + } +} + +CurlGlobalState::~CurlGlobalState() { curl_global_cleanup(); } + +namespace { +// curl write callback that appends response data into a std::string +static size_t WriteCallbackString(void *contents, size_t size, size_t nmemb, + void *userp) { + size_t realsize = size * nmemb; + auto *s = static_cast(userp); + s->append(static_cast(contents), realsize); + return realsize; +} + +// curl write callback that writes to a file stream +static size_t WriteCallbackFile(void *contents, size_t size, size_t nmemb, + void *userp) { + size_t realsize = size * nmemb; + auto *outFile = static_cast(userp); + outFile->write(static_cast(contents), realsize); + return realsize; +} + +// RAII wrapper for CURL handle using unique_ptr +using CurlHandle = std::unique_ptr; + +CurlHandle create_handle() { + CURL *handle = curl_easy_init(); + if (!handle) { + throw std::runtime_error( + "[CURLWebClient] Failed to initialize libcurl handle"); + } + return CurlHandle(handle, &curl_easy_cleanup); +} + +void set_common_get_options(CURL *curl, const std::string &url, + long connect_timeout, long total_timeout) { + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, connect_timeout); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, total_timeout); + curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); +} +} // namespace + +CURLWebClient::CURLWebClient() {} + +CURLWebClient::~CURLWebClient() {} + +void CURLWebClient::DownloadToFile(const std::string &url, + const std::string &filePath) { + auto curl = create_handle(); + + std::ofstream outFile(filePath, std::ios::binary); + if (!outFile.is_open()) { + throw std::runtime_error("[CURLWebClient] Cannot open file for writing: " + + filePath); + } + + set_common_get_options(curl.get(), url, 30L, 300L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackFile); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, + static_cast(&outFile)); + + CURLcode res = curl_easy_perform(curl.get()); + outFile.close(); + + if (res != CURLE_OK) { + std::remove(filePath.c_str()); + std::string error = std::string("[CURLWebClient] Download failed: ") + + curl_easy_strerror(res); + throw std::runtime_error(error); + } + + long httpCode = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); + + if (httpCode != 200) { + std::remove(filePath.c_str()); + std::stringstream ss; + ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; + throw std::runtime_error(ss.str()); + } +} + +std::string CURLWebClient::Get(const std::string &url) { + auto curl = create_handle(); + + std::string response_string; + set_common_get_options(curl.get(), url, 10L, 20L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, WriteCallbackString); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + std::string error = + std::string("[CURLWebClient] GET failed: ") + curl_easy_strerror(res); + throw std::runtime_error(error); + } + + long httpCode = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &httpCode); + + if (httpCode != 200) { + std::stringstream ss; + ss << "[CURLWebClient] HTTP error " << httpCode << " for URL " << url; + throw std::runtime_error(ss.str()); + } + + return response_string; +} + +std::string CURLWebClient::UrlEncode(const std::string &value) { + // A NULL handle is fine for UTF-8 encoding according to libcurl docs. + char *output = curl_easy_escape(nullptr, value.c_str(), 0); + + if (output) { + std::string result(output); + curl_free(output); + return result; + } + throw std::runtime_error("[CURLWebClient] curl_easy_escape failed"); +} diff --git a/pipeline/src/data_downloader.cpp b/pipeline/src/data_downloader.cpp index 9be4508..5060433 100644 --- a/pipeline/src/data_downloader.cpp +++ b/pipeline/src/data_downloader.cpp @@ -1,20 +1,13 @@ #include "data_downloader.h" -#include -#include +#include "web_client.h" #include #include #include #include +#include -static size_t WriteCallback(void *contents, size_t size, size_t nmemb, - void *userp) { - size_t realsize = size * nmemb; - std::ofstream *outFile = static_cast(userp); - outFile->write(static_cast(contents), realsize); - return realsize; -} - -DataDownloader::DataDownloader() {} +DataDownloader::DataDownloader(std::shared_ptr webClient) + : m_webClient(std::move(webClient)) {} DataDownloader::~DataDownloader() {} @@ -41,56 +34,7 @@ DataDownloader::DownloadCountriesDatabase(const std::string &cachePath, spdlog::info("[DataDownloader] Downloading: {}", url); - CURL *curl = curl_easy_init(); - if (!curl) { - throw std::runtime_error("[DataDownloader] Failed to initialize libcurl"); - } - - std::ofstream outFile(cachePath, std::ios::binary); - if (!outFile.is_open()) { - curl_easy_cleanup(curl); - throw std::runtime_error("[DataDownloader] Cannot open file for writing: " + - cachePath); - } - - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, static_cast(&outFile)); - - curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 30L); - curl_easy_setopt(curl, CURLOPT_TIMEOUT, 300L); - - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L); - - curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip"); - - curl_easy_setopt(curl, CURLOPT_USERAGENT, "biergarten-pipeline/0.1.0"); - - CURLcode res = curl_easy_perform(curl); - outFile.close(); - - if (res != CURLE_OK) { - curl_easy_cleanup(curl); - std::remove(cachePath.c_str()); - - std::string error = std::string("[DataDownloader] Download failed: ") + - curl_easy_strerror(res); - throw std::runtime_error(error); - } - - long httpCode = 0; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode); - curl_easy_cleanup(curl); - - if (httpCode != 200) { - std::remove(cachePath.c_str()); - - std::stringstream ss; - ss << "[DataDownloader] HTTP error " << httpCode - << " (commit: " << shortCommit << ")"; - throw std::runtime_error(ss.str()); - } + m_webClient->DownloadToFile(url, cachePath); std::ifstream fileCheck(cachePath, std::ios::binary | std::ios::ate); std::streamsize size = fileCheck.tellg(); diff --git a/pipeline/src/database.cpp b/pipeline/src/database.cpp index 2748ccb..a8749d2 100644 --- a/pipeline/src/database.cpp +++ b/pipeline/src/database.cpp @@ -157,13 +157,12 @@ void SqliteDatabase::InsertCity(int id, int stateId, int countryId, sqlite3_finalize(stmt); } -std::vector> SqliteDatabase::QueryCities() { +std::vector SqliteDatabase::QueryCities() { std::lock_guard lock(dbMutex); - - std::vector> cities; + std::vector cities; sqlite3_stmt *stmt = nullptr; - const char *query = "SELECT id, name FROM cities ORDER BY name"; + const char *query = "SELECT id, name, country_id FROM cities ORDER BY name"; int rc = sqlite3_prepare_v2(db, query, -1, &stmt, nullptr); if (rc != SQLITE_OK) { @@ -174,7 +173,8 @@ std::vector> SqliteDatabase::QueryCities() { int id = sqlite3_column_int(stmt, 0); const char *name = reinterpret_cast(sqlite3_column_text(stmt, 1)); - cities.push_back({id, name ? std::string(name) : ""}); + int countryId = sqlite3_column_int(stmt, 2); + cities.push_back({id, name ? std::string(name) : "", countryId}); } sqlite3_finalize(stmt); diff --git a/pipeline/src/json_loader.cpp b/pipeline/src/json_loader.cpp index 1d27176..6a7a966 100644 --- a/pipeline/src/json_loader.cpp +++ b/pipeline/src/json_loader.cpp @@ -1,32 +1,52 @@ +#include + +#include + #include "json_loader.h" #include "stream_parser.h" -#include -#include void JsonLoader::LoadWorldCities(const std::string &jsonPath, SqliteDatabase &db) { + constexpr size_t kBatchSize = 10000; + auto startTime = std::chrono::high_resolution_clock::now(); spdlog::info("\nLoading {} (streaming RapidJSON SAX)...", jsonPath); db.BeginTransaction(); + bool transactionOpen = true; size_t citiesProcessed = 0; - StreamingJsonParser::Parse( - jsonPath, db, - [&](const CityRecord &record) { - db.InsertCity(record.id, record.state_id, record.country_id, - record.name, record.latitude, record.longitude); - citiesProcessed++; - }, - [&](size_t current, size_t total) { - if (current % 10000 == 0 && current > 0) { - spdlog::info(" [Progress] Parsed {} cities...", current); - } - }); + try { + StreamingJsonParser::Parse( + jsonPath, db, + [&](const CityRecord &record) { + db.InsertCity(record.id, record.state_id, record.country_id, + record.name, record.latitude, record.longitude); + ++citiesProcessed; - spdlog::info(" OK: Parsed all cities from JSON"); + if (citiesProcessed % kBatchSize == 0) { + db.CommitTransaction(); + db.BeginTransaction(); + } + }, + [&](size_t current, size_t /*total*/) { + if (current % kBatchSize == 0 && current > 0) { + spdlog::info(" [Progress] Parsed {} cities...", current); + } + }); - db.CommitTransaction(); + spdlog::info(" OK: Parsed all cities from JSON"); + + if (transactionOpen) { + db.CommitTransaction(); + transactionOpen = false; + } + } catch (...) { + if (transactionOpen) { + db.CommitTransaction(); + } + throw; + } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast( diff --git a/pipeline/src/llama_generator.cpp b/pipeline/src/llama_generator.cpp index bc53ff6..f736852 100644 --- a/pipeline/src/llama_generator.cpp +++ b/pipeline/src/llama_generator.cpp @@ -1,7 +1,3 @@ -#include "llama_generator.h" - -#include "llama.h" - #include #include #include @@ -11,8 +7,12 @@ #include #include +#include "llama.h" +#include #include +#include "llama_generator.h" + namespace { std::string trim(std::string value) { @@ -26,10 +26,47 @@ std::string trim(std::string value) { return value; } +std::string CondenseWhitespace(std::string text) { + std::string out; + out.reserve(text.size()); + + bool inWhitespace = false; + for (unsigned char ch : text) { + if (std::isspace(ch)) { + if (!inWhitespace) { + out.push_back(' '); + inWhitespace = true; + } + continue; + } + + inWhitespace = false; + out.push_back(static_cast(ch)); + } + + return trim(std::move(out)); +} + +std::string PrepareRegionContext(std::string_view regionContext, + std::size_t maxChars = 700) { + std::string normalized = CondenseWhitespace(std::string(regionContext)); + if (normalized.size() <= maxChars) { + return normalized; + } + + normalized.resize(maxChars); + const std::size_t lastSpace = normalized.find_last_of(' '); + if (lastSpace != std::string::npos && lastSpace > maxChars / 2) { + normalized.resize(lastSpace); + } + + normalized += "..."; + return normalized; +} + std::string stripCommonPrefix(std::string line) { line = trim(std::move(line)); - // Strip simple list markers like "- ", "* ", "1. ", "2) ". if (!line.empty() && (line[0] == '-' || line[0] == '*')) { line = trim(line.substr(1)); } else { @@ -68,6 +105,50 @@ std::string stripCommonPrefix(std::string line) { return trim(std::move(line)); } +std::pair +parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) { + std::string normalized = raw; + std::replace(normalized.begin(), normalized.end(), '\r', '\n'); + + std::vector lines; + std::stringstream stream(normalized); + std::string line; + while (std::getline(stream, line)) { + line = stripCommonPrefix(std::move(line)); + if (!line.empty()) + lines.push_back(std::move(line)); + } + + std::vector filtered; + for (auto &l : lines) { + std::string low = l; + std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + if (!l.empty() && l.front() == '<' && low.back() == '>') + continue; + if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) + continue; + filtered.push_back(std::move(l)); + } + + if (filtered.size() < 2) + throw std::runtime_error(errorMessage); + + std::string first = trim(filtered.front()); + std::string second; + for (size_t i = 1; i < filtered.size(); ++i) { + if (!second.empty()) + second += ' '; + second += filtered[i]; + } + second = trim(std::move(second)); + + if (first.empty() || second.empty()) + throw std::runtime_error(errorMessage); + return {first, second}; +} + std::string toChatPrompt(const llama_model *model, const std::string &userPrompt) { const char *tmpl = llama_model_chat_template(model, nullptr); @@ -75,10 +156,7 @@ std::string toChatPrompt(const llama_model *model, return userPrompt; } - const llama_chat_message message{ - "user", - userPrompt.c_str(), - }; + const llama_chat_message message{"user", userPrompt.c_str()}; std::vector buffer(std::max(1024, userPrompt.size() * 4)); int32_t required = @@ -106,14 +184,11 @@ std::string toChatPrompt(const llama_model *model, const std::string &userPrompt) { const char *tmpl = llama_model_chat_template(model, nullptr); if (tmpl == nullptr) { - // Fall back to concatenating but keep system and user parts distinct. return systemPrompt + "\n\n" + userPrompt; } - const llama_chat_message messages[2] = { - {"system", systemPrompt.c_str()}, - {"user", userPrompt.c_str()}, - }; + const llama_chat_message messages[2] = {{"system", systemPrompt.c_str()}, + {"user", userPrompt.c_str()}}; std::vector buffer(std::max( 1024, (systemPrompt.size() + userPrompt.size()) * 4)); @@ -161,73 +236,135 @@ void appendTokenPiece(const llama_vocab *vocab, llama_token token, output.append(buffer.data(), static_cast(bytes)); } -std::pair -parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) { - std::string normalized = raw; - std::replace(normalized.begin(), normalized.end(), '\r', '\n'); +bool extractFirstJsonObject(const std::string &text, std::string &jsonOut) { + std::size_t start = std::string::npos; + int depth = 0; + bool inString = false; + bool escaped = false; - std::vector lines; - std::stringstream stream(normalized); - std::string line; - while (std::getline(stream, line)) { - line = stripCommonPrefix(std::move(line)); - if (!line.empty()) { - lines.push_back(std::move(line)); - } - } + for (std::size_t i = 0; i < text.size(); ++i) { + const char ch = text[i]; - // Filter out obvious internal-thought / meta lines that sometimes leak from - // models (e.g. "", "Okay, so the user is asking me..."). - std::vector filtered; - for (auto &l : lines) { - std::string low = l; - std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); - - // Skip single-token angle-bracket markers like or <...> - if (!l.empty() && l.front() == '<' && l.back() == '>') { + if (inString) { + if (escaped) { + escaped = false; + } else if (ch == '\\') { + escaped = true; + } else if (ch == '"') { + inString = false; + } continue; } - // Skip short internal commentary that starts with common discourse markers - if (low.rfind("okay,", 0) == 0 || low.rfind("wait,", 0) == 0 || - low.rfind("hmm", 0) == 0) { + if (ch == '"') { + inString = true; continue; } - // Skip lines that look like self-descriptions of what the model is doing - if (low.find("user is asking") != std::string::npos || - low.find("protocol") != std::string::npos || - low.find("parse") != std::string::npos || - low.find("return only") != std::string::npos) { + if (ch == '{') { + if (depth == 0) { + start = i; + } + ++depth; continue; } - filtered.push_back(std::move(l)); - } - - if (filtered.size() < 2) { - throw std::runtime_error(errorMessage); - } - - std::string first = trim(filtered.front()); - std::string second; - for (std::size_t i = 1; i < filtered.size(); ++i) { - if (!second.empty()) { - second += ' '; + if (ch == '}') { + if (depth == 0) { + continue; + } + --depth; + if (depth == 0 && start != std::string::npos) { + jsonOut = text.substr(start, i - start + 1); + return true; + } } - second += filtered[i]; - } - second = trim(std::move(second)); - - if (first.empty() || second.empty()) { - throw std::runtime_error(errorMessage); } - return {first, second}; + return false; } +std::string ValidateBreweryJson(const std::string &raw, std::string &nameOut, + std::string &descriptionOut) { + auto validateObject = [&](const boost::json::value &jv, + std::string &errorOut) -> bool { + if (!jv.is_object()) { + errorOut = "JSON root must be an object"; + return false; + } + + const auto &obj = jv.get_object(); + if (!obj.contains("name") || !obj.at("name").is_string()) { + errorOut = "JSON field 'name' is missing or not a string"; + return false; + } + + if (!obj.contains("description") || !obj.at("description").is_string()) { + errorOut = "JSON field 'description' is missing or not a string"; + return false; + } + + nameOut = trim(std::string(obj.at("name").as_string().c_str())); + descriptionOut = + trim(std::string(obj.at("description").as_string().c_str())); + + if (nameOut.empty()) { + errorOut = "JSON field 'name' must not be empty"; + return false; + } + + if (descriptionOut.empty()) { + errorOut = "JSON field 'description' must not be empty"; + return false; + } + + std::string nameLower = nameOut; + std::string descriptionLower = descriptionOut; + std::transform( + nameLower.begin(), nameLower.end(), nameLower.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::transform(descriptionLower.begin(), descriptionLower.end(), + descriptionLower.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + + if (nameLower == "string" || descriptionLower == "string") { + errorOut = "JSON appears to be a schema placeholder, not content"; + return false; + } + + errorOut.clear(); + return true; + }; + + boost::system::error_code ec; + boost::json::value jv = boost::json::parse(raw, ec); + std::string validationError; + if (ec) { + std::string extracted; + if (!extractFirstJsonObject(raw, extracted)) { + return "JSON parse error: " + ec.message(); + } + + ec.clear(); + jv = boost::json::parse(extracted, ec); + if (ec) { + return "JSON parse error: " + ec.message(); + } + + if (!validateObject(jv, validationError)) { + return validationError; + } + + return {}; + } + + if (!validateObject(jv, validationError)) { + return validationError; + } + + return {}; +} } // namespace LlamaGenerator::~LlamaGenerator() { @@ -244,10 +381,30 @@ LlamaGenerator::~LlamaGenerator() { llama_backend_free(); } -void LlamaGenerator::load(const std::string &modelPath) { - if (modelPath.empty()) { - throw std::runtime_error("LlamaGenerator: model path must not be empty"); +void LlamaGenerator::setSamplingOptions(float temperature, float topP, + int seed) { + if (temperature < 0.0f) { + throw std::runtime_error( + "LlamaGenerator: sampling temperature must be >= 0"); } + if (!(topP > 0.0f && topP <= 1.0f)) { + throw std::runtime_error( + "LlamaGenerator: sampling top-p must be in (0, 1]"); + } + if (seed < -1) { + throw std::runtime_error( + "LlamaGenerator: seed must be >= 0, or -1 for random"); + } + + sampling_temperature_ = temperature; + sampling_top_p_ = topP; + sampling_seed_ = (seed < 0) ? static_cast(LLAMA_DEFAULT_SEED) + : static_cast(seed); +} + +void LlamaGenerator::load(const std::string &modelPath) { + if (modelPath.empty()) + throw std::runtime_error("LlamaGenerator: model path must not be empty"); if (context_ != nullptr) { llama_free(context_); @@ -261,7 +418,7 @@ void LlamaGenerator::load(const std::string &modelPath) { llama_backend_init(); llama_model_params modelParams = llama_model_default_params(); - model_ = llama_load_model_from_file(modelPath.c_str(), modelParams); + model_ = llama_model_load_from_file(modelPath.c_str(), modelParams); if (model_ == nullptr) { throw std::runtime_error( "LlamaGenerator: failed to load model from path: " + modelPath); @@ -281,14 +438,12 @@ void LlamaGenerator::load(const std::string &modelPath) { } std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) { - if (model_ == nullptr || context_ == nullptr) { + if (model_ == nullptr || context_ == nullptr) throw std::runtime_error("LlamaGenerator: model not loaded"); - } const llama_vocab *vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) { + if (vocab == nullptr) throw std::runtime_error("LlamaGenerator: vocab unavailable"); - } llama_memory_clear(llama_get_memory(context_), true); @@ -308,17 +463,33 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) { static_cast(promptTokens.size()), true, true); } - if (tokenCount < 0) { + if (tokenCount < 0) throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + + const int32_t nCtx = static_cast(llama_n_ctx(context_)); + const int32_t nBatch = static_cast(llama_n_batch(context_)); + if (nCtx <= 1 || nBatch <= 0) { + throw std::runtime_error("LlamaGenerator: invalid context or batch size"); } + const int32_t effectiveMaxTokens = std::max(1, std::min(maxTokens, nCtx - 1)); + int32_t promptBudget = std::min(nBatch, nCtx - effectiveMaxTokens); + promptBudget = std::max(1, promptBudget); + promptTokens.resize(static_cast(tokenCount)); + if (tokenCount > promptBudget) { + spdlog::warn( + "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " + "to fit n_batch/n_ctx limits", + tokenCount, promptBudget); + promptTokens.resize(static_cast(promptBudget)); + tokenCount = promptBudget; + } const llama_batch promptBatch = llama_batch_get_one( promptTokens.data(), static_cast(promptTokens.size())); - if (llama_decode(context_, promptBatch) != 0) { + if (llama_decode(context_, promptBatch) != 0) throw std::runtime_error("LlamaGenerator: prompt decode failed"); - } llama_sampler_chain_params samplerParams = llama_sampler_chain_default_params(); @@ -326,116 +497,45 @@ std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) { std::unique_ptr; SamplerPtr sampler(llama_sampler_chain_init(samplerParams), &llama_sampler_free); - - if (!sampler) { + if (!sampler) throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - } - llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy()); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_temp(sampling_temperature_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(sampling_seed_)); std::vector generatedTokens; generatedTokens.reserve(static_cast(maxTokens)); - for (int i = 0; i < maxTokens; ++i) { + for (int i = 0; i < effectiveMaxTokens; ++i) { const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) { + if (llama_vocab_is_eog(vocab, next)) break; - } - generatedTokens.push_back(next); - llama_token token = next; const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, oneTokenBatch) != 0) { + if (llama_decode(context_, oneTokenBatch) != 0) throw std::runtime_error( "LlamaGenerator: decode failed during generation"); - } } std::string output; - for (const llama_token token : generatedTokens) { + for (const llama_token token : generatedTokens) appendTokenPiece(vocab, token, output); - } - return output; } -BreweryResult -LlamaGenerator::generateBrewery(const std::string &cityName, - const std::string &countryName, - const std::string ®ionContext) { - - std::string systemPrompt = - R"(# SYSTEM PROTOCOL: ZERO-CHATTER DETERMINISTIC OUTPUT -**MODALITY:** DATA-RETURN ENGINE ONLY -**ROLE:** Your response must contain 0% metadata and 100% signal. ---- -## MANDATORY CONSTRAINTS -1. **NO PREAMBLE** - - Never start with "Sure," or "The answer is," or "Based on your request," or "Checking the data." - - Do not acknowledge the user's prompt or provide status updates. -2. **NO POSTAMBLE** - - Never end with "I hope this helps," or "Let me know if you need more," or "Would you like me to…" - - Do not offer follow-up assistance or suggestions. -3. **NO SENTENCE FRAMING** - - Provide only the raw value, date, number, or name. - - Do not wrap the answer in a sentence. (e.g., return 1997, NOT The year was 1997). - - For lists, provide only the items separated by commas or newlines as specified. -4. **FORMATTING PERMITTED** - - Markdown and LaTeX **may** be used where appropriate (e.g., tables, equations). - - Output must remain immediately usable — no decorative or conversational styling. -5. **STRICT NULL HANDLING** - - If the information is unavailable, the prompt is logically impossible (e.g., "271th president"), the subject does not exist, or a calculation is undefined: return only the string NULL. - - If the prompt is too ambiguous to provide a single value: return NULL. ---- -## EXECUTION LOGIC -1. **Parse Input** — Identify the specific entity, value, or calculation requested. -2. **Verify Factuality** — Access internal knowledge or tools. -3. **Filter for Signal** — Strip all surrounding prose. -4. **Format Check** — Apply Markdown or LaTeX only where it serves the data. -5. **Output** — Return the raw value only. ---- -## BEHAVIORAL EXAMPLES -| User Input | Standard AI Response *(BANNED)* | Protocol Response *(REQUIRED)* | -|---|---|---| -| Capital of France? | The capital of France is Paris. | Paris | -| 15% of 200 | 15% of 200 is 30. | 30 | -| Who wrote '1984'? | George Orwell wrote that novel. | George Orwell | -| ISO code for Japan | The code is JP. | JP | -| $\sqrt{x}$ where $x$ is a potato | A potato has no square root. | NULL | -| 500th US President | There haven't been that many. | NULL | -| Pythagorean theorem | The theorem states... | $a^2 + b^2 = c^2$ | ---- -## FINAL INSTRUCTION -Total silence is preferred over conversational error. Any deviation from the raw-value-only format is a protocol failure. Proceed with next input.)"; - - std::string prompt = - "Generate a craft brewery name and 1000 character description for a " - "brewery located in " + - cityName + - (countryName.empty() ? std::string("") - : std::string(", ") + countryName) + - ". " + regionContext + - " Respond with exactly two lines: first line is the name, second line is " - "the description. Do not include bullets, numbering, or any extra text."; - - const std::string raw = infer(systemPrompt, prompt, 512); - auto [name, description] = - parseTwoLineResponse(raw, "LlamaGenerator: malformed brewery response"); - - return {name, description}; -} - std::string LlamaGenerator::infer(const std::string &systemPrompt, const std::string &prompt, int maxTokens) { - if (model_ == nullptr || context_ == nullptr) { + if (model_ == nullptr || context_ == nullptr) throw std::runtime_error("LlamaGenerator: model not loaded"); - } const llama_vocab *vocab = llama_model_get_vocab(model_); - if (vocab == nullptr) { + if (vocab == nullptr) throw std::runtime_error("LlamaGenerator: vocab unavailable"); - } llama_memory_clear(llama_get_memory(context_), true); @@ -456,17 +556,33 @@ std::string LlamaGenerator::infer(const std::string &systemPrompt, static_cast(promptTokens.size()), true, true); } - if (tokenCount < 0) { + if (tokenCount < 0) throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + + const int32_t nCtx = static_cast(llama_n_ctx(context_)); + const int32_t nBatch = static_cast(llama_n_batch(context_)); + if (nCtx <= 1 || nBatch <= 0) { + throw std::runtime_error("LlamaGenerator: invalid context or batch size"); } + const int32_t effectiveMaxTokens = std::max(1, std::min(maxTokens, nCtx - 1)); + int32_t promptBudget = std::min(nBatch, nCtx - effectiveMaxTokens); + promptBudget = std::max(1, promptBudget); + promptTokens.resize(static_cast(tokenCount)); + if (tokenCount > promptBudget) { + spdlog::warn( + "LlamaGenerator: prompt too long ({} tokens), truncating to {} tokens " + "to fit n_batch/n_ctx limits", + tokenCount, promptBudget); + promptTokens.resize(static_cast(promptBudget)); + tokenCount = promptBudget; + } const llama_batch promptBatch = llama_batch_get_one( promptTokens.data(), static_cast(promptTokens.size())); - if (llama_decode(context_, promptBatch) != 0) { + if (llama_decode(context_, promptBatch) != 0) throw std::runtime_error("LlamaGenerator: prompt decode failed"); - } llama_sampler_chain_params samplerParams = llama_sampler_chain_default_params(); @@ -474,61 +590,145 @@ std::string LlamaGenerator::infer(const std::string &systemPrompt, std::unique_ptr; SamplerPtr sampler(llama_sampler_chain_init(samplerParams), &llama_sampler_free); - - if (!sampler) { + if (!sampler) throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); - } - llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy()); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_temp(sampling_temperature_)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_top_p(sampling_top_p_, 1)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(sampling_seed_)); std::vector generatedTokens; generatedTokens.reserve(static_cast(maxTokens)); - for (int i = 0; i < maxTokens; ++i) { + for (int i = 0; i < effectiveMaxTokens; ++i) { const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); - if (llama_vocab_is_eog(vocab, next)) { + if (llama_vocab_is_eog(vocab, next)) break; - } - generatedTokens.push_back(next); - llama_token token = next; const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1); - if (llama_decode(context_, oneTokenBatch) != 0) { + if (llama_decode(context_, oneTokenBatch) != 0) throw std::runtime_error( "LlamaGenerator: decode failed during generation"); - } } std::string output; - for (const llama_token token : generatedTokens) { + for (const llama_token token : generatedTokens) appendTokenPiece(vocab, token, output); - } - return output; } -UserResult LlamaGenerator::generateUser(const std::string &locale) { +BreweryResult +LlamaGenerator::generateBrewery(const std::string &cityName, + const std::string &countryName, + const std::string ®ionContext) { + const std::string safeRegionContext = PrepareRegionContext(regionContext); + + const std::string systemPrompt = + "You are a copywriter for a craft beer travel guide. " + "Your writing is vivid, specific to place, and avoids generic beer " + "cliches. " + "You must output ONLY valid JSON. " + "The JSON schema must be exactly: {\"name\": \"string\", " + "\"description\": \"string\"}. " + "Do not include markdown formatting or backticks."; + std::string prompt = - "Generate a plausible craft beer enthusiast username and a one-sentence " - "bio. Locale: " + - locale + - ". Respond with exactly two lines: first line is the username (no " - "spaces), second line is the bio. Do not include bullets, numbering, " - "or any extra text."; + "Write a brewery name and place-specific description for a craft " + "brewery in " + + cityName + + (countryName.empty() ? std::string("") + : std::string(", ") + countryName) + + (safeRegionContext.empty() + ? std::string(".") + : std::string(". Regional context: ") + safeRegionContext); - const std::string raw = infer(prompt, 128); - auto [username, bio] = - parseTwoLineResponse(raw, "LlamaGenerator: malformed user response"); + const int maxAttempts = 3; + std::string raw; + std::string lastError; + for (int attempt = 0; attempt < maxAttempts; ++attempt) { + raw = infer(systemPrompt, prompt, 384); + spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1, + raw); - username.erase( - std::remove_if(username.begin(), username.end(), - [](unsigned char ch) { return std::isspace(ch); }), - username.end()); + std::string name; + std::string description; + const std::string validationError = + ValidateBreweryJson(raw, name, description); + if (validationError.empty()) { + return {std::move(name), std::move(description)}; + } - if (username.empty() || bio.empty()) { - throw std::runtime_error("LlamaGenerator: malformed user response"); + lastError = validationError; + spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}", + attempt + 1, validationError); + + prompt = "Your previous response was invalid. Error: " + validationError + + "\nReturn ONLY valid JSON with this exact schema: " + "{\"name\": \"string\", \"description\": \"string\"}." + "\nDo not include markdown, comments, or extra keys." + "\n\nLocation: " + + cityName + + (countryName.empty() ? std::string("") + : std::string(", ") + countryName) + + (safeRegionContext.empty() + ? std::string("") + : std::string("\nRegional context: ") + safeRegionContext); } - return {username, bio}; + spdlog::error("LlamaGenerator: malformed brewery response after {} attempts: " + "{}", + maxAttempts, lastError.empty() ? raw : lastError); + throw std::runtime_error("LlamaGenerator: malformed brewery response"); +} + +UserResult LlamaGenerator::generateUser(const std::string &locale) { + const std::string systemPrompt = + "You generate plausible social media profiles for craft beer " + "enthusiasts. " + "Respond with exactly two lines: " + "the first line is a username (lowercase, no spaces, 8-20 characters), " + "the second line is a one-sentence bio (20-40 words). " + "The profile should feel consistent with the locale. " + "No preamble, no labels."; + + std::string prompt = + "Generate a craft beer enthusiast profile. Locale: " + locale; + + const int maxAttempts = 3; + std::string raw; + for (int attempt = 0; attempt < maxAttempts; ++attempt) { + raw = infer(systemPrompt, prompt, 128); + spdlog::debug("LlamaGenerator (user): raw output (attempt {}): {}", + attempt + 1, raw); + + try { + auto [username, bio] = + parseTwoLineResponse(raw, "LlamaGenerator: malformed user response"); + + username.erase( + std::remove_if(username.begin(), username.end(), + [](unsigned char ch) { return std::isspace(ch); }), + username.end()); + + if (username.empty() || bio.empty()) { + throw std::runtime_error("LlamaGenerator: malformed user response"); + } + + if (bio.size() > 200) + bio = bio.substr(0, 200); + + return {username, bio}; + } catch (const std::exception &e) { + spdlog::warn("LlamaGenerator: malformed user response (attempt {}): {}", + attempt + 1, e.what()); + } + } + + spdlog::error("LlamaGenerator: malformed user response after {} attempts: {}", + maxAttempts, raw); + throw std::runtime_error("LlamaGenerator: malformed user response"); } diff --git a/pipeline/src/main.cpp b/pipeline/src/main.cpp index 97e058d..3f3c458 100644 --- a/pipeline/src/main.cpp +++ b/pipeline/src/main.cpp @@ -1,35 +1,66 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "curl_web_client.h" #include "data_downloader.h" #include "data_generator.h" #include "database.h" #include "json_loader.h" #include "llama_generator.h" #include "mock_generator.h" -#include -#include -#include -#include -#include +#include "wikipedia_service.h" -static bool FileExists(const std::string &filePath) { - return std::filesystem::exists(filePath); -} +namespace po = boost::program_options; int main(int argc, char *argv[]) { try { - curl_global_init(CURL_GLOBAL_DEFAULT); + const CurlGlobalState curl_state; - std::string modelPath = argc > 1 ? argv[1] : ""; - std::string cacheDir = argc > 2 ? argv[2] : "/tmp"; - std::string commit = - argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28 + po::options_description desc("Pipeline Options"); + desc.add_options()("help,h", "Produce help message")( + "model,m", po::value()->default_value(""), + "Path to LLM model (gguf)")( + "cache-dir,c", po::value()->default_value("/tmp"), + "Directory for cached JSON")( + "temperature", po::value()->default_value(0.8f), + "Sampling temperature (higher = more random)")( + "top-p", po::value()->default_value(0.92f), + "Nucleus sampling top-p in (0,1] (higher = more random)")( + "seed", po::value()->default_value(-1), + "Sampler seed: -1 for random, otherwise non-negative integer")( + "commit", po::value()->default_value("c5eb7772"), + "Git commit hash for DB consistency"); - std::string countryName = argc > 4 ? argv[4] : ""; + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + if (vm.count("help")) { + std::cout << desc << "\n"; + return 0; + } + + std::string modelPath = vm["model"].as(); + std::string cacheDir = vm["cache-dir"].as(); + float temperature = vm["temperature"].as(); + float topP = vm["top-p"].as(); + int seed = vm["seed"].as(); + std::string commit = vm["commit"].as(); std::string jsonPath = cacheDir + "/countries+states+cities.json"; std::string dbPath = cacheDir + "/biergarten-pipeline.db"; - bool hasJsonCache = FileExists(jsonPath); - bool hasDbCache = FileExists(dbPath); + bool hasJsonCache = std::filesystem::exists(jsonPath); + bool hasDbCache = std::filesystem::exists(dbPath); + + auto webClient = std::make_shared(); SqliteDatabase db; @@ -40,7 +71,7 @@ int main(int argc, char *argv[]) { spdlog::info("[Pipeline] Cache hit: skipping download and parse"); } else { spdlog::info("\n[Pipeline] Downloading geographic data from GitHub..."); - DataDownloader downloader; + DataDownloader downloader(webClient); downloader.DownloadCountriesDatabase(jsonPath, commit); JsonLoader::LoadWorldCities(jsonPath, db); @@ -52,17 +83,30 @@ int main(int argc, char *argv[]) { generator = std::make_unique(); spdlog::info("[Generator] Using MockGenerator (no model path provided)"); } else { - generator = std::make_unique(); - spdlog::info("[Generator] Using LlamaGenerator: {}", modelPath); + auto llamaGenerator = std::make_unique(); + llamaGenerator->setSamplingOptions(temperature, topP, seed); + spdlog::info( + "[Generator] Using LlamaGenerator: {} (temperature={}, top-p={}, " + "seed={})", + modelPath, temperature, topP, seed); + generator = std::move(llamaGenerator); } generator->load(modelPath); + WikipediaService wikipediaService(webClient); + spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); auto countries = db.QueryCountries(50); auto states = db.QueryStates(50); auto cities = db.QueryCities(); + // Build a quick map of country id -> name for per-city lookups. + auto allCountries = db.QueryCountries(0); + std::unordered_map countryMap; + for (const auto &c : allCountries) + countryMap[c.id] = c.name; + spdlog::info("\nTotal records loaded:"); spdlog::info(" Countries: {}", db.QueryCountries(0).size()); spdlog::info(" States: {}", db.QueryStates(0).size()); @@ -79,8 +123,23 @@ int main(int argc, char *argv[]) { spdlog::info("\n=== SAMPLE BREWERY GENERATION ==="); for (size_t i = 0; i < sampleCount; i++) { - const auto &[cityId, cityName] = cities[i]; - auto brewery = generator->generateBrewery(cityName, countryName, ""); + const auto &city = cities[i]; + const int cityId = city.id; + const std::string cityName = city.name; + + std::string localCountry; + const auto countryIt = countryMap.find(city.countryId); + if (countryIt != countryMap.end()) { + localCountry = countryIt->second; + } + + const std::string regionContext = + wikipediaService.GetSummary(cityName, localCountry); + spdlog::debug("[Pipeline] Region context for {}: {}", cityName, + regionContext); + + auto brewery = + generator->generateBrewery(cityName, localCountry, regionContext); generatedBreweries.push_back({cityId, cityName, brewery}); } @@ -95,12 +154,10 @@ int main(int argc, char *argv[]) { spdlog::info("\nOK: Pipeline completed successfully"); - curl_global_cleanup(); return 0; } catch (const std::exception &e) { spdlog::error("ERROR: Pipeline failed: {}", e.what()); - curl_global_cleanup(); return 1; } } diff --git a/pipeline/src/stream_parser.cpp b/pipeline/src/stream_parser.cpp index 432ea50..abf722d 100644 --- a/pipeline/src/stream_parser.cpp +++ b/pipeline/src/stream_parser.cpp @@ -1,15 +1,22 @@ -#include "stream_parser.h" -#include "database.h" #include -#include -#include -#include +#include + +#include +#include #include -using namespace rapidjson; +#include "database.h" +#include "stream_parser.h" + +class CityRecordHandler { + friend class boost::json::basic_parser; -class CityRecordHandler : public BaseReaderHandler, CityRecordHandler> { public: + static constexpr std::size_t max_array_size = static_cast(-1); + static constexpr std::size_t max_object_size = static_cast(-1); + static constexpr std::size_t max_string_size = static_cast(-1); + static constexpr std::size_t max_key_size = static_cast(-1); + struct ParseContext { SqliteDatabase *db = nullptr; std::function on_city; @@ -20,11 +27,35 @@ public: int states_inserted = 0; }; - CityRecordHandler(ParseContext &ctx) : context(ctx) {} + explicit CityRecordHandler(ParseContext &ctx) : context(ctx) {} - bool StartArray() { +private: + ParseContext &context; + + int depth = 0; + bool in_countries_array = false; + bool in_country_object = false; + bool in_states_array = false; + bool in_state_object = false; + bool in_cities_array = false; + bool building_city = false; + + int current_country_id = 0; + int current_state_id = 0; + CityRecord current_city = {}; + std::string current_key; + std::string current_key_val; + std::string current_string_val; + + std::string country_info[3]; + std::string state_info[2]; + + // Boost.JSON SAX Hooks + bool on_document_begin(boost::system::error_code &) { return true; } + bool on_document_end(boost::system::error_code &) { return true; } + + bool on_array_begin(boost::system::error_code &) { depth++; - if (depth == 1) { in_countries_array = true; } else if (depth == 3 && current_key == "states") { @@ -35,7 +66,7 @@ public: return true; } - bool EndArray(SizeType /*elementCount*/) { + bool on_array_end(std::size_t, boost::system::error_code &) { if (depth == 1) { in_countries_array = false; } else if (depth == 3) { @@ -47,9 +78,8 @@ public: return true; } - bool StartObject() { + bool on_object_begin(boost::system::error_code &) { depth++; - if (depth == 2 && in_countries_array) { in_country_object = true; current_country_id = 0; @@ -68,7 +98,7 @@ public: return true; } - bool EndObject(SizeType /*memberCount*/) { + bool on_object_end(std::size_t, boost::system::error_code &) { if (depth == 6 && building_city) { if (current_city.id > 0 && current_state_id > 0 && current_country_id > 0) { @@ -84,7 +114,7 @@ public: context.total_file_size); } } catch (const std::exception &e) { - spdlog::warn(" WARN: Failed to emit city: {}", e.what()); + spdlog::warn("Record parsing failed: {}", e.what()); } } building_city = false; @@ -95,7 +125,7 @@ public: state_info[0], state_info[1]); context.states_inserted++; } catch (const std::exception &e) { - spdlog::warn(" WARN: Failed to insert state: {}", e.what()); + spdlog::warn("Record parsing failed: {}", e.what()); } } in_state_object = false; @@ -106,7 +136,7 @@ public: country_info[1], country_info[2]); context.countries_inserted++; } catch (const std::exception &e) { - spdlog::warn(" WARN: Failed to insert country: {}", e.what()); + spdlog::warn("Record parsing failed: {}", e.what()); } } in_country_object = false; @@ -116,46 +146,71 @@ public: return true; } - bool Key(const char *str, SizeType len, bool /*copy*/) { - current_key.assign(str, len); + bool on_key_part(boost::json::string_view s, std::size_t, + boost::system::error_code &) { + current_key_val.append(s.data(), s.size()); return true; } - bool String(const char *str, SizeType len, bool /*copy*/) { + bool on_key(boost::json::string_view s, std::size_t, + boost::system::error_code &) { + current_key_val.append(s.data(), s.size()); + current_key = current_key_val; + current_key_val.clear(); + return true; + } + + bool on_string_part(boost::json::string_view s, std::size_t, + boost::system::error_code &) { + current_string_val.append(s.data(), s.size()); + return true; + } + + bool on_string(boost::json::string_view s, std::size_t, + boost::system::error_code &) { + current_string_val.append(s.data(), s.size()); + if (building_city && current_key == "name") { - current_city.name.assign(str, len); + current_city.name = current_string_val; } else if (in_state_object && current_key == "name") { - state_info[0].assign(str, len); + state_info[0] = current_string_val; } else if (in_state_object && current_key == "iso2") { - state_info[1].assign(str, len); + state_info[1] = current_string_val; } else if (in_country_object && current_key == "name") { - country_info[0].assign(str, len); + country_info[0] = current_string_val; } else if (in_country_object && current_key == "iso2") { - country_info[1].assign(str, len); + country_info[1] = current_string_val; } else if (in_country_object && current_key == "iso3") { - country_info[2].assign(str, len); + country_info[2] = current_string_val; } + + current_string_val.clear(); return true; } - bool Int(int i) { + bool on_number_part(boost::json::string_view, boost::system::error_code &) { + return true; + } + + bool on_int64(int64_t i, boost::json::string_view, + boost::system::error_code &) { if (building_city && current_key == "id") { - current_city.id = i; + current_city.id = static_cast(i); } else if (in_state_object && current_key == "id") { - current_state_id = i; + current_state_id = static_cast(i); } else if (in_country_object && current_key == "id") { - current_country_id = i; + current_country_id = static_cast(i); } return true; } - bool Uint(unsigned i) { return Int(static_cast(i)); } + bool on_uint64(uint64_t u, boost::json::string_view, + boost::system::error_code &ec) { + return on_int64(static_cast(u), "", ec); + } - bool Int64(int64_t i) { return Int(static_cast(i)); } - - bool Uint64(uint64_t i) { return Int(static_cast(i)); } - - bool Double(double d) { + bool on_double(double d, boost::json::string_view, + boost::system::error_code &) { if (building_city) { if (current_key == "latitude") { current_city.latitude = d; @@ -166,27 +221,14 @@ public: return true; } - bool Bool(bool /*b*/) { return true; } - bool Null() { return true; } - -private: - ParseContext &context; - - int depth = 0; - bool in_countries_array = false; - bool in_country_object = false; - bool in_states_array = false; - bool in_state_object = false; - bool in_cities_array = false; - bool building_city = false; - - int current_country_id = 0; - int current_state_id = 0; - CityRecord current_city = {}; - std::string current_key; - - std::string country_info[3]; - std::string state_info[2]; + bool on_bool(bool, boost::system::error_code &) { return true; } + bool on_null(boost::system::error_code &) { return true; } + bool on_comment_part(boost::json::string_view, boost::system::error_code &) { + return true; + } + bool on_comment(boost::json::string_view, boost::system::error_code &) { + return true; + } }; void StreamingJsonParser::Parse( @@ -194,7 +236,7 @@ void StreamingJsonParser::Parse( std::function onCity, std::function onProgress) { - spdlog::info(" Streaming parse of {}...", filePath); + spdlog::info(" Streaming parse of {} (Boost.JSON)...", filePath); FILE *file = std::fopen(filePath.c_str(), "rb"); if (!file) { @@ -212,23 +254,35 @@ void StreamingJsonParser::Parse( CityRecordHandler::ParseContext ctx{&db, onCity, onProgress, 0, total_size, 0, 0}; - CityRecordHandler handler(ctx); + boost::json::basic_parser parser( + boost::json::parse_options{}, ctx); - Reader reader; char buf[65536]; - FileReadStream frs(file, buf, sizeof(buf)); + size_t bytes_read; + boost::system::error_code ec; - if (!reader.Parse(frs, handler)) { - ParseErrorCode errCode = reader.GetParseErrorCode(); - size_t errOffset = reader.GetErrorOffset(); - std::fclose(file); - throw std::runtime_error(std::string("JSON parse error at offset ") + - std::to_string(errOffset) + - " (code: " + std::to_string(errCode) + ")"); + while ((bytes_read = std::fread(buf, 1, sizeof(buf), file)) > 0) { + char const *p = buf; + std::size_t remain = bytes_read; + + while (remain > 0) { + std::size_t consumed = parser.write_some(true, p, remain, ec); + if (ec) { + std::fclose(file); + throw std::runtime_error("JSON parse error: " + ec.message()); + } + p += consumed; + remain -= consumed; + } } + parser.write_some(false, nullptr, 0, ec); // Signal EOF std::fclose(file); + if (ec) { + throw std::runtime_error("JSON parse error at EOF: " + ec.message()); + } + spdlog::info(" OK: Parsed {} countries, {} states, {} cities", ctx.countries_inserted, ctx.states_inserted, ctx.cities_emitted); } diff --git a/pipeline/src/wikipedia_service.cpp b/pipeline/src/wikipedia_service.cpp new file mode 100644 index 0000000..29b3092 --- /dev/null +++ b/pipeline/src/wikipedia_service.cpp @@ -0,0 +1,77 @@ +#include "wikipedia_service.h" +#include +#include + +WikipediaService::WikipediaService(std::shared_ptr client) + : client_(std::move(client)) {} + +std::string WikipediaService::FetchExtract(std::string_view query) { + const std::string encoded = client_->UrlEncode(std::string(query)); + const std::string url = + "https://en.wikipedia.org/w/api.php?action=query&titles=" + encoded + + "&prop=extracts&explaintext=true&format=json"; + + const std::string body = client_->Get(url); + + boost::system::error_code ec; + boost::json::value doc = boost::json::parse(body, ec); + + if (!ec && doc.is_object()) { + auto &pages = doc.at("query").at("pages").get_object(); + if (!pages.empty()) { + auto &page = pages.begin()->value().get_object(); + if (page.contains("extract") && page.at("extract").is_string()) { + std::string extract(page.at("extract").as_string().c_str()); + spdlog::debug("WikipediaService fetched {} chars for '{}'", + extract.size(), query); + return extract; + } + } + } + + return {}; +} + +std::string WikipediaService::GetSummary(std::string_view city, + std::string_view country) { + const std::string key = std::string(city) + "|" + std::string(country); + const auto cacheIt = cache_.find(key); + if (cacheIt != cache_.end()) { + return cacheIt->second; + } + + std::string result; + + if (!client_) { + cache_.emplace(key, result); + return result; + } + + std::string regionQuery(city); + if (!country.empty()) { + regionQuery += ", "; + regionQuery += country; + } + + const std::string beerQuery = "beer in " + std::string(city); + + try { + const std::string regionExtract = FetchExtract(regionQuery); + const std::string beerExtract = FetchExtract(beerQuery); + + if (!regionExtract.empty()) { + result += regionExtract; + } + if (!beerExtract.empty()) { + if (!result.empty()) + result += "\n\n"; + result += beerExtract; + } + } catch (const std::runtime_error &e) { + spdlog::debug("WikipediaService lookup failed for '{}': {}", regionQuery, + e.what()); + } + + cache_.emplace(key, result); + return result; +}