Add llama grammar to ensure proper json output

This commit is contained in:
Aaron Po
2026-04-15 13:39:01 -04:00
parent ddf4bcb981
commit 62dfb5e14a
7 changed files with 115 additions and 231 deletions

View File

@@ -19,7 +19,6 @@
struct llama_model;
struct llama_context;
struct llama_sampler;
/**
* @brief Data generator implementation backed by llama.cpp.
@@ -78,13 +77,9 @@ class LlamaGenerator final : public DataGenerator {
struct ContextDeleter {
void operator()(llama_context* context) const noexcept;
};
struct SamplerDeleter {
void operator()(llama_sampler* sampler) const noexcept;
};
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
using SamplerChainHandle = std::unique_ptr<llama_sampler, SamplerDeleter>;
/**
* @brief Loads model and prepares inference context.
@@ -102,20 +97,24 @@ class LlamaGenerator final : public DataGenerator {
* @param system_prompt System role prompt.
* @param prompt User prompt.
* @param max_tokens Maximum tokens to generate.
* @param grammar Optional GBNF grammar constraining generated output.
* @return Generated text.
*/
std::string Infer(const std::string& system_prompt, const std::string& prompt,
int max_tokens = kDefaultMaxTokens);
int max_tokens = kDefaultMaxTokens,
std::string_view grammar = {});
/**
* @brief Runs inference on an already-formatted prompt.
*
* @param formatted_prompt Prompt preformatted for model chat template.
* @param max_tokens Maximum tokens to generate.
* @param grammar Optional GBNF grammar constraining generated output.
* @return Generated text.
*/
std::string InferFormatted(const std::string& formatted_prompt,
int max_tokens = kDefaultMaxTokens);
int max_tokens = kDefaultMaxTokens,
std::string_view grammar = {});
/**
* @brief Loads the brewery system prompt from disk.
@@ -127,8 +126,6 @@ class LlamaGenerator final : public DataGenerator {
ModelHandle model_;
ContextHandle context_;
/// @brief Persistent sampler chain reused across inference calls.
SamplerChainHandle sampler_;
float sampling_temperature_ = 1.0F;
float sampling_top_p_ = kDefaultSamplingTopP;
uint32_t sampling_top_k_ = kDefaultSamplingTopK;

View File

@@ -14,7 +14,7 @@
struct llama_model;
struct llama_vocab;
typedef int32_t llama_token;
using llama_token = int32_t;
/**
* @brief Normalizes and truncates regional context.
@@ -60,12 +60,4 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
std::string& name_out,
std::string& description_out);
/**
* @brief Extracts the last balanced JSON object from text.
*
* @param text Input text.
* @return Extracted JSON object or an empty string if none exists.
*/
std::string ExtractLastJsonObject(const std::string& text);
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_