mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-05-31 17:53:59 +00:00
Add llama grammar to ensure proper json output
This commit is contained in:
@@ -19,7 +19,6 @@
|
|||||||
|
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
struct llama_sampler;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Data generator implementation backed by llama.cpp.
|
* @brief Data generator implementation backed by llama.cpp.
|
||||||
@@ -78,13 +77,9 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
struct ContextDeleter {
|
struct ContextDeleter {
|
||||||
void operator()(llama_context* context) const noexcept;
|
void operator()(llama_context* context) const noexcept;
|
||||||
};
|
};
|
||||||
struct SamplerDeleter {
|
|
||||||
void operator()(llama_sampler* sampler) const noexcept;
|
|
||||||
};
|
|
||||||
|
|
||||||
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
using ModelHandle = std::unique_ptr<llama_model, ModelDeleter>;
|
||||||
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
using ContextHandle = std::unique_ptr<llama_context, ContextDeleter>;
|
||||||
using SamplerChainHandle = std::unique_ptr<llama_sampler, SamplerDeleter>;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Loads model and prepares inference context.
|
* @brief Loads model and prepares inference context.
|
||||||
@@ -102,20 +97,24 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
* @param system_prompt System role prompt.
|
* @param system_prompt System role prompt.
|
||||||
* @param prompt User prompt.
|
* @param prompt User prompt.
|
||||||
* @param max_tokens Maximum tokens to generate.
|
* @param max_tokens Maximum tokens to generate.
|
||||||
|
* @param grammar Optional GBNF grammar constraining generated output.
|
||||||
* @return Generated text.
|
* @return Generated text.
|
||||||
*/
|
*/
|
||||||
std::string Infer(const std::string& system_prompt, const std::string& prompt,
|
std::string Infer(const std::string& system_prompt, const std::string& prompt,
|
||||||
int max_tokens = kDefaultMaxTokens);
|
int max_tokens = kDefaultMaxTokens,
|
||||||
|
std::string_view grammar = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Runs inference on an already-formatted prompt.
|
* @brief Runs inference on an already-formatted prompt.
|
||||||
*
|
*
|
||||||
* @param formatted_prompt Prompt preformatted for model chat template.
|
* @param formatted_prompt Prompt preformatted for model chat template.
|
||||||
* @param max_tokens Maximum tokens to generate.
|
* @param max_tokens Maximum tokens to generate.
|
||||||
|
* @param grammar Optional GBNF grammar constraining generated output.
|
||||||
* @return Generated text.
|
* @return Generated text.
|
||||||
*/
|
*/
|
||||||
std::string InferFormatted(const std::string& formatted_prompt,
|
std::string InferFormatted(const std::string& formatted_prompt,
|
||||||
int max_tokens = kDefaultMaxTokens);
|
int max_tokens = kDefaultMaxTokens,
|
||||||
|
std::string_view grammar = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Loads the brewery system prompt from disk.
|
* @brief Loads the brewery system prompt from disk.
|
||||||
@@ -127,8 +126,6 @@ class LlamaGenerator final : public DataGenerator {
|
|||||||
|
|
||||||
ModelHandle model_;
|
ModelHandle model_;
|
||||||
ContextHandle context_;
|
ContextHandle context_;
|
||||||
/// @brief Persistent sampler chain reused across inference calls.
|
|
||||||
SamplerChainHandle sampler_;
|
|
||||||
float sampling_temperature_ = 1.0F;
|
float sampling_temperature_ = 1.0F;
|
||||||
float sampling_top_p_ = kDefaultSamplingTopP;
|
float sampling_top_p_ = kDefaultSamplingTopP;
|
||||||
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
|
uint32_t sampling_top_k_ = kDefaultSamplingTopK;
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
typedef int32_t llama_token;
|
using llama_token = int32_t;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Normalizes and truncates regional context.
|
* @brief Normalizes and truncates regional context.
|
||||||
@@ -60,12 +60,4 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
std::string& name_out,
|
std::string& name_out,
|
||||||
std::string& description_out);
|
std::string& description_out);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Extracts the last balanced JSON object from text.
|
|
||||||
*
|
|
||||||
* @param text Input text.
|
|
||||||
* @return Extracted JSON object or an empty string if none exists.
|
|
||||||
*/
|
|
||||||
std::string ExtractLastJsonObject(const std::string& text);
|
|
||||||
|
|
||||||
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
#endif // BIERGARTEN_PIPELINE_INCLUDES_DATA_GENERATION_LLAMA_GENERATOR_HELPERS_H_
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
#include "biergarten_data_generator.h"
|
#include "biergarten_data_generator.h"
|
||||||
#include "json_handling/json_loader.h"
|
#include "json_handling/json_loader.h"
|
||||||
|
|
||||||
static constexpr size_t kBreweryAmount = 4;
|
static constexpr size_t kBreweryAmount = 50;
|
||||||
|
|
||||||
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
std::vector<Location> BiergartenDataGenerator::QueryCitiesWithCountries() {
|
||||||
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
spdlog::info("\n=== GEOGRAPHIC DATA OVERVIEW ===");
|
||||||
|
|||||||
@@ -6,56 +6,24 @@
|
|||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <format>
|
#include <format>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
#include "data_generation/llama_generator_helpers.h"
|
#include "data_generation/llama_generator_helpers.h"
|
||||||
|
|
||||||
static std::string ExtractFinalJsonPayload(std::string raw_response) {
|
static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery(
|
||||||
auto trim = [](const std::string_view text) -> std::string_view {
|
root ::= ws "{" ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws
|
||||||
const size_t first = text.find_first_not_of(" \t\n\r");
|
ws ::= [ \t\n\r]*
|
||||||
if (first == std::string_view::npos) {
|
string ::= "\"" char+ "\""
|
||||||
return {};
|
char ::= [^"\\\x7F\x00-\x1F] | [\\] escape
|
||||||
}
|
escape ::= ["\\/bfnrt] | "u" hex hex hex hex
|
||||||
|
hex ::= [0-9a-fA-F]
|
||||||
const size_t last = text.find_last_not_of(" \t\n\r");
|
)json_brewery";
|
||||||
return text.substr(first, last - first + 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
static constexpr std::array<std::string_view, 6> separator_tokens = {
|
|
||||||
"<|think|>", "<think|>", "<|turn|>",
|
|
||||||
"<turn|>", "<channel|>", "<|channel|>"};
|
|
||||||
|
|
||||||
size_t separator_pos = std::string::npos;
|
|
||||||
size_t separator_length = 0;
|
|
||||||
for (const std::string_view token : separator_tokens) {
|
|
||||||
const size_t candidate_pos = raw_response.rfind(token);
|
|
||||||
if (candidate_pos != std::string::npos &&
|
|
||||||
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
|
|
||||||
separator_pos = candidate_pos;
|
|
||||||
separator_length = token.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (separator_pos != std::string::npos) {
|
|
||||||
raw_response.erase(0, separator_pos + separator_length);
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string_view trimmed = trim(raw_response);
|
|
||||||
const std::string json_candidate =
|
|
||||||
ExtractLastJsonObject(std::string(trimmed));
|
|
||||||
|
|
||||||
if (!json_candidate.empty()) {
|
|
||||||
return json_candidate;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::string(trimmed);
|
|
||||||
}
|
|
||||||
|
|
||||||
BreweryResult LlamaGenerator::GenerateBrewery(
|
BreweryResult LlamaGenerator::GenerateBrewery(
|
||||||
const Location& location, const std::string& region_context) {
|
const Location& location, const std::string& region_context) {
|
||||||
@@ -108,7 +76,7 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
for (int attempt = 0; attempt < max_attempts; ++attempt) {
|
for (int attempt = 0; attempt < max_attempts; ++attempt) {
|
||||||
constexpr int max_tokens = 1052;
|
constexpr int max_tokens = 1052;
|
||||||
// Generate brewery data from LLM
|
// Generate brewery data from LLM
|
||||||
raw = this->Infer(system_prompt, prompt, max_tokens);
|
raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar);
|
||||||
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
|
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
|
||||||
raw);
|
raw);
|
||||||
|
|
||||||
@@ -116,9 +84,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
|||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
const std::string json_only = ExtractFinalJsonPayload(raw);
|
|
||||||
const std::optional<std::string> validation_error =
|
const std::optional<std::string> validation_error =
|
||||||
ValidateBreweryJson(json_only, name, description);
|
ValidateBreweryJson(raw, name, description);
|
||||||
if (!validation_error.has_value()) {
|
if (!validation_error.has_value()) {
|
||||||
// Success: return parsed brewery data
|
// Success: return parsed brewery data
|
||||||
return BreweryResult{.name = std::move(name),
|
return BreweryResult{.name = std::move(name),
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
#include <boost/json.hpp>
|
#include <boost/json.hpp>
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <sstream>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
@@ -97,11 +96,16 @@ std::string ToChatPrompt(const llama_model* model,
|
|||||||
return combined_prompt;
|
return combined_prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::array<llama_chat_message, 2> messages = {
|
const std::array<llama_chat_message, 2> messages = {{
|
||||||
{{"system", system_prompt.c_str()}, {"user", user_prompt.c_str()}}};
|
{.role = "system", .content = system_prompt.c_str()},
|
||||||
|
{.role = "user", .content = user_prompt.c_str()},
|
||||||
|
}};
|
||||||
|
|
||||||
|
constexpr std::size_t min_template_buffer_size = 1024;
|
||||||
|
|
||||||
std::vector<char> buffer(std::max<std::size_t>(
|
std::vector<char> buffer(std::max<std::size_t>(
|
||||||
1024, (system_prompt.size() + user_prompt.size()) * 4));
|
min_template_buffer_size,
|
||||||
|
(system_prompt.size() + user_prompt.size()) * 4));
|
||||||
|
|
||||||
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
|
auto apply_template_with_resize = [&](const llama_chat_message* chat_messages,
|
||||||
int32_t message_count) -> int32_t {
|
int32_t message_count) -> int32_t {
|
||||||
@@ -113,11 +117,11 @@ std::string ToChatPrompt(const llama_model* model,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result >= static_cast<int32_t>(buffer.size())) {
|
const auto buffer_size = static_cast<int32_t>(buffer.size());
|
||||||
|
if (result >= buffer_size) {
|
||||||
buffer.resize(static_cast<std::size_t>(result) + 1);
|
buffer.resize(static_cast<std::size_t>(result) + 1);
|
||||||
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
|
result = llama_chat_apply_template(tmpl, chat_messages, message_count,
|
||||||
true, buffer.data(),
|
true, buffer.data(), buffer_size);
|
||||||
static_cast<int32_t>(buffer.size()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@@ -136,8 +140,9 @@ std::string ToChatPrompt(const llama_model* model,
|
|||||||
|
|
||||||
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
|
// FALLBACK: If the template fails (e.g., Model rejecting the "system" role),
|
||||||
// combine the system and user prompts into a single "user" message.
|
// combine the system and user prompts into a single "user" message.
|
||||||
const std::array<llama_chat_message, 1> fallback_msg = {
|
const std::array<llama_chat_message, 1> fallback_msg = {{
|
||||||
{{"user", combined_prompt.c_str()}}};
|
{.role = "user", .content = combined_prompt.c_str()},
|
||||||
|
}};
|
||||||
|
|
||||||
template_result = apply_template_with_resize(fallback_msg.data(), 1);
|
template_result = apply_template_with_resize(fallback_msg.data(), 1);
|
||||||
|
|
||||||
@@ -188,102 +193,17 @@ void AppendTokenPiece(const llama_vocab* vocab, llama_token token,
|
|||||||
"LlamaGenerator: failed to decode sampled token piece");
|
"LlamaGenerator: failed to decode sampled token piece");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shared parser used by the public extractor and JSON validation.
|
|
||||||
static bool ExtractLastJsonObject(const std::string& text,
|
|
||||||
std::string& json_out) {
|
|
||||||
// Remember where the most recent balanced object started.
|
|
||||||
size_t start = std::string::npos;
|
|
||||||
|
|
||||||
// Track nested braces outside of quoted strings.
|
|
||||||
int depth = 0;
|
|
||||||
|
|
||||||
// Track whether the scan is currently inside a quoted string.
|
|
||||||
bool in_string = false;
|
|
||||||
|
|
||||||
// Track escape sequences so quotes inside strings are handled correctly.
|
|
||||||
bool escaped = false;
|
|
||||||
|
|
||||||
// Record whether at least one complete object was found.
|
|
||||||
bool found = false;
|
|
||||||
|
|
||||||
// Keep the latest complete object candidate.
|
|
||||||
std::string candidate;
|
|
||||||
|
|
||||||
// Scan the input text one character at a time.
|
|
||||||
for (size_t i = 0; i < text.size(); ++i) {
|
|
||||||
// Inspect the current character.
|
|
||||||
const char chr = text[i];
|
|
||||||
|
|
||||||
// Inside a string literal, only escapes and quotes affect state.
|
|
||||||
if (in_string) {
|
|
||||||
if (escaped) {
|
|
||||||
// The current character was escaped, so clear the escape flag.
|
|
||||||
escaped = false;
|
|
||||||
} else if (chr == '\\') {
|
|
||||||
// Mark the next character as escaped.
|
|
||||||
escaped = true;
|
|
||||||
} else if (chr == '"') {
|
|
||||||
// Closing quote ends the string literal.
|
|
||||||
in_string = false;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Opening quotes enter string mode.
|
|
||||||
if (chr == '"') {
|
|
||||||
in_string = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Opening braces begin or nest a JSON object.
|
|
||||||
if (chr == '{') {
|
|
||||||
if (depth == 0) {
|
|
||||||
// Record the start of the outermost object.
|
|
||||||
start = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increase nesting depth for the active object.
|
|
||||||
++depth;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Closing braces may complete an object.
|
|
||||||
if (chr == '}') {
|
|
||||||
if (depth == 0) {
|
|
||||||
// Ignore stray closing braces.
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Drop one level of nesting.
|
|
||||||
--depth;
|
|
||||||
if (depth == 0 && start != std::string::npos) {
|
|
||||||
// Capture the latest complete object seen so far.
|
|
||||||
candidate = text.substr(start, i - start + 1);
|
|
||||||
found = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!found) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the captured object text to the caller.
|
|
||||||
json_out = std::move(candidate);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
||||||
std::string& name_out,
|
std::string& name_out,
|
||||||
std::string& description_out) {
|
std::string& description_out) {
|
||||||
auto validate_object = [&](const boost::json::value& jv,
|
auto validate_object = [&](const boost::json::value& json_value,
|
||||||
std::string& error_out) -> bool {
|
std::string& error_out) -> bool {
|
||||||
if (!jv.is_object()) {
|
if (!json_value.is_object()) {
|
||||||
error_out = "JSON root must be an object";
|
error_out = "JSON root must be an object";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& obj = jv.get_object();
|
const auto& obj = json_value.get_object();
|
||||||
if (!obj.contains("name") || !obj.at("name").is_string()) {
|
if (!obj.contains("name") || !obj.at("name").is_string()) {
|
||||||
error_out = "JSON field 'name' is missing or not a string";
|
error_out = "JSON field 'name' is missing or not a string";
|
||||||
return false;
|
return false;
|
||||||
@@ -313,14 +233,15 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
std::string name_lower = name_out;
|
std::string name_lower = name_out;
|
||||||
std::string description_lower = description_out;
|
std::string description_lower = description_out;
|
||||||
|
|
||||||
std::transform(
|
std::ranges::transform(name_lower, name_lower.begin(),
|
||||||
name_lower.begin(), name_lower.end(), name_lower.begin(),
|
[](unsigned char character) {
|
||||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
return static_cast<char>(std::tolower(character));
|
||||||
|
});
|
||||||
|
|
||||||
std::transform(description_lower.begin(), description_lower.end(),
|
std::ranges::transform(description_lower, description_lower.begin(),
|
||||||
description_lower.begin(), [](unsigned char c) {
|
[](unsigned char character) {
|
||||||
return static_cast<char>(std::tolower(c));
|
return static_cast<char>(std::tolower(character));
|
||||||
});
|
});
|
||||||
|
|
||||||
if (name_lower == "string" || description_lower == "string") {
|
if (name_lower == "string" || description_lower == "string") {
|
||||||
error_out = "JSON appears to be a schema placeholder, not content";
|
error_out = "JSON appears to be a schema placeholder, not content";
|
||||||
@@ -331,41 +252,16 @@ std::optional<std::string> ValidateBreweryJson(const std::string& raw,
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
boost::system::error_code ec;
|
boost::system::error_code error_code;
|
||||||
boost::json::value jv = boost::json::parse(raw, ec);
|
boost::json::value json_value = boost::json::parse(raw, error_code);
|
||||||
std::string validation_error;
|
std::string validation_error;
|
||||||
if (ec) {
|
if (error_code) {
|
||||||
std::string extracted;
|
return "JSON parse error: " + error_code.message();
|
||||||
if (!ExtractLastJsonObject(raw, extracted)) {
|
|
||||||
return "JSON parse error: " + ec.message();
|
|
||||||
}
|
|
||||||
|
|
||||||
ec.clear();
|
|
||||||
jv = boost::json::parse(extracted, ec);
|
|
||||||
if (ec) {
|
|
||||||
return "JSON parse error: " + ec.message();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!validate_object(jv, validation_error)) {
|
|
||||||
return validation_error;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!validate_object(jv, validation_error)) {
|
if (!validate_object(json_value, validation_error)) {
|
||||||
return validation_error;
|
return validation_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ExtractLastJsonObject(const std::string& text) {
|
|
||||||
// Reuse the internal parser and return an empty string if none was found.
|
|
||||||
std::string extracted;
|
|
||||||
if (ExtractLastJsonObject(text, extracted)) {
|
|
||||||
return extracted;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "data_generation/llama_generator.h"
|
#include "data_generation/llama_generator.h"
|
||||||
@@ -19,15 +20,68 @@
|
|||||||
|
|
||||||
static constexpr size_t kPromptTokenSlack = 8;
|
static constexpr size_t kPromptTokenSlack = 8;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using SamplerHandle = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>;
|
||||||
|
|
||||||
|
struct SamplerConfig {
|
||||||
|
float temperature;
|
||||||
|
uint32_t top_k;
|
||||||
|
float top_p;
|
||||||
|
uint32_t seed;
|
||||||
|
};
|
||||||
|
|
||||||
|
SamplerHandle MakeSamplerChain(const llama_vocab* vocab,
|
||||||
|
const SamplerConfig& config,
|
||||||
|
std::string_view grammar) {
|
||||||
|
const llama_sampler_chain_params sampler_params =
|
||||||
|
llama_sampler_chain_default_params();
|
||||||
|
|
||||||
|
SamplerHandle chain(llama_sampler_chain_init(sampler_params),
|
||||||
|
llama_sampler_free);
|
||||||
|
if (!chain) {
|
||||||
|
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto add_sampler = [&](llama_sampler* sampler, const char* error_message) {
|
||||||
|
if (sampler == nullptr) {
|
||||||
|
throw std::runtime_error(error_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_chain_add(chain.get(), sampler);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!grammar.empty()) {
|
||||||
|
const std::string grammar_text(grammar);
|
||||||
|
add_sampler(llama_sampler_init_grammar(vocab, grammar_text.c_str(), "root"),
|
||||||
|
"LlamaGenerator: failed to initialize grammar sampler");
|
||||||
|
}
|
||||||
|
|
||||||
|
add_sampler(llama_sampler_init_temp(config.temperature),
|
||||||
|
"LlamaGenerator: failed to initialize temperature sampler");
|
||||||
|
add_sampler(llama_sampler_init_top_k(static_cast<int32_t>(config.top_k)),
|
||||||
|
"LlamaGenerator: failed to initialize top-k sampler");
|
||||||
|
add_sampler(llama_sampler_init_top_p(config.top_p, 1),
|
||||||
|
"LlamaGenerator: failed to initialize top-p sampler");
|
||||||
|
add_sampler(llama_sampler_init_dist(config.seed),
|
||||||
|
"LlamaGenerator: failed to initialize distribution sampler");
|
||||||
|
|
||||||
|
return chain;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
std::string LlamaGenerator::Infer(const std::string& system_prompt,
|
||||||
const std::string& prompt,
|
const std::string& prompt,
|
||||||
const int max_tokens) {
|
const int max_tokens,
|
||||||
|
std::string_view grammar) {
|
||||||
return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt),
|
return InferFormatted(ToChatPrompt(model_.get(), system_prompt, prompt),
|
||||||
max_tokens);
|
max_tokens, grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||||
const int max_tokens) {
|
const int max_tokens,
|
||||||
|
std::string_view grammar) {
|
||||||
/**
|
/**
|
||||||
* Validate that model and context are loaded
|
* Validate that model and context are loaded
|
||||||
*/
|
*/
|
||||||
@@ -43,6 +97,14 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
throw std::runtime_error("LlamaGenerator: vocab unavailable");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const SamplerConfig sampler_config{
|
||||||
|
.temperature = sampling_temperature_,
|
||||||
|
.top_k = sampling_top_k_,
|
||||||
|
.top_p = sampling_top_p_,
|
||||||
|
.seed = rng_(),
|
||||||
|
};
|
||||||
|
auto sampler = MakeSamplerChain(vocab, sampler_config, grammar);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clear KV cache to ensure clean inference state (no residual context)
|
* Clear KV cache to ensure clean inference state (no residual context)
|
||||||
*/
|
*/
|
||||||
@@ -140,17 +202,13 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
|||||||
std::vector<llama_token> generated_tokens;
|
std::vector<llama_token> generated_tokens;
|
||||||
generated_tokens.reserve(static_cast<size_t>(effective_max_tokens));
|
generated_tokens.reserve(static_cast<size_t>(effective_max_tokens));
|
||||||
|
|
||||||
if (!sampler_) {
|
|
||||||
throw std::runtime_error("LlamaGenerator: sampler not initialized");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < effective_max_tokens; ++i) {
|
for (int i = 0; i < effective_max_tokens; ++i) {
|
||||||
/**
|
/**
|
||||||
* Sample next token using configured sampler chain and model logits
|
* Sample next token using configured sampler chain and model logits
|
||||||
* Index -1 means use the last output position from previous batch
|
* Index -1 means use the last output position from previous batch
|
||||||
*/
|
*/
|
||||||
const llama_token next =
|
const llama_token next =
|
||||||
llama_sampler_sample(sampler_.get(), context_.get(), -1);
|
llama_sampler_sample(sampler.get(), context_.get(), -1);
|
||||||
/**
|
/**
|
||||||
* Stop if model predicts end-of-generation token (EOS/EOT)
|
* Stop if model predicts end-of-generation token (EOS/EOT)
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -30,13 +30,6 @@ void LlamaGenerator::ContextDeleter::operator()(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaGenerator::SamplerDeleter::operator()(
|
|
||||||
llama_sampler* sampler) const noexcept {
|
|
||||||
if (sampler != nullptr) {
|
|
||||||
llama_sampler_free(sampler);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
||||||
const std::string& model_path)
|
const std::string& model_path)
|
||||||
: rng_(std::random_device{}()) {
|
: rng_(std::random_device{}()) {
|
||||||
@@ -81,25 +74,6 @@ LlamaGenerator::LlamaGenerator(const ApplicationOptions& options,
|
|||||||
n_ctx_ = options.n_ctx;
|
n_ctx_ = options.n_ctx;
|
||||||
|
|
||||||
this->Load(model_path);
|
this->Load(model_path);
|
||||||
const llama_sampler_chain_params sampler_params =
|
|
||||||
llama_sampler_chain_default_params();
|
|
||||||
|
|
||||||
sampler_ = SamplerChainHandle(llama_sampler_chain_init(sampler_params));
|
|
||||||
if (!sampler_) {
|
|
||||||
throw std::runtime_error("LlamaGenerator: failed to initialize sampler");
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_chain_add(sampler_.get(),
|
|
||||||
llama_sampler_init_temp(sampling_temperature_));
|
|
||||||
|
|
||||||
llama_sampler_chain_add(
|
|
||||||
sampler_.get(),
|
|
||||||
llama_sampler_init_top_k(static_cast<int32_t>(sampling_top_k_)));
|
|
||||||
|
|
||||||
llama_sampler_chain_add(sampler_.get(),
|
|
||||||
llama_sampler_init_top_p(sampling_top_p_, 1));
|
|
||||||
|
|
||||||
llama_sampler_chain_add(sampler_.get(), llama_sampler_init_dist(rng_()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaGenerator::~LlamaGenerator() = default;
|
LlamaGenerator::~LlamaGenerator() = default;
|
||||||
|
|||||||
Reference in New Issue
Block a user