From 280c9c61bd90bfe44720eefe5395a583f6ce2eb2 Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Wed, 1 Apr 2026 23:29:16 -0400 Subject: [PATCH] Implement Llama-based brewery and user data generation; remove mock generator and related files --- pipeline/CMakeLists.txt | 23 ++- pipeline/includes/data_generator.h | 25 +++ pipeline/includes/generator.h | 36 ----- pipeline/includes/json_loader.h | 1 - pipeline/includes/llama_generator.h | 24 +++ pipeline/includes/mock_generator.h | 23 +++ pipeline/includes/work_queue.h | 63 -------- pipeline/src/generator.cpp | 21 --- pipeline/src/llama_generator.cpp | 236 ++++++++++++++++++++++++++++ pipeline/src/main.cpp | 20 ++- pipeline/src/mock_generator.cpp | 101 ++++++++++++ 11 files changed, 445 insertions(+), 128 deletions(-) create mode 100644 pipeline/includes/data_generator.h delete mode 100644 pipeline/includes/generator.h create mode 100644 pipeline/includes/llama_generator.h create mode 100644 pipeline/includes/mock_generator.h delete mode 100644 pipeline/includes/work_queue.h delete mode 100644 pipeline/src/generator.cpp create mode 100644 pipeline/src/llama_generator.cpp create mode 100644 pipeline/src/mock_generator.cpp diff --git a/pipeline/CMakeLists.txt b/pipeline/CMakeLists.txt index 31e5409..15f7171 100644 --- a/pipeline/CMakeLists.txt +++ b/pipeline/CMakeLists.txt @@ -39,6 +39,24 @@ if(NOT spdlog_POPULATED) add_subdirectory(${spdlog_SOURCE_DIR} ${spdlog_BINARY_DIR} EXCLUDE_FROM_ALL) endif() +# llama.cpp (on-device inference) +set(LLAMA_BUILD_TESTS OFF CACHE BOOL "" FORCE) +set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) +set(LLAMA_BUILD_SERVER OFF CACHE BOOL "" FORCE) + +FetchContent_Declare( + llama_cpp + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git + GIT_TAG b8611 +) +FetchContent_MakeAvailable(llama_cpp) + +if(TARGET llama) + target_compile_options(llama PRIVATE + $<$:-include algorithm> + ) +endif() + file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS src/*.cpp ) @@ -49,6 +67,7 @@ target_include_directories(biergarten-pipeline PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/includes ${rapidjson_SOURCE_DIR}/include + ${llama_cpp_SOURCE_DIR}/include ) target_link_libraries(biergarten-pipeline @@ -56,7 +75,8 @@ target_link_libraries(biergarten-pipeline CURL::libcurl Boost::unit_test_framework SQLite::SQLite3 - spdlog::spdlog + spdlog::spdlog + llama ) target_compile_options(biergarten-pipeline PRIVATE @@ -116,7 +136,6 @@ if(BUILD_TESTING) Boost::unit_test_framework CURL::libcurl nlohmann_json::nlohmann_json - llama ) add_test( diff --git a/pipeline/includes/data_generator.h b/pipeline/includes/data_generator.h new file mode 100644 index 0000000..61ae07e --- /dev/null +++ b/pipeline/includes/data_generator.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +struct BreweryResult { + std::string name; + std::string description; +}; + +struct UserResult { + std::string username; + std::string bio; +}; + +class IDataGenerator { +public: + virtual ~IDataGenerator() = default; + + virtual void load(const std::string &modelPath) = 0; + + virtual BreweryResult generateBrewery(const std::string &cityName, + const std::string ®ionContext) = 0; + + virtual UserResult generateUser(const std::string &locale) = 0; +}; diff --git a/pipeline/includes/generator.h b/pipeline/includes/generator.h deleted file mode 100644 index 9db1c65..0000000 --- a/pipeline/includes/generator.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -/// @brief Deterministic mock brewery text generator used in pipeline output. -class LlamaBreweryGenerator { -private: - const std::vector breweryAdjectives = { - "Craft", "Heritage", "Local", "Artisan", - "Pioneer", "Golden", "Modern", "Classic"}; - - const std::vector breweryNouns = { - "Brewing Co.", "Brewery", "Bier Haus", "Taproom", - "Works", "House", "Fermentery", "Ale Co."}; - - const std::vector descriptions = { - "Handcrafted pale ales and seasonal IPAs with local ingredients.", - "Traditional lagers and experimental sours in small batches.", - "Award-winning stouts and wildly hoppy blonde ales.", - "Craft brewery specializing in Belgian-style triples and dark porters.", - "Modern brewery blending tradition with bold experimental flavors."}; - -public: - /// @brief Generated brewery payload for one city. - struct Brewery { - std::string name; - std::string description; - }; - - /// @brief Loads model resources (mock implementation in this project). - void LoadModel(const std::string &modelPath); - - /// @brief Generates deterministic brewery text for a city and seed. - Brewery GenerateBrewery(const std::string &cityName, int seed); -}; diff --git a/pipeline/includes/json_loader.h b/pipeline/includes/json_loader.h index a201370..b85b863 100644 --- a/pipeline/includes/json_loader.h +++ b/pipeline/includes/json_loader.h @@ -2,7 +2,6 @@ #include "database.h" #include "stream_parser.h" -#include "work_queue.h" #include /// @brief Loads world-city JSON data into SQLite through streaming parsing. diff --git a/pipeline/includes/llama_generator.h b/pipeline/includes/llama_generator.h new file mode 100644 index 0000000..0865ecc --- /dev/null +++ b/pipeline/includes/llama_generator.h @@ -0,0 +1,24 @@ +#pragma once + +#include "data_generator.h" +#include +#include + +struct llama_model; +struct llama_context; + +class LlamaGenerator final : public IDataGenerator { +public: + ~LlamaGenerator() override; + + void load(const std::string &modelPath) override; + BreweryResult generateBrewery(const std::string &cityName, + const std::string ®ionContext) override; + UserResult generateUser(const std::string &locale) override; + +private: + std::string infer(const std::string &prompt, int maxTokens = 256); + + llama_model *model_ = nullptr; + llama_context *context_ = nullptr; +}; diff --git a/pipeline/includes/mock_generator.h b/pipeline/includes/mock_generator.h new file mode 100644 index 0000000..efc4d3f --- /dev/null +++ b/pipeline/includes/mock_generator.h @@ -0,0 +1,23 @@ +#pragma once + +#include "data_generator.h" +#include +#include + +class MockGenerator final : public IDataGenerator { +public: + void load(const std::string &modelPath) override; + BreweryResult generateBrewery(const std::string &cityName, + const std::string ®ionContext) override; + UserResult generateUser(const std::string &locale) override; + +private: + static std::size_t deterministicHash(const std::string &a, + const std::string &b); + + static const std::vector kBreweryAdjectives; + static const std::vector kBreweryNouns; + static const std::vector kBreweryDescriptions; + static const std::vector kUsernames; + static const std::vector kBios; +}; diff --git a/pipeline/includes/work_queue.h b/pipeline/includes/work_queue.h deleted file mode 100644 index 0b6feea..0000000 --- a/pipeline/includes/work_queue.h +++ /dev/null @@ -1,63 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -/// @brief Bounded thread-safe queue with blocking push/pop and shutdown. -template class WorkQueue { -private: - std::queue queue; - std::mutex mutex; - std::condition_variable cv_not_empty; - std::condition_variable cv_not_full; - size_t max_size; - bool shutdown = false; - -public: - /// @brief Creates a queue with fixed capacity. - explicit WorkQueue(size_t capacity) : max_size(capacity) {} - - /// @brief Pushes an item, blocking while full unless shutdown is signaled. - bool push(T item) { - std::unique_lock lock(mutex); - cv_not_full.wait(lock, - [this] { return queue.size() < max_size || shutdown; }); - - if (shutdown) - return false; - - queue.push(std::move(item)); - cv_not_empty.notify_one(); - return true; - } - - /// @brief Pops an item, blocking while empty unless shutdown is signaled. - std::optional pop() { - std::unique_lock lock(mutex); - cv_not_empty.wait(lock, [this] { return !queue.empty() || shutdown; }); - - if (queue.empty()) - return std::nullopt; - - T item = std::move(queue.front()); - queue.pop(); - cv_not_full.notify_one(); - return item; - } - - /// @brief Signals queue shutdown and wakes all waiting producers/consumers. - void shutdown_queue() { - std::unique_lock lock(mutex); - shutdown = true; - cv_not_empty.notify_all(); - cv_not_full.notify_all(); - } - - /// @brief Returns current queue size. - size_t size() const { - std::lock_guard lock(mutex); - return queue.size(); - } -}; diff --git a/pipeline/src/generator.cpp b/pipeline/src/generator.cpp deleted file mode 100644 index 85c12c3..0000000 --- a/pipeline/src/generator.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "generator.h" -#include -#include - -void LlamaBreweryGenerator::LoadModel(const std::string &modelPath) { - spdlog::info(" [Mock] Initialized llama model: {}", modelPath); - spdlog::info(" OK: Model ready"); -} - -LlamaBreweryGenerator::Brewery -LlamaBreweryGenerator::GenerateBrewery(const std::string &cityName, int seed) { - // Deterministic mock generation for stable test output. - size_t nameHash = std::hash{}(cityName + std::to_string(seed)); - - Brewery result; - result.name = breweryAdjectives[nameHash % breweryAdjectives.size()] + " " + - breweryNouns[(nameHash / 7) % breweryNouns.size()]; - result.description = descriptions[(nameHash / 13) % descriptions.size()]; - - return result; -} diff --git a/pipeline/src/llama_generator.cpp b/pipeline/src/llama_generator.cpp new file mode 100644 index 0000000..952f5fc --- /dev/null +++ b/pipeline/src/llama_generator.cpp @@ -0,0 +1,236 @@ +#include "llama_generator.h" + +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +std::string trim(std::string value) { + auto notSpace = [](unsigned char ch) { return !std::isspace(ch); }; + + value.erase(value.begin(), + std::find_if(value.begin(), value.end(), notSpace)); + value.erase(std::find_if(value.rbegin(), value.rend(), notSpace).base(), + value.end()); + + return value; +} + +void appendTokenPiece(const llama_vocab *vocab, llama_token token, + std::string &output) { + std::array buffer{}; + int32_t bytes = + llama_token_to_piece(vocab, token, buffer.data(), + static_cast(buffer.size()), 0, true); + + if (bytes < 0) { + std::vector dynamicBuffer(static_cast(-bytes)); + bytes = llama_token_to_piece(vocab, token, dynamicBuffer.data(), + static_cast(dynamicBuffer.size()), 0, + true); + if (bytes < 0) { + throw std::runtime_error( + "LlamaGenerator: failed to decode sampled token piece"); + } + + output.append(dynamicBuffer.data(), static_cast(bytes)); + return; + } + + output.append(buffer.data(), static_cast(bytes)); +} + +std::pair +parseTwoLineResponse(const std::string &raw, const std::string &errorMessage) { + const auto newlinePos = raw.find('\n'); + if (newlinePos == std::string::npos) { + throw std::runtime_error(errorMessage); + } + + std::string first = trim(raw.substr(0, newlinePos)); + std::string second = trim(raw.substr(newlinePos + 1)); + + if (first.empty() || second.empty()) { + throw std::runtime_error(errorMessage); + } + + return {first, second}; +} + +} // namespace + +LlamaGenerator::~LlamaGenerator() { + if (context_ != nullptr) { + llama_free(context_); + context_ = nullptr; + } + + if (model_ != nullptr) { + llama_model_free(model_); + model_ = nullptr; + } + + llama_backend_free(); +} + +void LlamaGenerator::load(const std::string &modelPath) { + if (modelPath.empty()) { + throw std::runtime_error("LlamaGenerator: model path must not be empty"); + } + + if (context_ != nullptr) { + llama_free(context_); + context_ = nullptr; + } + if (model_ != nullptr) { + llama_model_free(model_); + model_ = nullptr; + } + + llama_backend_init(); + + llama_model_params modelParams = llama_model_default_params(); + model_ = llama_load_model_from_file(modelPath.c_str(), modelParams); + if (model_ == nullptr) { + throw std::runtime_error( + "LlamaGenerator: failed to load model from path: " + modelPath); + } + + llama_context_params contextParams = llama_context_default_params(); + contextParams.n_ctx = 2048; + + context_ = llama_init_from_model(model_, contextParams); + if (context_ == nullptr) { + llama_model_free(model_); + model_ = nullptr; + throw std::runtime_error("LlamaGenerator: failed to create context"); + } + + spdlog::info("[LlamaGenerator] Loaded model: {}", modelPath); +} + +std::string LlamaGenerator::infer(const std::string &prompt, int maxTokens) { + if (model_ == nullptr || context_ == nullptr) { + throw std::runtime_error("LlamaGenerator: model not loaded"); + } + + const llama_vocab *vocab = llama_model_get_vocab(model_); + if (vocab == nullptr) { + throw std::runtime_error("LlamaGenerator: vocab unavailable"); + } + + std::vector promptTokens(prompt.size() + 8); + int32_t tokenCount = + llama_tokenize(vocab, prompt.c_str(), static_cast(prompt.size()), + promptTokens.data(), + static_cast(promptTokens.size()), true, true); + + if (tokenCount < 0) { + promptTokens.resize(static_cast(-tokenCount)); + tokenCount = + llama_tokenize(vocab, prompt.c_str(), + static_cast(prompt.size()), promptTokens.data(), + static_cast(promptTokens.size()), true, true); + } + + if (tokenCount < 0) { + throw std::runtime_error("LlamaGenerator: prompt tokenization failed"); + } + + promptTokens.resize(static_cast(tokenCount)); + + const llama_batch promptBatch = llama_batch_get_one( + promptTokens.data(), static_cast(promptTokens.size())); + if (llama_decode(context_, promptBatch) != 0) { + throw std::runtime_error("LlamaGenerator: prompt decode failed"); + } + + llama_sampler_chain_params samplerParams = + llama_sampler_chain_default_params(); + using SamplerPtr = + std::unique_ptr; + SamplerPtr sampler(llama_sampler_chain_init(samplerParams), + &llama_sampler_free); + + if (!sampler) { + throw std::runtime_error("LlamaGenerator: failed to initialize sampler"); + } + + llama_sampler_chain_add(sampler.get(), llama_sampler_init_greedy()); + + std::vector generatedTokens; + generatedTokens.reserve(static_cast(maxTokens)); + + for (int i = 0; i < maxTokens; ++i) { + const llama_token next = llama_sampler_sample(sampler.get(), context_, -1); + if (llama_vocab_is_eog(vocab, next)) { + break; + } + + generatedTokens.push_back(next); + + llama_token token = next; + const llama_batch oneTokenBatch = llama_batch_get_one(&token, 1); + if (llama_decode(context_, oneTokenBatch) != 0) { + throw std::runtime_error( + "LlamaGenerator: decode failed during generation"); + } + } + + std::string output; + for (const llama_token token : generatedTokens) { + appendTokenPiece(vocab, token, output); + } + + return output; +} + +BreweryResult +LlamaGenerator::generateBrewery(const std::string &cityName, + const std::string ®ionContext) { + std::string prompt = + "Generate a craft brewery name and one-sentence description for a " + "brewery located in " + + cityName + ". " + regionContext + + " Respond with exactly two lines: first line is the name, second " + "line is the description."; + + const std::string raw = infer(prompt, 128); + auto [name, description] = + parseTwoLineResponse(raw, "LlamaGenerator: malformed brewery response"); + + return {name, description}; +} + +UserResult LlamaGenerator::generateUser(const std::string &locale) { + std::string prompt = + "Generate a plausible craft beer enthusiast username and a one-sentence " + "bio. Locale: " + + locale + + ". Respond with exactly two lines: first line is the username (no " + "spaces), second line is the bio."; + + const std::string raw = infer(prompt, 128); + auto [username, bio] = + parseTwoLineResponse(raw, "LlamaGenerator: malformed user response"); + + username.erase( + std::remove_if(username.begin(), username.end(), + [](unsigned char ch) { return std::isspace(ch); }), + username.end()); + + if (username.empty() || bio.empty()) { + throw std::runtime_error("LlamaGenerator: malformed user response"); + } + + return {username, bio}; +} diff --git a/pipeline/src/main.cpp b/pipeline/src/main.cpp index 287dc26..4b4d4c1 100644 --- a/pipeline/src/main.cpp +++ b/pipeline/src/main.cpp @@ -1,9 +1,12 @@ #include "data_downloader.h" +#include "data_generator.h" #include "database.h" -#include "generator.h" #include "json_loader.h" +#include "llama_generator.h" +#include "mock_generator.h" #include #include +#include #include static bool FileExists(const std::string &filePath) { @@ -14,7 +17,7 @@ int main(int argc, char *argv[]) { try { curl_global_init(CURL_GLOBAL_DEFAULT); - std::string modelPath = argc > 1 ? argv[1] : "./model.gguf"; + std::string modelPath = argc > 1 ? argv[1] : ""; std::string cacheDir = argc > 2 ? argv[2] : "/tmp"; std::string commit = argc > 3 ? argv[3] : "c5eb7772"; // Default: stable 2026-03-28 @@ -41,8 +44,15 @@ int main(int argc, char *argv[]) { } spdlog::info("Initializing brewery generator..."); - LlamaBreweryGenerator generator; - generator.LoadModel(modelPath); + std::unique_ptr generator; + if (modelPath.empty()) { + generator = std::make_unique(); + spdlog::info("[Generator] Using MockGenerator (no model path provided)"); + } else { + generator = std::make_unique(); + spdlog::info("[Generator] Using LlamaGenerator: {}", modelPath); + } + generator->load(modelPath); spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ==="); @@ -74,7 +84,7 @@ int main(int argc, char *argv[]) { spdlog::info("\n=== SAMPLE BREWERY GENERATION ===\n"); for (size_t i = 0; i < std::min(size_t(5), cities.size()); i++) { const auto &[cityId, cityName] = cities[i]; - auto brewery = generator.GenerateBrewery(cityName, i); + auto brewery = generator->generateBrewery(cityName, ""); spdlog::info(" {}: {}", cityName, brewery.name); spdlog::info(" -> {}", brewery.description); } diff --git a/pipeline/src/mock_generator.cpp b/pipeline/src/mock_generator.cpp new file mode 100644 index 0000000..7255d27 --- /dev/null +++ b/pipeline/src/mock_generator.cpp @@ -0,0 +1,101 @@ +#include "mock_generator.h" + +#include +#include + +const std::vector MockGenerator::kBreweryAdjectives = { + "Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden", + "Modern", "Classic", "Summit", "Northern", "Riverstone", "Barrel", + "Hinterland", "Harbor", "Wild", "Granite", "Copper", "Maple"}; + +const std::vector MockGenerator::kBreweryNouns = { + "Brewing Co.", "Brewery", "Bier Haus", "Taproom", "Works", + "House", "Fermentery", "Ale Co.", "Cellars", "Collective", + "Project", "Foundry", "Malthouse", "Public House", "Co-op", + "Lab", "Beer Hall", "Guild"}; + +const std::vector MockGenerator::kBreweryDescriptions = { + "Handcrafted pale ales and seasonal IPAs with local ingredients.", + "Traditional lagers and experimental sours in small batches.", + "Award-winning stouts and wildly hoppy blonde ales.", + "Craft brewery specializing in Belgian-style triples and dark porters.", + "Modern brewery blending tradition with bold experimental flavors.", + "Neighborhood-focused taproom pouring crisp pilsners and citrusy pale " + "ales.", + "Small-batch brewery known for barrel-aged releases and smoky lagers.", + "Independent brewhouse pairing farmhouse ales with rotating food pop-ups.", + "Community brewpub making balanced bitters, saisons, and hazy IPAs.", + "Experimental nanobrewery exploring local yeast and regional grains.", + "Family-run brewery producing smooth amber ales and robust porters.", + "Urban brewery crafting clean lagers and bright, fruit-forward sours.", + "Riverfront brewhouse featuring oak-matured ales and seasonal blends.", + "Modern taproom focused on sessionable lagers and classic pub styles.", + "Brewery rooted in tradition with a lineup of malty reds and crisp lagers.", + "Creative brewery offering rotating collaborations and limited draft-only " + "pours.", + "Locally inspired brewery serving approachable ales with bold hop " + "character.", + "Destination taproom known for balanced IPAs and cocoa-rich stouts."}; + +const std::vector MockGenerator::kUsernames = { + "hopseeker", "malttrail", "yeastwhisper", "lagerlane", + "barrelbound", "foamfinder", "taphunter", "graingeist", + "brewscout", "aleatlas", "caskcompass", "hopsandmaps", + "mashpilot", "pintnomad", "fermentfriend", "stoutsignal", + "sessionwander", "kettlekeeper"}; + +const std::vector MockGenerator::kBios = { + "Always chasing balanced IPAs and crisp lagers across local taprooms.", + "Weekend brewery explorer with a soft spot for dark, roasty stouts.", + "Documenting tiny brewpubs, fresh pours, and unforgettable beer gardens.", + "Fan of farmhouse ales, food pairings, and long tasting flights.", + "Collecting favorite pilsners one city at a time.", + "Hops-first drinker who still saves room for classic malt-forward styles.", + "Finding hidden tap lists and sharing the best seasonal releases.", + "Brewery road-tripper focused on local ingredients and clean fermentation.", + "Always comparing house lagers and ranking patio pint vibes.", + "Curious about yeast strains, barrel programs, and cellar experiments.", + "Believes every neighborhood deserves a great community taproom.", + "Looking for session beers that taste great from first sip to last.", + "Belgian ale enthusiast who never skips a new saison.", + "Hazy IPA critic with deep respect for a perfectly clear pilsner.", + "Visits breweries for the stories, stays for the flagship pours.", + "Craft beer fan mapping tasting notes and favorite brew routes.", + "Always ready to trade recommendations for underrated local breweries.", + "Keeping a running list of must-try collab releases and tap takeovers."}; + +void MockGenerator::load(const std::string & /*modelPath*/) { + spdlog::info("[MockGenerator] No model needed"); +} + +std::size_t MockGenerator::deterministicHash(const std::string &a, + const std::string &b) { + std::size_t seed = std::hash{}(a); + const std::size_t mixed = std::hash{}(b); + seed ^= mixed + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + seed = (seed << 13) | (seed >> ((sizeof(std::size_t) * 8) - 13)); + return seed; +} + +BreweryResult MockGenerator::generateBrewery(const std::string &cityName, + const std::string ®ionContext) { + const std::size_t hash = regionContext.empty() + ? std::hash{}(cityName) + : deterministicHash(cityName, regionContext); + + BreweryResult result; + result.name = kBreweryAdjectives[hash % kBreweryAdjectives.size()] + " " + + kBreweryNouns[(hash / 7) % kBreweryNouns.size()]; + result.description = + kBreweryDescriptions[(hash / 13) % kBreweryDescriptions.size()]; + return result; +} + +UserResult MockGenerator::generateUser(const std::string &locale) { + const std::size_t hash = std::hash{}(locale); + + UserResult result; + result.username = kUsernames[hash % kUsernames.size()]; + result.bio = kBios[(hash / 11) % kBios.size()]; + return result; +}