mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
cleanup
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
#ifndef BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
#ifndef BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
||||||
#define BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
#define BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file data_generation/llama_generator.h
|
* @file data_generation/llama_generator.h
|
||||||
* @brief llama.cpp-backed implementation of DataGenerator.
|
* @brief llama.cpp-backed implementation of DataGenerator.
|
||||||
@@ -34,12 +36,16 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
LlamaGenerator(const ApplicationOptions& options,
|
LlamaGenerator(const ApplicationOptions& options,
|
||||||
const std::string& model_path);
|
const std::string& model_path);
|
||||||
|
|
||||||
/// @brief Releases model/context resources.
|
|
||||||
~LlamaGenerator() override;
|
~LlamaGenerator() override;
|
||||||
|
|
||||||
|
// disable copy constructor
|
||||||
LlamaGenerator(const LlamaGenerator&) = delete;
|
LlamaGenerator(const LlamaGenerator&) = delete;
|
||||||
|
|
||||||
|
// disable copy assignment operator
|
||||||
LlamaGenerator& operator=(const LlamaGenerator&) = delete;
|
LlamaGenerator& operator=(const LlamaGenerator&) = delete;
|
||||||
|
// disable move constructor
|
||||||
LlamaGenerator(LlamaGenerator&&) = delete;
|
LlamaGenerator(LlamaGenerator&&) = delete;
|
||||||
|
// disable move assignment operator
|
||||||
LlamaGenerator& operator=(LlamaGenerator&&) = delete;
|
LlamaGenerator& operator=(LlamaGenerator&&) = delete;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -61,7 +67,7 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
UserResult GenerateUser(const std::string& locale) override;
|
UserResult GenerateUser(const std::string& locale) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr int kDefaultMaxTokens = 10000;
|
static constexpr int32_t kDefaultMaxTokens = 10000;
|
||||||
static constexpr float kDefaultSamplingTopP = 0.95F;
|
static constexpr float kDefaultSamplingTopP = 0.95F;
|
||||||
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
||||||
static constexpr uint32_t kDefaultContextSize = 8192;
|
static constexpr uint32_t kDefaultContextSize = 8192;
|
||||||
@@ -69,25 +75,16 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
struct ModelDeleter {
|
struct ModelDeleter {
|
||||||
void operator()(llama_model* model) const noexcept;
|
void operator()(llama_model* model) const noexcept;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ContextDeleter {
|
struct ContextDeleter {
|
||||||
void operator()(llama_context* context) const noexcept;
|
void operator()(llama_context* context) const noexcept;
|
||||||
};
|
};
|
||||||
|
struct SamplerDeleter {
|
||||||
|
void operator()(llama_sampler* sampler) const noexcept;
|
||||||
|
};
|
||||||
|
|
||||||
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
||||||
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
||||||
|
using SamplerChainHandle = std::unique_ptr<llama_sampler, SamplerDeleter>;
|
||||||
struct SamplerState {
|
|
||||||
SamplerState() = default;
|
|
||||||
~SamplerState();
|
|
||||||
|
|
||||||
SamplerState(const SamplerState&) = delete;
|
|
||||||
SamplerState& operator=(const SamplerState&) = delete;
|
|
||||||
SamplerState(SamplerState&&) = delete;
|
|
||||||
SamplerState& operator=(SamplerState&&) = delete;
|
|
||||||
|
|
||||||
llama_sampler* chain = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Loads model and prepares inference context.
|
* @brief Loads model and prepares inference context.
|
||||||
@@ -126,12 +123,12 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
* @param prompt_file_path Prompt file path to try first.
|
* @param prompt_file_path Prompt file path to try first.
|
||||||
* @return Loaded prompt text.
|
* @return Loaded prompt text.
|
||||||
*/
|
*/
|
||||||
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
|
std::string LoadBrewerySystemPrompt(const std::filesystem::path& prompt_file_path);
|
||||||
|
|
||||||
ModelHandle model_;
|
ModelHandle model_;
|
||||||
ContextHandle context_;
|
ContextHandle context_;
|
||||||
/// @brief Persistent sampler chain reused across inference calls.
|
/// @brief Persistent sampler chain reused across inference calls.
|
||||||
std::unique_ptr<SamplerState> sampler_;
|
SamplerChainHandle sampler_;
|
||||||
float sampling_temperature_ = 1.0F;
|
float sampling_temperature_ = 1.0F;
|
||||||
float sampling_top_p_ = kDefaultSamplingTopP;
|
float sampling_top_p_ = kDefaultSamplingTopP;
|
||||||
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
|
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
|
||||||
|
|||||||
@@ -7,14 +7,14 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
typedef int llama_token;
|
typedef int32_t llama_token;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Normalizes and truncates regional context.
|
* @brief Normalizes and truncates regional context.
|
||||||
@@ -23,18 +23,8 @@ typedef int llama_token;
|
|||||||
* @param max_chars Maximum output length.
|
* @param max_chars Maximum output length.
|
||||||
* @return Processed region context.
|
* @return Processed region context.
|
||||||
*/
|
*/
|
||||||
std::string PrepareRegionContextPublic(std::string_view region_context,
|
std::string PrepareRegionContext(std::string_view region_context,
|
||||||
std::size_t max_chars = 2000);
|
size_t max_chars = 2000);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Parses a response expected to contain two logical lines.
|
|
||||||
*
|
|
||||||
* @param raw Raw model output.
|
|
||||||
* @param error_message Error message thrown on parse failure.
|
|
||||||
* @return Pair containing first and second parsed fields.
|
|
||||||
*/
|
|
||||||
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
|
|
||||||
const std::string& raw, const std::string& error_message);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Applies model chat template to system and user prompts.
|
* @brief Applies model chat template to system and user prompts.
|
||||||
@@ -44,9 +34,9 @@ std::pair<std::string, std::string> ParseTwoLineResponsePublic(
|
|||||||
* @param user_prompt User prompt text.
|
* @param user_prompt User prompt text.
|
||||||
* @return Model-formatted prompt.
|
* @return Model-formatted prompt.
|
||||||
*/
|
*/
|
||||||
std::string ToChatPromptPublic(const llama_model* model,
|
std::string ToChatPrompt(const llama_model* model,
|
||||||
const std::string& system_prompt,
|
const std::string& system_prompt,
|
||||||
const std::string& user_prompt);
|
const std::string& user_prompt);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Decodes a sampled token and appends it to output text.
|
* @brief Decodes a sampled token and appends it to output text.
|
||||||
@@ -55,8 +45,8 @@ std::string ToChatPromptPublic(const llama_model* model,
|
|||||||
* @param token Sampled token id.
|
* @param token Sampled token id.
|
||||||
* @param output Output text buffer.
|
* @param output Output text buffer.
|
||||||
*/
|
*/
|
||||||
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
|
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||||
std::string& output);
|
std::string& output);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Validates and parses brewery JSON output.
|
* @brief Validates and parses brewery JSON output.
|
||||||
@@ -66,9 +56,9 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
|
|||||||
* @param description_out Parsed brewery description.
|
* @param description_out Parsed brewery description.
|
||||||
* @return Validation error message if invalid, or std::nullopt on success.
|
* @return Validation error message if invalid, or std::nullopt on success.
|
||||||
*/
|
*/
|
||||||
std::optional<std::string> ValidateBreweryJsonPublic(
|
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||||
const std::string& raw, std::string& name_out,
|
std::string& name_out,
|
||||||
std::string& description_out);
|
std::string& description_out);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Extracts the last balanced JSON object from text.
|
* @brief Extracts the last balanced JSON object from text.
|
||||||
@@ -76,6 +66,6 @@ std::optional<std::string> ValidateBreweryJsonPublic(
|
|||||||
* @param text Input text.
|
* @param text Input text.
|
||||||
* @return Extracted JSON object or an empty string if none exists.
|
* @return Extracted JSON object or an empty string if none exists.
|
||||||
*/
|
*/
|
||||||
std::string ExtractLastJsonObjectPublic(const std::string& text);
|
std::string ExtractLastJsonObject(const std::string& text);
|
||||||
|
|
||||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
||||||
@@ -42,7 +42,7 @@ class MockGenerator final : public DataGenerator {
|
|||||||
* @param location City and country names.
|
* @param location City and country names.
|
||||||
* @return Deterministic hash value.
|
* @return Deterministic hash value.
|
||||||
*/
|
*/
|
||||||
static std::size_t DeterministicHash(const Location& location);
|
static size_t DeterministicHash(const Location& location);
|
||||||
|
|
||||||
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
|
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
|
||||||
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
|
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
|
||||||
|
|||||||
@@ -3,18 +3,17 @@
|
|||||||
* @brief BiergartenDataGenerator::QueryCitiesWithCountries() implementation.
|
* @brief BiergartenDataGenerator::QueryCitiesWithCountries() implementation.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "biergarten_data_generator.h"
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include "biergarten_data_generator.h"
|
||||||
|
|
||||||
#include "json_handling/json_loader.h"
|
#include "json_handling/json_loader.h"
|
||||||
|
|
||||||
static constexpr std::size_t kBreweryAmount = 4;
|
static constexpr size_t kBreweryAmount = 4;
|
||||||
|
|
||||||
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
||||||
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
||||||
@@ -24,11 +23,12 @@ std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
|||||||
auto all_locations = JsonLoader::LoadLocations(locations_path);
|
auto all_locations = JsonLoader::LoadLocations(locations_path);
|
||||||
spdlog::info(" Locations available: {}", all_locations.size());
|
spdlog::info(" Locations available: {}", all_locations.size());
|
||||||
|
|
||||||
const std::size_t sample_count =
|
const size_t sample_count = std::min(kBreweryAmount, all_locations.size());
|
||||||
std::min(kBreweryAmount, all_locations.size());
|
|
||||||
const auto sample_count_signed =
|
const auto sample_count_signed =
|
||||||
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
|
static_cast<std::iter_difference_t<decltype(all_locations.cbegin())>>(
|
||||||
sample_count);
|
sample_count);
|
||||||
|
|
||||||
std::vector<Location> sampled_locations;
|
std::vector<Location> sampled_locations;
|
||||||
sampled_locations.reserve(sample_count);
|
sampled_locations.reserve(sample_count);
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,12 @@
|
|||||||
|
|
||||||
static std::string ExtractFinalJsonPayload(std::string raw_response) {
|
static std::string ExtractFinalJsonPayload(std::string raw_response) {
|
||||||
auto trim = [](const std::string_view text) -> std::string_view {
|
auto trim = [](const std::string_view text) -> std::string_view {
|
||||||
const std::size_t first = text.find_first_not_of(" \t\n\r");
|
const size_t first = text.find_first_not_of(" \t\n\r");
|
||||||
if (first == std::string_view::npos) {
|
if (first == std::string_view::npos) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::size_t last = text.find_last_not_of(" \t\n\r");
|
const size_t last = text.find_last_not_of(" \t\n\r");
|
||||||
return text.substr(first, last - first + 1);
|
return text.substr(first, last - first + 1);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -31,10 +31,10 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
|
|||||||
"<|think|>", "<think|>", "<|turn|>",
|
"<|think|>", "<think|>", "<|turn|>",
|
||||||
"<turn|>", "<channel|>", "<|channel|>"};
|
"<turn|>", "<channel|>", "<|channel|>"};
|
||||||
|
|
||||||
std::size_t separator_pos = std::string::npos;
|
size_t separator_pos = std::string::npos;
|
||||||
std::size_t separator_length = 0;
|
size_t separator_length = 0;
|
||||||
for (const std::string_view token : separator_tokens) {
|
for (const std::string_view token : separator_tokens) {
|
||||||
const std::size_t candidate_pos = raw_response.rfind(token);
|
const size_t candidate_pos = raw_response.rfind(token);
|
||||||
if (candidate_pos != std::string::npos &&
|
if (candidate_pos != std::string::npos &&
|
||||||
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
|
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
|
||||||
separator_pos = candidate_pos;
|
separator_pos = candidate_pos;
|
||||||
@@ -48,10 +48,10 @@ static std::string ExtractFinalJsonPayload(std::string raw_response) {
|
|||||||
|
|
||||||
const std::string_view trimmed = trim(raw_response);
|
const std::string_view trimmed = trim(raw_response);
|
||||||
const std::string json_candidate =
|
const std::string json_candidate =
|
||||||
ExtractLastJsonObjectPublic(std::string(trimmed));
|
ExtractLastJsonObject(std::string(trimmed));
|
||||||
|
|
||||||
if (!json_candidate.empty()) {
|
if (!json_candidate.empty()) {
|
||||||
return ExtractLastJsonObjectPublic(std::string(trimmed));
|
return json_candidate;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::string(trimmed);
|
return std::string(trimmed);
|
||||||
@@ -63,7 +63,7 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
* Preprocess and truncate region context to manageable size
|
* Preprocess and truncate region context to manageable size
|
||||||
*/
|
*/
|
||||||
const std::string safe_region_context =
|
const std::string safe_region_context =
|
||||||
PrepareRegionContextPublic(region_context);
|
PrepareRegionContext(region_context);
|
||||||
|
|
||||||
const std::string country_suffix =
|
const std::string country_suffix =
|
||||||
location.country.empty() ? std::string{}
|
location.country.empty() ? std::string{}
|
||||||
@@ -118,7 +118,7 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
std::string description;
|
std::string description;
|
||||||
const std::string json_only = ExtractFinalJsonPayload(raw);
|
const std::string json_only = ExtractFinalJsonPayload(raw);
|
||||||
const std::optional<std::string> validation_error =
|
const std::optional<std::string> validation_error =
|
||||||
ValidateBreweryJsonPublic(json_only, name, description);
|
ValidateBreweryJson(json_only, name, description);
|
||||||
if (!validation_error.has_value()) {
|
if (!validation_error.has_value()) {
|
||||||
// Success: return parsed brewery data
|
// Success: return parsed brewery data
|
||||||
return BreweryResult{.name = std::move(name),
|
return BreweryResult{.name = std::move(name),
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator_helpers.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -25,12 +25,12 @@
|
|||||||
*/
|
*/
|
||||||
static std::string Trim(std::string_view value) {
|
static std::string Trim(std::string_view value) {
|
||||||
constexpr std::string_view whitespace = " \t\n\r\f\v";
|
constexpr std::string_view whitespace = " \t\n\r\f\v";
|
||||||
const std::size_t first_index = value.find_first_not_of(whitespace);
|
const size_t first_index = value.find_first_not_of(whitespace);
|
||||||
if (first_index == std::string_view::npos) {
|
if (first_index == std::string_view::npos) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::size_t last_index = value.find_last_not_of(whitespace);
|
const size_t last_index = value.find_last_not_of(whitespace);
|
||||||
return std::string(value.substr(first_index, last_index - first_index + 1));
|
return std::string(value.substr(first_index, last_index - first_index + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ static std::string CondenseWhitespace(std::string_view text) {
|
|||||||
out.reserve(text.size());
|
out.reserve(text.size());
|
||||||
|
|
||||||
bool pending_space = false;
|
bool pending_space = false;
|
||||||
for (const unsigned char chr : text) {
|
for (const char chr : text) {
|
||||||
if (std::isspace(chr) != 0) {
|
if (std::isspace(chr) != 0) {
|
||||||
if (!out.empty()) {
|
if (!out.empty()) {
|
||||||
pending_space = true;
|
pending_space = true;
|
||||||
@@ -55,7 +55,7 @@ static std::string CondenseWhitespace(std::string_view text) {
|
|||||||
out.push_back(' ');
|
out.push_back(' ');
|
||||||
pending_space = false;
|
pending_space = false;
|
||||||
}
|
}
|
||||||
out.push_back(static_cast<char>(chr));
|
out.push_back(chr);
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
@@ -65,8 +65,8 @@ static std::string CondenseWhitespace(std::string_view text) {
|
|||||||
* Truncate region context to fit within max length while preserving word
|
* Truncate region context to fit within max length while preserving word
|
||||||
* boundaries
|
* boundaries
|
||||||
*/
|
*/
|
||||||
static std::string PrepareRegionContext(std::string_view region_context,
|
std::string PrepareRegionContext(std::string_view region_context,
|
||||||
const size_t max_chars) {
|
const size_t max_chars) {
|
||||||
std::string normalized = CondenseWhitespace(region_context);
|
std::string normalized = CondenseWhitespace(region_context);
|
||||||
if (normalized.size() <= max_chars) {
|
if (normalized.size() <= max_chars) {
|
||||||
return normalized;
|
return normalized;
|
||||||
@@ -82,11 +82,10 @@ static std::string PrepareRegionContext(std::string_view region_context,
|
|||||||
return normalized;
|
return normalized;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string ToChatPrompt(const llama_model* model,
|
std::string ToChatPrompt(const llama_model* model,
|
||||||
const std::string& system_prompt,
|
const std::string& system_prompt,
|
||||||
const std::string& user_prompt) {
|
const std::string& user_prompt) {
|
||||||
std::string combined_prompt;
|
std::string combined_prompt = system_prompt;
|
||||||
combined_prompt.append(system_prompt);
|
|
||||||
combined_prompt.append("\n\n");
|
combined_prompt.append("\n\n");
|
||||||
combined_prompt.append(user_prompt);
|
combined_prompt.append(user_prompt);
|
||||||
|
|
||||||
@@ -127,7 +126,7 @@ static std::string ToChatPrompt(const llama_model* model,
|
|||||||
int32_t template_result = apply_template_with_resize(messages.data(), 2);
|
int32_t template_result = apply_template_with_resize(messages.data(), 2);
|
||||||
|
|
||||||
if (template_result >= 0) {
|
if (template_result >= 0) {
|
||||||
return {buffer.data(), static_cast<std::size_t>(template_result)};
|
return {buffer.data(), static_cast<size_t>(template_result)};
|
||||||
}
|
}
|
||||||
|
|
||||||
spdlog::warn(
|
spdlog::warn(
|
||||||
@@ -151,74 +150,114 @@ static std::string ToChatPrompt(const llama_model* model,
|
|||||||
return combined_prompt;
|
return combined_prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {buffer.data(), static_cast<std::size_t>(template_result)};
|
return {buffer.data(), static_cast<size_t>(template_result)};
|
||||||
}
|
}
|
||||||
|
|
||||||
static void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||||
std::string& output) {
|
std::string& output) {
|
||||||
std::array<char, 256> buffer{};
|
constexpr size_t initial_buffer_size = 256;
|
||||||
|
|
||||||
|
std::array<char, initial_buffer_size> buffer{};
|
||||||
|
|
||||||
|
// serialize the sampled token into UTF-8 bytes
|
||||||
|
|
||||||
|
auto buffer_too_small = [](int32_t result) -> bool { return result < 0; };
|
||||||
|
|
||||||
int32_t bytes =
|
int32_t bytes =
|
||||||
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
|
llama_token_to_piece(vocab, token, buffer.data(), buffer.size(), 0, true);
|
||||||
|
|
||||||
if (bytes < 0) {
|
if (!buffer_too_small(bytes)) {
|
||||||
std::vector<char> dynamic_buffer(static_cast<std::size_t>(-bytes));
|
// Append the decoded bytes from the stack buffer.
|
||||||
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
|
output.append(buffer.data(), static_cast<size_t>(bytes));
|
||||||
static_cast<int32_t>(dynamic_buffer.size()), 0,
|
|
||||||
true);
|
|
||||||
if (bytes < 0) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"LlamaGenerator: failed to decode sampled token piece");
|
|
||||||
}
|
|
||||||
|
|
||||||
output.append(dynamic_buffer.data(), static_cast<std::size_t>(bytes));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
output.append(buffer.data(), static_cast<std::size_t>(bytes));
|
const int32_t required_size = -bytes;
|
||||||
|
std::vector<char> dynamic_buffer(static_cast<size_t>(required_size));
|
||||||
|
|
||||||
|
// Retry token decoding against the larger heap buffer.
|
||||||
|
bytes = llama_token_to_piece(vocab, token, dynamic_buffer.data(),
|
||||||
|
static_cast<int32_t>(dynamic_buffer.size()), 0,
|
||||||
|
true);
|
||||||
|
|
||||||
|
if (!buffer_too_small(bytes)) {
|
||||||
|
output.append(dynamic_buffer.data(), static_cast<size_t>(bytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LlamaGenerator: failed to decode sampled token piece");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shared parser used by the public extractor and JSON validation.
|
||||||
static bool ExtractLastJsonObject(const std::string& text,
|
static bool ExtractLastJsonObject(const std::string& text,
|
||||||
std::string& json_out) {
|
std::string& json_out) {
|
||||||
std::size_t start = std::string::npos;
|
// Remember where the most recent balanced object started.
|
||||||
|
size_t start = std::string::npos;
|
||||||
|
|
||||||
|
// Track nested braces outside of quoted strings.
|
||||||
int depth = 0;
|
int depth = 0;
|
||||||
|
|
||||||
|
// Track whether the scan is currently inside a quoted string.
|
||||||
bool in_string = false;
|
bool in_string = false;
|
||||||
|
|
||||||
|
// Track escape sequences so quotes inside strings are handled correctly.
|
||||||
bool escaped = false;
|
bool escaped = false;
|
||||||
|
|
||||||
|
// Record whether at least one complete object was found.
|
||||||
bool found = false;
|
bool found = false;
|
||||||
|
|
||||||
|
// Keep the latest complete object candidate.
|
||||||
std::string candidate;
|
std::string candidate;
|
||||||
|
|
||||||
for (std::size_t i = 0; i < text.size(); ++i) {
|
// Scan the input text one character at a time.
|
||||||
const char ch = text[i];
|
for (size_t i = 0; i < text.size(); ++i) {
|
||||||
|
// Inspect the current character.
|
||||||
|
const char chr = text[i];
|
||||||
|
|
||||||
|
// Inside a string literal, only escapes and quotes affect state.
|
||||||
if (in_string) {
|
if (in_string) {
|
||||||
if (escaped) {
|
if (escaped) {
|
||||||
|
// The current character was escaped, so clear the escape flag.
|
||||||
escaped = false;
|
escaped = false;
|
||||||
} else if (ch == '\\') {
|
} else if (chr == '\\') {
|
||||||
|
// Mark the next character as escaped.
|
||||||
escaped = true;
|
escaped = true;
|
||||||
} else if (ch == '"') {
|
} else if (chr == '"') {
|
||||||
|
// Closing quote ends the string literal.
|
||||||
in_string = false;
|
in_string = false;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ch == '"') {
|
// Opening quotes enter string mode.
|
||||||
|
if (chr == '"') {
|
||||||
in_string = true;
|
in_string = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ch == '{') {
|
// Opening braces begin or nest a JSON object.
|
||||||
|
if (chr == '{') {
|
||||||
if (depth == 0) {
|
if (depth == 0) {
|
||||||
|
// Record the start of the outermost object.
|
||||||
start = i;
|
start = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increase nesting depth for the active object.
|
||||||
++depth;
|
++depth;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ch == '}') {
|
// Closing braces may complete an object.
|
||||||
|
if (chr == '}') {
|
||||||
if (depth == 0) {
|
if (depth == 0) {
|
||||||
|
// Ignore stray closing braces.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Drop one level of nesting.
|
||||||
--depth;
|
--depth;
|
||||||
if (depth == 0 && start != std::string::npos) {
|
if (depth == 0 && start != std::string::npos) {
|
||||||
|
// Capture the latest complete object seen so far.
|
||||||
candidate = text.substr(start, i - start + 1);
|
candidate = text.substr(start, i - start + 1);
|
||||||
found = true;
|
found = true;
|
||||||
}
|
}
|
||||||
@@ -229,22 +268,14 @@ static bool ExtractLastJsonObject(const std::string& text,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the captured object text to the caller.
|
||||||
json_out = std::move(candidate);
|
json_out = std::move(candidate);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ExtractLastJsonObjectPublic(const std::string& text) {
|
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||||
std::string extracted;
|
std::string& name_out,
|
||||||
if (ExtractLastJsonObject(text, extracted)) {
|
std::string& description_out) {
|
||||||
return extracted;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::optional<std::string> ValidateBreweryJson(
|
|
||||||
const std::string& raw, std::string& name_out,
|
|
||||||
std::string& description_out) {
|
|
||||||
auto validate_object = [&](const boost::json::value& jv,
|
auto validate_object = [&](const boost::json::value& jv,
|
||||||
std::string& error_out) -> bool {
|
std::string& error_out) -> bool {
|
||||||
if (!jv.is_object()) {
|
if (!jv.is_object()) {
|
||||||
@@ -281,9 +312,11 @@ static std::optional<std::string> ValidateBreweryJson(
|
|||||||
|
|
||||||
std::string name_lower = name_out;
|
std::string name_lower = name_out;
|
||||||
std::string description_lower = description_out;
|
std::string description_lower = description_out;
|
||||||
|
|
||||||
std::transform(
|
std::transform(
|
||||||
name_lower.begin(), name_lower.end(), name_lower.begin(),
|
name_lower.begin(), name_lower.end(), name_lower.begin(),
|
||||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||||
|
|
||||||
std::transform(description_lower.begin(), description_lower.end(),
|
std::transform(description_lower.begin(), description_lower.end(),
|
||||||
description_lower.begin(), [](unsigned char c) {
|
description_lower.begin(), [](unsigned char c) {
|
||||||
return static_cast<char>(std::tolower(c));
|
return static_cast<char>(std::tolower(c));
|
||||||
@@ -327,25 +360,12 @@ static std::optional<std::string> ValidateBreweryJson(
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward declarations for helper functions exposed to other translation units
|
std::string ExtractLastJsonObject(const std::string& text) {
|
||||||
std::string PrepareRegionContextPublic(std::string_view region_context,
|
// Reuse the internal parser and return an empty string if none was found.
|
||||||
std::size_t max_chars) {
|
std::string extracted;
|
||||||
return PrepareRegionContext(region_context, max_chars);
|
if (ExtractLastJsonObject(text, extracted)) {
|
||||||
}
|
return extracted;
|
||||||
|
}
|
||||||
|
|
||||||
std::string ToChatPromptPublic(const llama_model* model,
|
return {};
|
||||||
const std::string& system_prompt,
|
|
||||||
const std::string& user_prompt) {
|
|
||||||
return ToChatPrompt(model, system_prompt, user_prompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
|
|
||||||
std::string& output) {
|
|
||||||
AppendTokenPiece(vocab, token, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<std::string> ValidateBreweryJsonPublic(
|
|
||||||
const std::string& raw, std::string& name_out,
|
|
||||||
std::string& description_out) {
|
|
||||||
return ValidateBreweryJson(raw, name_out, description_out);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,12 +17,12 @@
|
|||||||
#include "data_generation/llama_generator_helpers.h"
|
#include "data_generation/llama_generator_helpers.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
static constexpr std::size_t kPromptTokenSlack = 8;
|
static constexpr size_t kPromptTokenSlack = 8;
|
||||||
|
|
||||||
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
||||||
const std::string& prompt,
|
const std::string& prompt,
|
||||||
const int max_tokens) {
|
const int max_tokens) {
|
||||||
return InferFormatted(ToChatPromptPublic(model_.get(), system_prompt, prompt),
|
return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt),
|
||||||
max_tokens);
|
max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,16 +54,26 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
|
std::vector<llama_token> prompt_tokens(formatted_prompt.size() +
|
||||||
kPromptTokenSlack);
|
kPromptTokenSlack);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int32_t token_count = llama_tokenize(
|
int32_t token_count = llama_tokenize(
|
||||||
vocab, formatted_prompt.c_str(),
|
vocab,
|
||||||
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
formatted_prompt.c_str(),
|
||||||
static_cast<int32_t>(prompt_tokens.size()), true, true);
|
static_cast<int32_t>(formatted_prompt.size()),
|
||||||
|
prompt_tokens.data(),
|
||||||
|
static_cast<int32_t>(prompt_tokens.size()),
|
||||||
|
true,
|
||||||
|
true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If buffer too small, negative return indicates required size
|
* If buffer too small, negative return indicates required size
|
||||||
*/
|
*/
|
||||||
if (token_count < 0) {
|
if (token_count < 0) {
|
||||||
prompt_tokens.resize(static_cast<std::size_t>(-token_count));
|
prompt_tokens.resize(static_cast<size_t>(-token_count));
|
||||||
|
|
||||||
|
|
||||||
token_count = llama_tokenize(
|
token_count = llama_tokenize(
|
||||||
vocab, formatted_prompt.c_str(),
|
vocab, formatted_prompt.c_str(),
|
||||||
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
static_cast<int32_t>(formatted_prompt.size()), prompt_tokens.data(),
|
||||||
@@ -91,6 +101,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
const int32_t effective_max_tokens =
|
const int32_t effective_max_tokens =
|
||||||
std::max(1, std::min(max_tokens, n_ctx - 1));
|
std::max(1, std::min(max_tokens, n_ctx - 1));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prompt can use remaining context after reserving space for generation
|
* Prompt can use remaining context after reserving space for generation
|
||||||
*/
|
*/
|
||||||
@@ -100,13 +111,13 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
/**
|
/**
|
||||||
* Truncate prompt if necessary to fit within constraints
|
* Truncate prompt if necessary to fit within constraints
|
||||||
*/
|
*/
|
||||||
prompt_tokens.resize(static_cast<std::size_t>(token_count));
|
prompt_tokens.resize(static_cast<size_t>(token_count));
|
||||||
if (token_count > prompt_budget) {
|
if (token_count > prompt_budget) {
|
||||||
spdlog::warn(
|
spdlog::warn(
|
||||||
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
|
"LlamaGenerator: prompt too long ({} tokens), truncating to {} "
|
||||||
"tokens to fit n_batch/n_ctx limits",
|
"tokens to fit n_batch/n_ctx limits",
|
||||||
token_count, prompt_budget);
|
token_count, prompt_budget);
|
||||||
prompt_tokens.resize(static_cast<std::size_t>(prompt_budget));
|
prompt_tokens.resize(static_cast<size_t>(prompt_budget));
|
||||||
token_count = prompt_budget;
|
token_count = prompt_budget;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,9 +138,9 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
* end-of-sequence
|
* end-of-sequence
|
||||||
*/
|
*/
|
||||||
std::vector<llama_token> generated_tokens;
|
std::vector<llama_token> generated_tokens;
|
||||||
generated_tokens.reserve(static_cast<std::size_t>(effective_max_tokens));
|
generated_tokens.reserve(static_cast<size_t>(effective_max_tokens));
|
||||||
|
|
||||||
if (sampler_ == nullptr || sampler_->chain == nullptr) {
|
if (!sampler_) {
|
||||||
throw std::runtime_error("LlamaGenerator: sampler not initialized");
|
throw std::runtime_error("LlamaGenerator: sampler not initialized");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +150,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
* Index -1 means use the last output position from previous batch
|
* Index -1 means use the last output position from previous batch
|
||||||
*/
|
*/
|
||||||
const llama_token next =
|
const llama_token next =
|
||||||
llama_sampler_sample(sampler_->chain, context_.get(), -1);
|
llama_sampler_sample(sampler_.get(), context_.get(), -1);
|
||||||
/**
|
/**
|
||||||
* Stop if model predicts end-of-generation token (EOS/EOT)
|
* Stop if model predicts end-of-generation token (EOS/EOT)
|
||||||
*/
|
*/
|
||||||
@@ -165,7 +176,7 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
*/
|
*/
|
||||||
std::string output;
|
std::string output;
|
||||||
for (const llama_token token : generated_tokens) {
|
for (const llama_token token : generated_tokens) {
|
||||||
AppendTokenPiecePublic(vocab, token, output);
|
AppendTokenPiece(vocab, token, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
|
|||||||
@@ -9,60 +9,31 @@
|
|||||||
#include <random>
|
#include <random>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
#include "data_model/application_options.h"
|
#include "data_model/application_options.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
static constexpr uint32_t kMaxContextSize = 32768U;
|
static constexpr uint32_t kMaxContextSize = 32768U;
|
||||||
|
|
||||||
struct SamplerConfig {
|
void LlamaGenerator::ModelDeleter::operator()(
|
||||||
float temperature;
|
llama_model* model) const noexcept {
|
||||||
float top_p;
|
|
||||||
uint32_t top_k;
|
|
||||||
};
|
|
||||||
|
|
||||||
using SamplerPtr =
|
|
||||||
std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
|
||||||
|
|
||||||
void LlamaGenerator::ModelDeleter::operator()(llama_model* model) const noexcept {
|
|
||||||
if (model != nullptr) {
|
if (model != nullptr) {
|
||||||
llama_model_free(model);
|
llama_model_free(model);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaGenerator::ContextDeleter::operator()(llama_context* context) const noexcept {
|
void LlamaGenerator::ContextDeleter::operator()(
|
||||||
|
llama_context* context) const noexcept {
|
||||||
if (context != nullptr) {
|
if (context != nullptr) {
|
||||||
llama_free(context);
|
llama_free(context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static SamplerPtr CreateSamplerChain(const SamplerConfig& config,
|
void LlamaGenerator::SamplerDeleter::operator()(
|
||||||
std::mt19937& rng) {
|
llama_sampler* sampler) const noexcept {
|
||||||
const llama_sampler_chain_params sampler_params =
|
if (sampler != nullptr) {
|
||||||
llama_sampler_chain_default_params();
|
llama_sampler_free(sampler);
|
||||||
|
|
||||||
SamplerPtr sampler(llama_sampler_chain_init(sampler_params),
|
|
||||||
&llama_sampler_free);
|
|
||||||
if (!sampler) {
|
|
||||||
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_chain_add(sampler.get(),
|
|
||||||
llama_sampler_init_temp(config.temperature));
|
|
||||||
llama_sampler_chain_add(
|
|
||||||
sampler.get(),
|
|
||||||
llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)));
|
|
||||||
llama_sampler_chain_add(sampler.get(),
|
|
||||||
llama_sampler_init_top_p(config.top_p, 1));
|
|
||||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(rng()));
|
|
||||||
|
|
||||||
return sampler;
|
|
||||||
}
|
|
||||||
|
|
||||||
LlamaGenerator::SamplerState::~SamplerState() {
|
|
||||||
if (chain != nullptr) {
|
|
||||||
llama_sampler_free(chain);
|
|
||||||
chain = nullptr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,11 +81,25 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
|||||||
n_ctx_ = options.n_ctx;
|
n_ctx_ = options.n_ctx;
|
||||||
|
|
||||||
this->Load(model_path);
|
this->Load(model_path);
|
||||||
const SamplerConfig sampler_config{sampling_temperature_, sampling_top_p_,
|
const llama_sampler_chain_params sampler_params =
|
||||||
sampling_top_k_};
|
llama_sampler_chain_default_params();
|
||||||
auto sampler_chain = CreateSamplerChain(sampler_config, rng_);
|
|
||||||
sampler_ = std::make_unique<SamplerState>();
|
sampler_ = SamplerChainHandle(llama_sampler_chain_init(sampler_params));
|
||||||
sampler_->chain = sampler_chain.release();
|
if (!sampler_) {
|
||||||
|
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_chain_add(sampler_.get(),
|
||||||
|
llama_sampler_init_temp(sampling_temperature_));
|
||||||
|
|
||||||
|
llama_sampler_chain_add(
|
||||||
|
sampler_.get(),
|
||||||
|
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
|
||||||
|
|
||||||
|
llama_sampler_chain_add(sampler_.get(),
|
||||||
|
llama_sampler_init_top_p(sampling_top_p_, 1));
|
||||||
|
|
||||||
|
llama_sampler_chain_add(sampler_.get(), llama_sampler_init_dist(rng_()));
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaGenerator::~LlamaGenerator() = default;
|
LlamaGenerator::~LlamaGenerator() = default;
|
||||||
|
|||||||
@@ -12,8 +12,6 @@
|
|||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Loads brewery system prompt from disk or cache.
|
* @brief Loads brewery system prompt from disk or cache.
|
||||||
*
|
*
|
||||||
@@ -21,22 +19,21 @@ namespace fs = std::filesystem;
|
|||||||
* @return Prompt text loaded from disk.
|
* @return Prompt text loaded from disk.
|
||||||
*/
|
*/
|
||||||
std::string LlamaGenerator::LoadBrewerySystemPrompt(
|
std::string LlamaGenerator::LoadBrewerySystemPrompt(
|
||||||
const std::string& prompt_file_path) {
|
const std::filesystem::path& prompt_file_path) {
|
||||||
// Return cached version if already loaded
|
// Return cached version if already loaded
|
||||||
if (!brewery_system_prompt_.empty()) {
|
if (!brewery_system_prompt_.empty()) {
|
||||||
return brewery_system_prompt_;
|
return brewery_system_prompt_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try the provided path only
|
|
||||||
const fs::path prompt_path(prompt_file_path);
|
std::ifstream prompt_file(prompt_file_path);
|
||||||
std::ifstream prompt_file(prompt_path);
|
|
||||||
if (!prompt_file.is_open()) {
|
if (!prompt_file.is_open()) {
|
||||||
spdlog::error(
|
spdlog::error(
|
||||||
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
|
"LlamaGenerator: Failed to open brewery system prompt file '{}'",
|
||||||
prompt_path.string());
|
prompt_file_path.string());
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: missing brewery system prompt file: " +
|
"LlamaGenerator: missing brewery system prompt file: " +
|
||||||
prompt_path.string());
|
prompt_file_path.string());
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string prompt((std::istreambuf_iterator(prompt_file)),
|
const std::string prompt((std::istreambuf_iterator(prompt_file)),
|
||||||
@@ -45,15 +42,15 @@ std::string LlamaGenerator::LoadBrewerySystemPrompt(
|
|||||||
|
|
||||||
if (prompt.empty()) {
|
if (prompt.empty()) {
|
||||||
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
|
spdlog::error("LlamaGenerator: Brewery system prompt file '{}' is empty",
|
||||||
prompt_path.string());
|
prompt_file_path.string());
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"LlamaGenerator: empty brewery system prompt file: " +
|
"LlamaGenerator: empty brewery system prompt file: " +
|
||||||
prompt_path.string());
|
prompt_file_path.string());
|
||||||
}
|
}
|
||||||
|
|
||||||
spdlog::info(
|
spdlog::info(
|
||||||
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
|
"LlamaGenerator: Loaded brewery system prompt from '{}' ({} chars)",
|
||||||
prompt_path.string(), prompt.length());
|
prompt_file_path.string(), prompt.length());
|
||||||
brewery_system_prompt_ = prompt;
|
brewery_system_prompt_ = prompt;
|
||||||
return brewery_system_prompt_;
|
return brewery_system_prompt_;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
BreweryResult MockGenerator::GenerateBrewery(
|
BreweryResult MockGenerator::GenerateBrewery(
|
||||||
const Location& location, const std::string& /*region_context*/) {
|
const Location& location, const std::string& /*region_context*/) {
|
||||||
const std::size_t hash = DeterministicHash(location);
|
const size_t hash = DeterministicHash(location);
|
||||||
|
|
||||||
const std::string_view adjective =
|
const std::string_view adjective =
|
||||||
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
|
kBreweryAdjectives.at(hash % kBreweryAdjectives.size());
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
#include "data_generation/mock_generator.h"
|
#include "data_generation/mock_generator.h"
|
||||||
|
|
||||||
UserResult MockGenerator::GenerateUser(const std::string& locale) {
|
UserResult MockGenerator::GenerateUser(const std::string& locale) {
|
||||||
const std::size_t hash = std::hash<std::string>{}(locale);
|
const size_t hash = std::hash<std::string>{}(locale);
|
||||||
|
|
||||||
UserResult result;
|
UserResult result;
|
||||||
const std::string_view username = kUsernames[hash % kUsernames.size()];
|
const std::string_view username = kUsernames[hash % kUsernames.size()];
|
||||||
|
|||||||
@@ -4,16 +4,16 @@
|
|||||||
* initializes shared infrastructure, and executes the pipeline entry flow.
|
* initializes shared infrastructure, and executes the pipeline entry flow.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
|
#include <boost/di.hpp>
|
||||||
|
#include <boost/program_options.hpp>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <boost/di.hpp>
|
|
||||||
#include <boost/program_options.hpp>
|
|
||||||
#include <spdlog/spdlog.h>
|
|
||||||
|
|
||||||
#include "biergarten_data_generator.h"
|
#include "biergarten_data_generator.h"
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
#include "data_generation/mock_generator.h"
|
#include "data_generation/mock_generator.h"
|
||||||
|
|||||||
Reference in New Issue
Block a user