mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-06-01 01:54:00 +00:00
Add llama grammar to ensure proper json output
This commit is contained in:
@@ -19,7 +19,6 @@
|
||||
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
|
||||
/**
|
||||
* @brief Data generator implementation backed by llama.cpp.
|
||||
@@ -78,13 +77,9 @@ class LlamaGenerator final : public DataGenerator {
|
||||
struct ContextDeleter {
|
||||
void operator()(llama_context* context) const noexcept;
|
||||
};
|
||||
struct SamplerDeleter {
|
||||
void operator()(llama_sampler* sampler) const noexcept;
|
||||
};
|
||||
|
||||
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
||||
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
||||
using SamplerChainHandle = std::unique_ptr<llama_sampler, SamplerDeleter>;
|
||||
|
||||
/**
|
||||
* @brief Loads model and prepares inference context.
|
||||
@@ -102,20 +97,24 @@ class LlamaGenerator final : public DataGenerator {
|
||||
* @param system_prompt System role prompt.
|
||||
* @param prompt User prompt.
|
||||
* @param max_tokens Maximum tokens to generate.
|
||||
* @param grammar Optional GBNF grammar constraining generated output.
|
||||
* @return Generated text.
|
||||
*/
|
||||
std::string Infer(const std::string& system_prompt, const std::string& prompt,
|
||||
int max_tokens = kDefaultMaxTokens);
|
||||
int max_tokens = kDefaultMaxTokens,
|
||||
std::string_view grammar = {});
|
||||
|
||||
/**
|
||||
* @brief Runs inference on an already-formatted prompt.
|
||||
*
|
||||
* @param formatted_prompt Prompt preformatted for model chat template.
|
||||
* @param max_tokens Maximum tokens to generate.
|
||||
* @param grammar Optional GBNF grammar constraining generated output.
|
||||
* @return Generated text.
|
||||
*/
|
||||
std::string InferFormatted(const std::string& formatted_prompt,
|
||||
int max_tokens = kDefaultMaxTokens);
|
||||
int max_tokens = kDefaultMaxTokens,
|
||||
std::string_view grammar = {});
|
||||
|
||||
/**
|
||||
* @brief Loads the brewery system prompt from disk.
|
||||
@@ -127,8 +126,6 @@ class LlamaGenerator final : public DataGenerator {
|
||||
|
||||
ModelHandle model_;
|
||||
ContextHandle context_;
|
||||
/// @brief Persistent sampler chain reused across inference calls.
|
||||
SamplerChainHandle sampler_;
|
||||
float sampling_temperature_ = 1.0F;
|
||||
float sampling_top_p_ = kDefaultSamplingTopP;
|
||||
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
struct llama_model;
|
||||
struct llama_vocab;
|
||||
typedef int32_t llama_token;
|
||||
using llama_token = int32_t;
|
||||
|
||||
/**
|
||||
* @brief Normalizes and truncates regional context.
|
||||
@@ -60,12 +60,4 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||
std::string& name_out,
|
||||
std::string& description_out);
|
||||
|
||||
/**
|
||||
* @brief Extracts the last balanced JSON object from text.
|
||||
*
|
||||
* @param text Input text.
|
||||
* @return Extracted JSON object or an empty string if none exists.
|
||||
*/
|
||||
std::string ExtractLastJsonObject(const std::string& text);
|
||||
|
||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
||||
|
||||
Reference in New Issue
Block a user