Add llama grammar to ensure proper json output

This commit is contained in:
Aaron Po
2026-04-15 13:39:01 -04:00
parent ddf4bcb981
commit 62dfb5e14a
7 changed files with 115 additions and 231 deletions

View File

@@ -6,56 +6,24 @@
#include "data_generation/llama_generator.h"
#include <array>
#include <format>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <spdlog/spdlog.h>
#include "data_generation/llama_generator_helpers.h"
static std::string ExtractFinalJsonPayload(std::string raw_response) {
auto trim = [](const std::string_view text) -> std::string_view {
const size_t first = text.find_first_not_of(" \t\n\r");
if (first == std::string_view::npos) {
return {};
}
const size_t last = text.find_last_not_of(" \t\n\r");
return text.substr(first, last - first + 1);
};
static constexpr std::array<std::string_view, 6> separator_tokens = {
"<|think|>", "<think|>", "<|turn|>",
"<turn|>", "<channel|>", "<|channel|>"};
size_t separator_pos = std::string::npos;
size_t separator_length = 0;
for (const std::string_view token : separator_tokens) {
const size_t candidate_pos = raw_response.rfind(token);
if (candidate_pos != std::string::npos &&
(separator_pos == std::string::npos || candidate_pos > separator_pos)) {
separator_pos = candidate_pos;
separator_length = token.size();
}
}
if (separator_pos != std::string::npos) {
raw_response.erase(0, separator_pos + separator_length);
}
const std::string_view trimmed = trim(raw_response);
const std::string json_candidate =
ExtractLastJsonObject(std::string(trimmed));
if (!json_candidate.empty()) {
return json_candidate;
}
return std::string(trimmed);
}
static constexpr std::string_view kBreweryJsonGrammar = R"json_brewery(
root ::= ws "{" ws "\"name\"" ws ":" ws string ws "," ws "\"description\"" ws ":" ws string ws "}" ws
ws ::= [ \t\n\r]*
string ::= "\"" char+ "\""
char ::= [^"\\\x7F\x00-\x1F] | [\\] escape
escape ::= ["\\/bfnrt] | "u" hex hex hex hex
hex ::= [0-9a-fA-F]
)json_brewery";
BreweryResult LlamaGenerator::GenerateBrewery(
const Location& location, const std::string& region_context) {
@@ -108,7 +76,7 @@ BreweryResult LlamaGenerator::GenerateBrewery(
for (int attempt = 0; attempt < max_attempts; ++attempt) {
constexpr int max_tokens = 1052;
// Generate brewery data from LLM
raw = this->Infer(system_prompt, prompt, max_tokens);
raw = this->Infer(system_prompt, prompt, max_tokens, kBreweryJsonGrammar);
spdlog::debug("LlamaGenerator: raw output (attempt {}): {}", attempt + 1,
raw);
@@ -116,9 +84,8 @@ BreweryResult LlamaGenerator::GenerateBrewery(
std::string name;
std::string description;
const std::string json_only = ExtractFinalJsonPayload(raw);
const std::optional<std::string> validation_error =
ValidateBreweryJson(json_only, name, description);
ValidateBreweryJson(raw, name, description);
if (!validation_error.has_value()) {
// Success: return parsed brewery data
return BreweryResult{.name = std::move(name),