mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
cleanup
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#ifndef BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
||||
#define BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_H_
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
/**
|
||||
* @file data_generation/llama_generator.h
|
||||
* @brief llama.cpp-backed implementation of DataGenerator.
|
||||
@@ -34,12 +36,16 @@ class LlamaGenerator final : public DataGenerator {
|
||||
LlamaGenerator(const ApplicationOptions& options,
|
||||
const std::string& model_path);
|
||||
|
||||
/// @brief Releases model/context resources.
|
||||
~LlamaGenerator() override;
|
||||
|
||||
// disable copy constructor
|
||||
LlamaGenerator(const LlamaGenerator&) = delete;
|
||||
|
||||
// disable copy assignment operator
|
||||
LlamaGenerator& operator=(const LlamaGenerator&) = delete;
|
||||
// disable move constructor
|
||||
LlamaGenerator(LlamaGenerator&&) = delete;
|
||||
// disable move assignment operator
|
||||
LlamaGenerator& operator=(LlamaGenerator&&) = delete;
|
||||
|
||||
/**
|
||||
@@ -61,7 +67,7 @@ class LlamaGenerator final : public DataGenerator {
|
||||
UserResult GenerateUser(const std::string& locale) override;
|
||||
|
||||
private:
|
||||
static constexpr int kDefaultMaxTokens = 10000;
|
||||
static constexpr int32_t kDefaultMaxTokens = 10000;
|
||||
static constexpr float kDefaultSamplingTopP = 0.95F;
|
||||
static constexpr uint32_t kDefaultSamplingTopK = 64;
|
||||
static constexpr uint32_t kDefaultContextSize = 8192;
|
||||
@@ -69,25 +75,16 @@ class LlamaGenerator final : public DataGenerator {
|
||||
struct ModelDeleter {
|
||||
void operator()(llama_model* model) const noexcept;
|
||||
};
|
||||
|
||||
struct ContextDeleter {
|
||||
void operator()(llama_context* context) const noexcept;
|
||||
};
|
||||
struct SamplerDeleter {
|
||||
void operator()(llama_sampler* sampler) const noexcept;
|
||||
};
|
||||
|
||||
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
||||
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
||||
|
||||
struct SamplerState {
|
||||
SamplerState() = default;
|
||||
~SamplerState();
|
||||
|
||||
SamplerState(const SamplerState&) = delete;
|
||||
SamplerState& operator=(const SamplerState&) = delete;
|
||||
SamplerState(SamplerState&&) = delete;
|
||||
SamplerState& operator=(SamplerState&&) = delete;
|
||||
|
||||
llama_sampler* chain = nullptr;
|
||||
};
|
||||
using SamplerChainHandle = std::unique_ptr<llama_sampler, SamplerDeleter>;
|
||||
|
||||
/**
|
||||
* @brief Loads model and prepares inference context.
|
||||
@@ -126,12 +123,12 @@ class LlamaGenerator final : public DataGenerator {
|
||||
* @param prompt_file_path Prompt file path to try first.
|
||||
* @return Loaded prompt text.
|
||||
*/
|
||||
std::string LoadBrewerySystemPrompt(const std::string& prompt_file_path);
|
||||
std::string LoadBrewerySystemPrompt(const std::filesystem::path& prompt_file_path);
|
||||
|
||||
ModelHandle model_;
|
||||
ContextHandle context_;
|
||||
/// @brief Persistent sampler chain reused across inference calls.
|
||||
std::unique_ptr<SamplerState> sampler_;
|
||||
SamplerChainHandle sampler_;
|
||||
float sampling_temperature_ = 1.0F;
|
||||
float sampling_top_p_ = kDefaultSamplingTopP;
|
||||
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
*/
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
|
||||
struct llama_model;
|
||||
struct llama_vocab;
|
||||
typedef int llama_token;
|
||||
typedef int32_t llama_token;
|
||||
|
||||
/**
|
||||
* @brief Normalizes and truncates regional context.
|
||||
@@ -23,18 +23,8 @@ typedef int llama_token;
|
||||
* @param max_chars Maximum output length.
|
||||
* @return Processed region context.
|
||||
*/
|
||||
std::string PrepareRegionContextPublic(std::string_view region_context,
|
||||
std::size_t max_chars = 2000);
|
||||
|
||||
/**
|
||||
* @brief Parses a response expected to contain two logical lines.
|
||||
*
|
||||
* @param raw Raw model output.
|
||||
* @param error_message Error message thrown on parse failure.
|
||||
* @return Pair containing first and second parsed fields.
|
||||
*/
|
||||
std::pair<std::string, std::string> ParseTwoLineResponsePublic(
|
||||
const std::string& raw, const std::string& error_message);
|
||||
std::string PrepareRegionContext(std::string_view region_context,
|
||||
size_t max_chars = 2000);
|
||||
|
||||
/**
|
||||
* @brief Applies model chat template to system and user prompts.
|
||||
@@ -44,9 +34,9 @@ std::pair<std::string, std::string> ParseTwoLineResponsePublic(
|
||||
* @param user_prompt User prompt text.
|
||||
* @return Model-formatted prompt.
|
||||
*/
|
||||
std::string ToChatPromptPublic(const llama_model* model,
|
||||
const std::string& system_prompt,
|
||||
const std::string& user_prompt);
|
||||
std::string ToChatPrompt(const llama_model* model,
|
||||
const std::string& system_prompt,
|
||||
const std::string& user_prompt);
|
||||
|
||||
/**
|
||||
* @brief Decodes a sampled token and appends it to output text.
|
||||
@@ -55,8 +45,8 @@ std::string ToChatPromptPublic(const llama_model* model,
|
||||
* @param token Sampled token id.
|
||||
* @param output Output text buffer.
|
||||
*/
|
||||
void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
|
||||
std::string& output);
|
||||
void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
||||
std::string& output);
|
||||
|
||||
/**
|
||||
* @brief Validates and parses brewery JSON output.
|
||||
@@ -66,9 +56,9 @@ void AppendTokenPiecePublic(const llama_vocab* vocab, llama_token token,
|
||||
* @param description_out Parsed brewery description.
|
||||
* @return Validation error message if invalid, or std::nullopt on success.
|
||||
*/
|
||||
std::optional<std::string> ValidateBreweryJsonPublic(
|
||||
const std::string& raw, std::string& name_out,
|
||||
std::string& description_out);
|
||||
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||
std::string& name_out,
|
||||
std::string& description_out);
|
||||
|
||||
/**
|
||||
* @brief Extracts the last balanced JSON object from text.
|
||||
@@ -76,6 +66,6 @@ std::optional<std::string> ValidateBreweryJsonPublic(
|
||||
* @param text Input text.
|
||||
* @return Extracted JSON object or an empty string if none exists.
|
||||
*/
|
||||
std::string ExtractLastJsonObjectPublic(const std::string& text);
|
||||
std::string ExtractLastJsonObject(const std::string& text);
|
||||
|
||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
||||
|
||||
@@ -42,7 +42,7 @@ class MockGenerator final : public DataGenerator {
|
||||
* @param location City and country names.
|
||||
* @return Deterministic hash value.
|
||||
*/
|
||||
static std::size_t DeterministicHash(const Location& location);
|
||||
static size_t DeterministicHash(const Location& location);
|
||||
|
||||
static constexpr std::array<std::string_view, 18> kBreweryAdjectives = {
|
||||
"Craft", "Heritage", "Local", "Artisan", "Pioneer", "Golden",
|
||||
|
||||
Reference in New Issue
Block a user