Add formatting script for c++

This commit is contained in:
Aaron Po
2026-04-19 15:27:26 -04:00
parent 898cc8971b
commit 9a2ecfea82
11 changed files with 74 additions and 47 deletions

38
pipeline/format.sh Executable file
View File

@@ -0,0 +1,38 @@
#!/bin/bash
# Check for -y flag
SKIP_CONFIRM=false
if [[ "$1" == "-y" ]]; then
SKIP_CONFIRM=true
fi
echo "WARNING: This script will format all .cpp, .h, .cxx, .cc .c, .hpp files in the src and includes directories."
echo "This script will overwrite the files with the formatted version."
if [[ "$SKIP_CONFIRM" == false ]]; then
read -p "Do you want to continue? (y/n) " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "Aborted."
exit 1
fi
fi
if [ ! -f .clang-format ]; then
echo "ERROR: .clang-format file not found."
exit 1
fi
if ! command -v clang-format &>/dev/null; then
echo "ERROR: clang-format not found."
exit 1
fi
echo "Formatting files..."
find includes src \( -name "*.cpp" -o -name "*.hpp" -o -name "*.h" -o -name "*.c" -o -name "*.cc" -o -name "*.cxx" \) | xargs clang-format -i
echo "Done."
exit 0

View File

@@ -125,7 +125,8 @@ class LlamaGenerator final : public DataGenerator {
* @param prompt_file_path Prompt file path to try first. * @param prompt_file_path Prompt file path to try first.
* @return Loaded prompt text. * @return Loaded prompt text.
*/ */
std::string LoadBrewerySystemPrompt(const std::filesystem::path& prompt_file_path); std::string LoadBrewerySystemPrompt(
const std::filesystem::path& prompt_file_path);
ModelHandle model_; ModelHandle model_;
ContextHandle context_; ContextHandle context_;

View File

@@ -13,6 +13,5 @@ class IPromptFormatter {
virtual ~IPromptFormatter() = default; virtual ~IPromptFormatter() = default;
[[nodiscard]] virtual std::string Format( [[nodiscard]] virtual std::string Format(
std::string_view system_prompt, std::string_view system_prompt, std::string_view user_prompt) const = 0;
std::string_view user_prompt) const = 0;
}; };

View File

@@ -17,8 +17,7 @@ void BiergartenDataGenerator::LogResults() const {
index, location.city, location.country, location.state_province, index, location.city, location.country, location.state_province,
location.iso3166_2, location.latitude, location.longitude); location.iso3166_2, location.latitude, location.longitude);
spdlog::info(" brewery_name_en=\"{}\"", brewery.name_en); spdlog::info(" brewery_name_en=\"{}\"", brewery.name_en);
spdlog::info(" brewery_description_en=\"{}\"", spdlog::info(" brewery_description_en=\"{}\"", brewery.description_en);
brewery.description_en);
spdlog::info(" brewery_name_local=\"{}\"", brewery.name_local); spdlog::info(" brewery_name_local=\"{}\"", brewery.name_local);
spdlog::info(" brewery_description_local=\"{}\"", spdlog::info(" brewery_description_local=\"{}\"",
brewery.description_local); brewery.description_local);

View File

@@ -3,10 +3,10 @@
* @brief BiergartenDataGenerator::Run() implementation. * @brief BiergartenDataGenerator::Run() implementation.
*/ */
#include <utility>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <utility>
#include "biergarten_data_generator.h" #include "biergarten_data_generator.h"
bool BiergartenDataGenerator::Run() { bool BiergartenDataGenerator::Run() {

View File

@@ -122,8 +122,8 @@ static bool ReadRequiredTrimmedStringField(const boost::json::object& obj,
const boost::json::value* field = obj.if_contains(key); const boost::json::value* field = obj.if_contains(key);
if (field == nullptr || !field->is_string()) { if (field == nullptr || !field->is_string()) {
if (error_out != nullptr) { if (error_out != nullptr) {
*error_out = "JSON field '" + std::string(key) + *error_out =
"' is missing or not a string"; "JSON field '" + std::string(key) + "' is missing or not a string";
} }
return false; return false;
} }
@@ -192,8 +192,7 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
return validation_error; return validation_error;
} }
if (!ReadRequiredTrimmedStringField(obj, "name_local", if (!ReadRequiredTrimmedStringField(obj, "name_local", brewery_out.name_local,
brewery_out.name_local,
&validation_error)) { &validation_error)) {
return validation_error; return validation_error;
} }

View File

@@ -22,7 +22,8 @@ static constexpr size_t kPromptTokenSlack = 8;
namespace { namespace {
using SamplerHandle = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>; using SamplerHandle =
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
struct SamplerConfig { struct SamplerConfig {
float temperature; float temperature;
@@ -117,17 +118,10 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
std::vector<llama_token> prompt_tokens(formatted_prompt.size() + std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
kPromptTokenSlack); kPromptTokenSlack);
int32_t token_count = llama_tokenize( int32_t token_count = llama_tokenize(
vocab, vocab, formatted_prompt.c_str(),
formatted_prompt.c_str(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
static_cast<int32_t>(formatted_prompt.size()), static_cast<int32_t>(prompt_tokens.size()), true, true);
prompt_tokens.data(),
static_cast<int32_t>(prompt_tokens.size()),
true,
true);
/** /**
* If buffer too small, negative return indicates required size * If buffer too small, negative return indicates required size
@@ -135,7 +129,6 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
if (token_count < 0) { if (token_count < 0) {
prompt_tokens.resize(static_cast<size_t>(-token_count)); prompt_tokens.resize(static_cast<size_t>(-token_count));
token_count = llama_tokenize( token_count = llama_tokenize(
vocab, formatted_prompt.c_str(), vocab, formatted_prompt.c_str(),
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(), static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),

View File

@@ -5,11 +5,11 @@
#include "data_generation/llama_generator.h" #include "data_generation/llama_generator.h"
#include <filesystem>
#include <memory> #include <memory>
#include <random> #include <random>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <filesystem>
#include "data_model/application_options.h" #include "data_model/application_options.h"
#include "llama.h" #include "llama.h"
@@ -30,8 +30,8 @@ void LlamaGenerator::ContextDeleter::operator()(
} }
} }
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options, LlamaGenerator::LlamaGenerator(
const std::string& model_path, const ApplicationOptions& options, const std::string& model_path,
std::unique_ptr<IPromptFormatter> prompt_formatter) std::unique_ptr<IPromptFormatter> prompt_formatter)
: rng_(std::random_device{}()), : rng_(std::random_device{}()),
prompt_formatter_(std::move(prompt_formatter)) { prompt_formatter_(std::move(prompt_formatter)) {

View File

@@ -25,7 +25,6 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt(
return brewery_system_prompt_; return brewery_system_prompt_;
} }
std::ifstream prompt_file(prompt_file_path); std::ifstream prompt_file(prompt_file_path);
if (!prompt_file.is_open()) { if (!prompt_file.is_open()) {
spdlog::error( spdlog::error(

View File

@@ -6,15 +6,15 @@
#include "json_handling/json_loader.h" #include "json_handling/json_loader.h"
#include <spdlog/spdlog.h>
#include <boost/json.hpp>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <boost/json.hpp>
#include <spdlog/spdlog.h>
static std::string ReadRequiredString(const boost::json::object& object, static std::string ReadRequiredString(const boost::json::object& object,
const char* key) { const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
@@ -40,8 +40,8 @@ static std::vector<std::string> ReadRequiredStringArray(
const boost::json::object& object, const char* key) { const boost::json::object& object, const char* key) {
const boost::json::value* value = object.if_contains(key); const boost::json::value* value = object.if_contains(key);
if (value == nullptr || !value->is_array()) { if (value == nullptr || !value->is_array()) {
throw std::runtime_error(std::string("Missing or invalid string array field: ") + throw std::runtime_error(
key); std::string("Missing or invalid string array field: ") + key);
} }
const auto& array = value->as_array(); const auto& array = value->as_array();
@@ -49,8 +49,8 @@ static std::vector<std::string> ReadRequiredStringArray(
items.reserve(array.size()); items.reserve(array.size());
for (const auto& item : array) { for (const auto& item : array) {
if (!item.is_string()) { if (!item.is_string()) {
throw std::runtime_error(std::string("Missing or invalid string array field: ") + throw std::runtime_error(
key); std::string("Missing or invalid string array field: ") + key);
} }
items.emplace_back(item.as_string()); items.emplace_back(item.as_string());
} }
@@ -98,8 +98,7 @@ std::vector<Location> JsonLoader::LoadLocations(
.iso3166_2 = ReadRequiredString(object, "iso3166_2"), .iso3166_2 = ReadRequiredString(object, "iso3166_2"),
.country = ReadRequiredString(object, "country"), .country = ReadRequiredString(object, "country"),
.iso3166_1 = ReadRequiredString(object, "iso3166_1"), .iso3166_1 = ReadRequiredString(object, "iso3166_1"),
.local_languages = .local_languages = ReadRequiredStringArray(object, "local_languages"),
ReadRequiredStringArray(object, "local_languages"),
.latitude = ReadRequiredNumber(object, "latitude"), .latitude = ReadRequiredNumber(object, "latitude"),
.longitude = ReadRequiredNumber(object, "longitude"), .longitude = ReadRequiredNumber(object, "longitude"),
}); });

View File

@@ -3,7 +3,7 @@
* @brief CURLWebClient::Get() implementation. * @brief CURLWebClient::Get() implementation.
*/ */
#include "web_client/curl_web_client.h" #include <curl/curl.h>
#include <cstdint> #include <cstdint>
#include <limits> #include <limits>
@@ -11,7 +11,7 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <curl/curl.h> #include "web_client/curl_web_client.h"
using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>; using CurlHandle = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;