Add llama grammar to ensure proper json output

This commit is contained in:
Aaron Po
2026-04-15 13:39:01 -04:00
parent ddf4bcb981
commit 62dfb5e14a
7 changed files with 115 additions and 231 deletions

View File

@@ -19,7 +19,6 @@
struct llama_model;
struct llama_context;
struct llama_sampler;
/**
* @brief Data generator implementation backed by llama.cpp.
@@ -78,13 +77,9 @@ class LlamaGenerator final : public DataGenerator {
struct ContextDeleter {
void operator()(llama_context* context) const noexcept;
};
struct SamplerDeleter {
void operator()(llama_sampler* sampler) const noexcept;
};
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
using SamplerChainHandle = std::unique_ptr<llama_sampler, SamplerDeleter>;
/**
* @brief Loads model and prepares inference context.
@@ -102,20 +97,24 @@ class LlamaGenerator final : public DataGenerator {
* @param system_prompt System role prompt.
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @param grammar Optional GBNF grammar constraining generated output.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt, const std::string& prompt,
int max_tokens = kDefaultMaxTokens);
int max_tokens = kDefaultMaxTokens,
std::string_view grammar = {});
/**
* @brief Runs inference on an already-formatted prompt.
*
* @param formatted_prompt Prompt preformatted for model chat template.
* @param max_tokens Maximum tokens to generate.
* @param grammar Optional GBNF grammar constraining generated output.
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = kDefaultMaxTokens);
int max_tokens = kDefaultMaxTokens,
std::string_view grammar = {});
/**
* @brief Loads the brewery system prompt from disk.
@@ -127,8 +126,6 @@ class LlamaGenerator final : public DataGenerator {
ModelHandle model_;
ContextHandle context_;
/// @brief Persistent sampler chain reused across inference calls.
SamplerChainHandle sampler_;
float sampling_temperature_ = 1.0F;
float sampling_top_p_ = kDefaultSamplingTopP;
uint32_t sampling_top_k_ = kDefaultSamplingTopK;