diff --git a/pipeline/format.sh b/pipeline/format.sh new file mode 100755 index 0000000..8aee13a --- /dev/null +++ b/pipeline/format.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Check for -y flag +SKIP_CONFIRM=false +if [[ "$1" == "-y" ]]; then + SKIP_CONFIRM=true +fi + +echo "WARNING: This script will format all .cpp, .h, .cxx, .cc .c, .hpp files in the src and includes directories." +echo "This script will overwrite the files with the formatted version." + +if [[ "$SKIP_CONFIRM" == false ]]; then + read -p "Do you want to continue? (y/n) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Aborted." + exit 1 + fi +fi + +if [ ! -f .clang-format ]; then + echo "ERROR: .clang-format file not found." + exit 1 +fi + +if ! command -v clang-format &>/dev/null; then + echo "ERROR: clang-format not found." + exit 1 +fi + +echo "Formatting files..." + +find includes src \( -name "*.cpp" -o -name "*.hpp" -o -name "*.h" -o -name "*.c" -o -name "*.cc" -o -name "*.cxx" \) | xargs clang-format -i + +echo "Done." + +exit 0 + diff --git a/pipeline/includes/data_generation/llama_generator.h b/pipeline/includes/data_generation/llama_generator.h index 07da334..1e648a3 100644 --- a/pipeline/includes/data_generation/llama_generator.h +++ b/pipeline/includes/data_generation/llama_generator.h @@ -32,7 +32,7 @@ class LlamaGenerator final : public DataGenerator { * * @param options Parsed application options. * @param model_path Filesystem path to GGUF model assets. - * @param prompt_formatter Formatter that produces model-specific prompts. + * @param prompt_formatter Formatter that produces model-specific prompts. */ LlamaGenerator(const ApplicationOptions& options, const std::string& model_path, @@ -100,24 +100,24 @@ class LlamaGenerator final : public DataGenerator { * @param system_prompt System role prompt. * @param prompt User prompt. * @param max_tokens Maximum tokens to generate. - * @param grammar Optional GBNF grammar constraining generated output. + * @param grammar Optional GBNF grammar constraining generated output. * @return Generated text. */ std::string Infer(const std::string& system_prompt, const std::string& prompt, - int max_tokens = kDefaultMaxTokens, - std::string_view grammar = {}); + int max_tokens = kDefaultMaxTokens, + std::string_view grammar = {}); /** * @brief Runs inference on an already-formatted prompt. * * @param formatted_prompt Prompt preformatted for model chat template. * @param max_tokens Maximum tokens to generate. - * @param grammar Optional GBNF grammar constraining generated output. + * @param grammar Optional GBNF grammar constraining generated output. * @return Generated text. */ std::string InferFormatted(const std::string& formatted_prompt, - int max_tokens = kDefaultMaxTokens, - std::string_view grammar = {}); + int max_tokens = kDefaultMaxTokens, + std::string_view grammar = {}); /** * @brief Loads the brewery system prompt from disk. @@ -125,7 +125,8 @@ class LlamaGenerator final : public DataGenerator { * @param prompt_file_path Prompt file path to try first. * @return Loaded prompt text. */ - std::string LoadBrewerySystemPrompt(const std::filesystem::path& prompt_file_path); + std::string LoadBrewerySystemPrompt( + const std::filesystem::path& prompt_file_path); ModelHandle model_; ContextHandle context_; diff --git a/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h b/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h index 7498397..df1ffff 100644 --- a/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h +++ b/pipeline/includes/data_generation/prompt_formatting/prompt_formatter.h @@ -13,6 +13,5 @@ class IPromptFormatter { virtual ~IPromptFormatter() = default; [[nodiscard]] virtual std::string Format( - std::string_view system_prompt, - std::string_view user_prompt) const = 0; + std::string_view system_prompt, std::string_view user_prompt) const = 0; }; diff --git a/pipeline/src/biergarten_data_generator/log_results.cc b/pipeline/src/biergarten_data_generator/log_results.cc index f44c875..975729e 100644 --- a/pipeline/src/biergarten_data_generator/log_results.cc +++ b/pipeline/src/biergarten_data_generator/log_results.cc @@ -17,8 +17,7 @@ void BiergartenDataGenerator::LogResults() const { index, location.city, location.country, location.state_province, location.iso3166_2, location.latitude, location.longitude); spdlog::info(" brewery_name_en=\"{}\"", brewery.name_en); - spdlog::info(" brewery_description_en=\"{}\"", - brewery.description_en); + spdlog::info(" brewery_description_en=\"{}\"", brewery.description_en); spdlog::info(" brewery_name_local=\"{}\"", brewery.name_local); spdlog::info(" brewery_description_local=\"{}\"", brewery.description_local); diff --git a/pipeline/src/biergarten_data_generator/run.cc b/pipeline/src/biergarten_data_generator/run.cc index 48f91ac..609c3e8 100644 --- a/pipeline/src/biergarten_data_generator/run.cc +++ b/pipeline/src/biergarten_data_generator/run.cc @@ -3,10 +3,10 @@ * @brief BiergartenDataGenerator::Run() implementation. */ -#include - #include +#include + #include "biergarten_data_generator.h" bool BiergartenDataGenerator::Run() { @@ -20,7 +20,7 @@ bool BiergartenDataGenerator::Run() { try { std::string region_context = context_service_->GetLocationContext(city); spdlog::debug("[Pipeline] Context for '{}' ({}) gathered:\n{}", - city.city, city.country, region_context); + city.city, city.country, region_context); enriched.push_back( EnrichedCity{.location = std::move(city), diff --git a/pipeline/src/data_generation/llama/helpers.cc b/pipeline/src/data_generation/llama/helpers.cc index 66f0c8e..8556b8d 100644 --- a/pipeline/src/data_generation/llama/helpers.cc +++ b/pipeline/src/data_generation/llama/helpers.cc @@ -122,8 +122,8 @@ static bool ReadRequiredTrimmedStringField(const boost::json::object& obj, const boost::json::value* field = obj.if_contains(key); if (field == nullptr || !field->is_string()) { if (error_out != nullptr) { - *error_out = "JSON field '" + std::string(key) + - "' is missing or not a string"; + *error_out = + "JSON field '" + std::string(key) + "' is missing or not a string"; } return false; } @@ -192,8 +192,7 @@ std::optional ValidateBreweryJson(const std::string& raw, return validation_error; } - if (!ReadRequiredTrimmedStringField(obj, "name_local", - brewery_out.name_local, + if (!ReadRequiredTrimmedStringField(obj, "name_local", brewery_out.name_local, &validation_error)) { return validation_error; } diff --git a/pipeline/src/data_generation/llama/infer.cc b/pipeline/src/data_generation/llama/infer.cc index 81a3754..2a2c116 100644 --- a/pipeline/src/data_generation/llama/infer.cc +++ b/pipeline/src/data_generation/llama/infer.cc @@ -22,7 +22,8 @@ static constexpr size_t kPromptTokenSlack = 8; namespace { -using SamplerHandle = std::unique_ptr; +using SamplerHandle = + std::unique_ptr; struct SamplerConfig { float temperature; @@ -117,17 +118,10 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, std::vector prompt_tokens(formatted_prompt.size() + kPromptTokenSlack); - - - int32_t token_count = llama_tokenize( - vocab, - formatted_prompt.c_str(), - static_cast(formatted_prompt.size()), - prompt_tokens.data(), - static_cast(prompt_tokens.size()), - true, - true); + vocab, formatted_prompt.c_str(), + static_cast(formatted_prompt.size()), prompt_tokens.data(), + static_cast(prompt_tokens.size()), true, true); /** * If buffer too small, negative return indicates required size @@ -135,7 +129,6 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt, if (token_count < 0) { prompt_tokens.resize(static_cast(-token_count)); - token_count = llama_tokenize( vocab, formatted_prompt.c_str(), static_cast(formatted_prompt.size()), prompt_tokens.data(), diff --git a/pipeline/src/data_generation/llama/llama_generator.cc b/pipeline/src/data_generation/llama/llama_generator.cc index 1903396..a854f48 100644 --- a/pipeline/src/data_generation/llama/llama_generator.cc +++ b/pipeline/src/data_generation/llama/llama_generator.cc @@ -5,11 +5,11 @@ #include "data_generation/llama_generator.h" +#include #include #include #include #include -#include #include "data_model/application_options.h" #include "llama.h" @@ -30,9 +30,9 @@ void LlamaGenerator::ContextDeleter::operator()( } } -LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, - const std::string& model_path, - std::unique_ptr prompt_formatter) +LlamaGenerator::LlamaGenerator( + const ApplicationOptions& options, const std::string& model_path, + std::unique_ptr prompt_formatter) : rng_(std::random_device{}()), prompt_formatter_(std::move(prompt_formatter)) { if (model_path.empty()) { diff --git a/pipeline/src/data_generation/llama/load_brewery_prompt.cc b/pipeline/src/data_generation/llama/load_brewery_prompt.cc index 242eda8..f59d590 100644 --- a/pipeline/src/data_generation/llama/load_brewery_prompt.cc +++ b/pipeline/src/data_generation/llama/load_brewery_prompt.cc @@ -25,7 +25,6 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt( return brewery_system_prompt_; } - std::ifstream prompt_file(prompt_file_path); if (!prompt_file.is_open()) { spdlog::error( diff --git a/pipeline/src/json_handling/json_loader.cc b/pipeline/src/json_handling/json_loader.cc index 16ed1af..5ca70e9 100644 --- a/pipeline/src/json_handling/json_loader.cc +++ b/pipeline/src/json_handling/json_loader.cc @@ -6,15 +6,15 @@ #include "json_handling/json_loader.h" +#include + +#include #include #include #include #include #include -#include -#include - static std::string ReadRequiredString(const boost::json::object& object, const char* key) { const boost::json::value* value = object.if_contains(key); @@ -40,8 +40,8 @@ static std::vector ReadRequiredStringArray( const boost::json::object& object, const char* key) { const boost::json::value* value = object.if_contains(key); if (value == nullptr || !value->is_array()) { - throw std::runtime_error(std::string("Missing or invalid string array field: ") + - key); + throw std::runtime_error( + std::string("Missing or invalid string array field: ") + key); } const auto& array = value->as_array(); @@ -49,8 +49,8 @@ static std::vector ReadRequiredStringArray( items.reserve(array.size()); for (const auto& item : array) { if (!item.is_string()) { - throw std::runtime_error(std::string("Missing or invalid string array field: ") + - key); + throw std::runtime_error( + std::string("Missing or invalid string array field: ") + key); } items.emplace_back(item.as_string()); } @@ -98,8 +98,7 @@ std::vector JsonLoader::LoadLocations( .iso3166_2 = ReadRequiredString(object, "iso3166_2"), .country = ReadRequiredString(object, "country"), .iso3166_1 = ReadRequiredString(object, "iso3166_1"), - .local_languages = - ReadRequiredStringArray(object, "local_languages"), + .local_languages = ReadRequiredStringArray(object, "local_languages"), .latitude = ReadRequiredNumber(object, "latitude"), .longitude = ReadRequiredNumber(object, "longitude"), }); diff --git a/pipeline/src/web_client/curl_web_client_get.cc b/pipeline/src/web_client/curl_web_client_get.cc index 334f0d6..2e178f7 100644 --- a/pipeline/src/web_client/curl_web_client_get.cc +++ b/pipeline/src/web_client/curl_web_client_get.cc @@ -3,7 +3,7 @@ * @brief CURLWebClient::Get() implementation. */ -#include "web_client/curl_web_client.h" +#include #include #include @@ -11,7 +11,7 @@ #include #include -#include +#include "web_client/curl_web_client.h" using CurlHandle = std::unique_ptr;