mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-05-31 17:53:59 +00:00
Merge branch 'main-2.0' into feat/add-sqllite-to-cpp-pipeline
This commit is contained in:
38
pipeline/format.sh
Executable file
38
pipeline/format.sh
Executable file
@@ -0,0 +1,38 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check for -y flag
|
||||||
|
SKIP_CONFIRM=false
|
||||||
|
if [[ "$1" == "-y" ]]; then
|
||||||
|
SKIP_CONFIRM=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "WARNING: This script will format all .cpp, .h, .cxx, .cc .c, .hpp files in the src and includes directories."
|
||||||
|
echo "This script will overwrite the files with the formatted version."
|
||||||
|
|
||||||
|
if [[ "$SKIP_CONFIRM" == false ]]; then
|
||||||
|
read -p "Do you want to continue? (y/n) " -n 1 -r
|
||||||
|
echo
|
||||||
|
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||||
|
echo "Aborted."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f .clang-format ]; then
|
||||||
|
echo "ERROR: .clang-format file not found."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! command -v clang-format &>/dev/null; then
|
||||||
|
echo "ERROR: clang-format not found."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Formatting files..."
|
||||||
|
|
||||||
|
find includes src \( -name "*.cpp" -o -name "*.hpp" -o -name "*.h" -o -name "*.c" -o -name "*.cc" -o -name "*.cxx" \) | xargs clang-format -i
|
||||||
|
|
||||||
|
echo "Done."
|
||||||
|
|
||||||
|
exit 0
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
*
|
*
|
||||||
* @param options Parsed application options.
|
* @param options Parsed application options.
|
||||||
* @param model_path Filesystem path to GGUF model assets.
|
* @param model_path Filesystem path to GGUF model assets.
|
||||||
* @param prompt_formatter Formatter that produces model-specific prompts.
|
* @param prompt_formatter Formatter that produces model-specific prompts.
|
||||||
*/
|
*/
|
||||||
LlamaGenerator(const ApplicationOptions& options,
|
LlamaGenerator(const ApplicationOptions& options,
|
||||||
const std::string& model_path,
|
const std::string& model_path,
|
||||||
@@ -100,24 +100,24 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
* @param system_prompt System role prompt.
|
* @param system_prompt System role prompt.
|
||||||
* @param prompt User prompt.
|
* @param prompt User prompt.
|
||||||
* @param max_tokens Maximum tokens to generate.
|
* @param max_tokens Maximum tokens to generate.
|
||||||
* @param grammar Optional GBNF grammar constraining generated output.
|
* @param grammar Optional GBNF grammar constraining generated output.
|
||||||
* @return Generated text.
|
* @return Generated text.
|
||||||
*/
|
*/
|
||||||
std::string Infer(const std::string& system_prompt, const std::string& prompt,
|
std::string Infer(const std::string& system_prompt, const std::string& prompt,
|
||||||
int max_tokens = kDefaultMaxTokens,
|
int max_tokens = kDefaultMaxTokens,
|
||||||
std::string_view grammar = {});
|
std::string_view grammar = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Runs inference on an already-formatted prompt.
|
* @brief Runs inference on an already-formatted prompt.
|
||||||
*
|
*
|
||||||
* @param formatted_prompt Prompt preformatted for model chat template.
|
* @param formatted_prompt Prompt preformatted for model chat template.
|
||||||
* @param max_tokens Maximum tokens to generate.
|
* @param max_tokens Maximum tokens to generate.
|
||||||
* @param grammar Optional GBNF grammar constraining generated output.
|
* @param grammar Optional GBNF grammar constraining generated output.
|
||||||
* @return Generated text.
|
* @return Generated text.
|
||||||
*/
|
*/
|
||||||
std::string InferFormatted(const std::string& formatted_prompt,
|
std::string InferFormatted(const std::string& formatted_prompt,
|
||||||
int max_tokens = kDefaultMaxTokens,
|
int max_tokens = kDefaultMaxTokens,
|
||||||
std::string_view grammar = {});
|
std::string_view grammar = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Loads the brewery system prompt from disk.
|
* @brief Loads the brewery system prompt from disk.
|
||||||
@@ -125,7 +125,8 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
* @param prompt_file_path Prompt file path to try first.
|
* @param prompt_file_path Prompt file path to try first.
|
||||||
* @return Loaded prompt text.
|
* @return Loaded prompt text.
|
||||||
*/
|
*/
|
||||||
std::string LoadBrewerySystemPrompt(const std::filesystem::path& prompt_file_path);
|
std::string LoadBrewerySystemPrompt(
|
||||||
|
const std::filesystem::path& prompt_file_path);
|
||||||
|
|
||||||
ModelHandle model_;
|
ModelHandle model_;
|
||||||
ContextHandle context_;
|
ContextHandle context_;
|
||||||
|
|||||||
@@ -13,6 +13,5 @@ class IPromptFormatter {
|
|||||||
virtual ~IPromptFormatter() = default;
|
virtual ~IPromptFormatter() = default;
|
||||||
|
|
||||||
[[nodiscard]] virtual std::string Format(
|
[[nodiscard]] virtual std::string Format(
|
||||||
std::string_view system_prompt,
|
std::string_view system_prompt, std::string_view user_prompt) const = 0;
|
||||||
std::string_view user_prompt) const = 0;
|
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ void BiergartenDataGenerator::LogResults() const {
|
|||||||
index, location.city, location.country, location.state_province,
|
index, location.city, location.country, location.state_province,
|
||||||
location.iso3166_2, location.latitude, location.longitude);
|
location.iso3166_2, location.latitude, location.longitude);
|
||||||
spdlog::info(" brewery_name_en=\"{}\"", brewery.name_en);
|
spdlog::info(" brewery_name_en=\"{}\"", brewery.name_en);
|
||||||
spdlog::info(" brewery_description_en=\"{}\"",
|
spdlog::info(" brewery_description_en=\"{}\"", brewery.description_en);
|
||||||
brewery.description_en);
|
|
||||||
spdlog::info(" brewery_name_local=\"{}\"", brewery.name_local);
|
spdlog::info(" brewery_name_local=\"{}\"", brewery.name_local);
|
||||||
spdlog::info(" brewery_description_local=\"{}\"",
|
spdlog::info(" brewery_description_local=\"{}\"",
|
||||||
brewery.description_local);
|
brewery.description_local);
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
* @brief BiergartenDataGenerator::Run() implementation.
|
* @brief BiergartenDataGenerator::Run() implementation.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "biergarten_data_generator.h"
|
#include "biergarten_data_generator.h"
|
||||||
|
|
||||||
bool BiergartenDataGenerator::Run() {
|
bool BiergartenDataGenerator::Run() {
|
||||||
@@ -20,7 +20,7 @@ bool BiergartenDataGenerator::Run() {
|
|||||||
try {
|
try {
|
||||||
std::string region_context = context_service_->GetLocationContext(city);
|
std::string region_context = context_service_->GetLocationContext(city);
|
||||||
spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}",
|
spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}",
|
||||||
city.city, city.country, region_context);
|
city.city, city.country, region_context);
|
||||||
|
|
||||||
enriched.push_back(
|
enriched.push_back(
|
||||||
EnrichedCity{.location = std::move(city),
|
EnrichedCity{.location = std::move(city),
|
||||||
|
|||||||
@@ -122,8 +122,8 @@ static bool ReadRequiredTrimmedStringField(const boost::json::object& obj,
|
|||||||
const boost::json::value* field = obj.if_contains(key);
|
const boost::json::value* field = obj.if_contains(key);
|
||||||
if (field == nullptr || !field->is_string()) {
|
if (field == nullptr || !field->is_string()) {
|
||||||
if (error_out != nullptr) {
|
if (error_out != nullptr) {
|
||||||
*error_out = "JSON field '" + std::string(key) +
|
*error_out =
|
||||||
"' is missing or not a string";
|
"JSON field '" + std::string(key) + "' is missing or not a string";
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -192,8 +192,7 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
return validation_error;
|
return validation_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ReadRequiredTrimmedStringField(obj, "name_local",
|
if (!ReadRequiredTrimmedStringField(obj, "name_local", brewery_out.name_local,
|
||||||
brewery_out.name_local,
|
|
||||||
&validation_error)) {
|
&validation_error)) {
|
||||||
return validation_error;
|
return validation_error;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ static constexpr size_t kPromptTokenSlack = 8;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using SamplerHandle = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
using SamplerHandle =
|
||||||
|
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||||
|
|
||||||
struct SamplerConfig {
|
struct SamplerConfig {
|
||||||
float temperature;
|
float temperature;
|
||||||
@@ -117,17 +118,10 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
|
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
|
||||||
kPromptTokenSlack);
|
kPromptTokenSlack);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int32_t token_count = llama_tokenize(
|
int32_t token_count = llama_tokenize(
|
||||||
vocab,
|
vocab, formatted_prompt.c_str(),
|
||||||
formatted_prompt.c_str(),
|
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
||||||
static_cast<int32_t>(formatted_prompt.size()),
|
static_cast<int32_t>(prompt_tokens.size()), true, true);
|
||||||
prompt_tokens.data(),
|
|
||||||
static_cast<int32_t>(prompt_tokens.size()),
|
|
||||||
true,
|
|
||||||
true);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If buffer too small, negative return indicates required size
|
* If buffer too small, negative return indicates required size
|
||||||
@@ -135,7 +129,6 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
if (token_count < 0) {
|
if (token_count < 0) {
|
||||||
prompt_tokens.resize(static_cast<size_t>(-token_count));
|
prompt_tokens.resize(static_cast<size_t>(-token_count));
|
||||||
|
|
||||||
|
|
||||||
token_count = llama_tokenize(
|
token_count = llama_tokenize(
|
||||||
vocab, formatted_prompt.c_str(),
|
vocab, formatted_prompt.c_str(),
|
||||||
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
||||||
|
|||||||
@@ -5,11 +5,11 @@
|
|||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <filesystem>
|
|
||||||
|
|
||||||
#include "data_model/application_options.h"
|
#include "data_model/application_options.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
@@ -30,9 +30,9 @@ void LlamaGenerator::ContextDeleter::operator()(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
LlamaGenerator::LlamaGenerator(
|
||||||
const std::string& model_path,
|
const ApplicationOptions& options, const std::string& model_path,
|
||||||
std::unique_ptr<IPromptFormatter> prompt_formatter)
|
std::unique_ptr<IPromptFormatter> prompt_formatter)
|
||||||
: rng_(std::random_device{}()),
|
: rng_(std::random_device{}()),
|
||||||
prompt_formatter_(std::move(prompt_formatter)) {
|
prompt_formatter_(std::move(prompt_formatter)) {
|
||||||
if (model_path.empty()) {
|
if (model_path.empty()) {
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt(
|
|||||||
return brewery_system_prompt_;
|
return brewery_system_prompt_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::ifstream prompt_file(prompt_file_path);
|
std::ifstream prompt_file(prompt_file_path);
|
||||||
if (!prompt_file.is_open()) {
|
if (!prompt_file.is_open()) {
|
||||||
spdlog::error(
|
spdlog::error(
|
||||||
|
|||||||
@@ -6,15 +6,15 @@
|
|||||||
|
|
||||||
#include "json_handling/json_loader.h"
|
#include "json_handling/json_loader.h"
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
|
#include <boost/json.hpp>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
|
||||||
#include <boost/json.hpp>
|
|
||||||
#include <spdlog/spdlog.h>
|
|
||||||
|
|
||||||
static std::string ReadRequiredString(const boost::json::object& object,
|
static std::string ReadRequiredString(const boost::json::object& object,
|
||||||
const char* key) {
|
const char* key) {
|
||||||
const boost::json::value* value = object.if_contains(key);
|
const boost::json::value* value = object.if_contains(key);
|
||||||
@@ -40,8 +40,8 @@ static std::vector<std::string> ReadRequiredStringArray(
|
|||||||
const boost::json::object& object, const char* key) {
|
const boost::json::object& object, const char* key) {
|
||||||
const boost::json::value* value = object.if_contains(key);
|
const boost::json::value* value = object.if_contains(key);
|
||||||
if (value == nullptr || !value->is_array()) {
|
if (value == nullptr || !value->is_array()) {
|
||||||
throw std::runtime_error(std::string("Missing or invalid string array field: ") +
|
throw std::runtime_error(
|
||||||
key);
|
std::string("Missing or invalid string array field: ") + key);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& array = value->as_array();
|
const auto& array = value->as_array();
|
||||||
@@ -49,8 +49,8 @@ static std::vector<std::string> ReadRequiredStringArray(
|
|||||||
items.reserve(array.size());
|
items.reserve(array.size());
|
||||||
for (const auto& item : array) {
|
for (const auto& item : array) {
|
||||||
if (!item.is_string()) {
|
if (!item.is_string()) {
|
||||||
throw std::runtime_error(std::string("Missing or invalid string array field: ") +
|
throw std::runtime_error(
|
||||||
key);
|
std::string("Missing or invalid string array field: ") + key);
|
||||||
}
|
}
|
||||||
items.emplace_back(item.as_string());
|
items.emplace_back(item.as_string());
|
||||||
}
|
}
|
||||||
@@ -98,8 +98,7 @@ std::vector<Location> JsonLoader::LoadLocations(
|
|||||||
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
|
.iso3166_2 = ReadRequiredString(object, "iso3166_2"),
|
||||||
.country = ReadRequiredString(object, "country"),
|
.country = ReadRequiredString(object, "country"),
|
||||||
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
|
.iso3166_1 = ReadRequiredString(object, "iso3166_1"),
|
||||||
.local_languages =
|
.local_languages = ReadRequiredStringArray(object, "local_languages"),
|
||||||
ReadRequiredStringArray(object, "local_languages"),
|
|
||||||
.latitude = ReadRequiredNumber(object, "latitude"),
|
.latitude = ReadRequiredNumber(object, "latitude"),
|
||||||
.longitude = ReadRequiredNumber(object, "longitude"),
|
.longitude = ReadRequiredNumber(object, "longitude"),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
* @brief CURLWebClient::Get() implementation.
|
* @brief CURLWebClient::Get() implementation.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "web_client/curl_web_client.h"
|
#include <curl/curl.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
@@ -11,7 +11,7 @@
|
|||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <curl/curl.h>
|
#include "web_client/curl_web_client.h"
|
||||||
|
|
||||||
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user