mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-04-05 18:09:04 +00:00
fix: address critical correctness, reliability, and design issues in pipeline
CORRECTNESS FIXES: - json_loader: Add RollbackTransaction() and call it on exception instead of CommitTransaction(). Prevents partial data corruption on parse/disk errors. - wikipedia_service: Fix invalid MediaWiki API parameter explaintext=true -> explaintext=1. Now returns plain text instead of HTML markup in contexts. - helpers: Fix ParseTwoLineResponse filter to only remove known thinking tags (<think>, <reasoning>, <reflect>) instead of any <...> pattern. Prevents silently removing legitimate output like <username>content</username>. RELIABILITY & DESIGN IMPROVEMENTS: - load/main: Make n_ctx (context window size) configurable via --n-ctx flag (default 2048, range 1-32768) to support larger models like Qwen3-14B. - generate_brewery: Prevent retry prompt growth by extracting location context into constant and using compact retry format (error + schema + location only). Avoids token truncation on final retry attempts. - database: Fix data representativeness by changing QueryCities from ORDER BY name (alphabetic bias) to ORDER BY RANDOM() for unbiased sampling. Convert all SQLITE_STATIC to SQLITE_TRANSIENT to prevent use-after-free risks. POLISH: - infer: Advance sampling seed between generation calls to improve diversity across brewery and user generation. - data_downloader: Remove unnecessary commit hash truncation; use full hash. - json_loader: Fix misleading log message from "RapidJSON" to "Boost.JSON".
This commit is contained in:
@@ -25,15 +25,10 @@ std::string DataDownloader::DownloadCountriesDatabase(
|
||||
return cache_path;
|
||||
}
|
||||
|
||||
std::string short_commit = commit;
|
||||
if (commit.length() > 7) {
|
||||
short_commit = commit.substr(0, 7);
|
||||
}
|
||||
|
||||
std::string url =
|
||||
"https://raw.githubusercontent.com/dr5hn/"
|
||||
"countries-states-cities-database/" +
|
||||
short_commit + "/json/countries+states+cities.json";
|
||||
commit + "/json/countries+states+cities.json";
|
||||
|
||||
spdlog::info("[DataDownloader] Downloading: {}", url);
|
||||
|
||||
|
||||
@@ -50,6 +50,14 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
? std::string(".")
|
||||
: std::string(". Regional context: ") + safe_region_context);
|
||||
|
||||
/**
|
||||
* Store location context for retry prompts (without repeating full context)
|
||||
*/
|
||||
const std::string retry_location =
|
||||
"Location: " + city_name +
|
||||
(country_name.empty() ? std::string("")
|
||||
: std::string(", ") + country_name);
|
||||
|
||||
/**
|
||||
* RETRY LOOP with validation and error correction
|
||||
* Attempts to generate valid brewery data up to 3 times, with feedback-based
|
||||
@@ -84,19 +92,16 @@ BreweryResult LlamaGenerator::GenerateBrewery(
|
||||
spdlog::warn("LlamaGenerator: malformed brewery JSON (attempt {}): {}",
|
||||
attempt + 1, validation_error);
|
||||
|
||||
// Update prompt with error details to guide LLM toward correct output
|
||||
// Update prompt with error details to guide LLM toward correct output.
|
||||
// For retries, use a compact prompt format to avoid exceeding token
|
||||
// limits.
|
||||
prompt =
|
||||
"Your previous response was invalid. Error: " + validation_error +
|
||||
"\nReturn ONLY valid JSON with this exact schema: "
|
||||
"{\"name\": \"string\", \"description\": \"string\"}."
|
||||
"\nDo not include markdown, comments, or extra keys."
|
||||
"\n\nLocation: " +
|
||||
city_name +
|
||||
(country_name.empty() ? std::string("")
|
||||
: std::string(", ") + country_name) +
|
||||
(safe_region_context.empty()
|
||||
? std::string("")
|
||||
: std::string("\nRegional context: ") + safe_region_context);
|
||||
"\n\n" +
|
||||
retry_location;
|
||||
}
|
||||
|
||||
// All retry attempts exhausted: log failure and throw exception
|
||||
|
||||
@@ -147,7 +147,17 @@ std::pair<std::string, std::string> ParseTwoLineResponse(
|
||||
std::transform(low.begin(), low.end(), low.begin(), [](unsigned char c) {
|
||||
return static_cast<char>(std::tolower(c));
|
||||
});
|
||||
if (!l.empty() && l.front() == '<' && low.back() == '>') continue;
|
||||
// Filter known thinking tags like <think>...</think>, but be conservative
|
||||
// to avoid removing legitimate output. Only filter specific known
|
||||
// patterns.
|
||||
if (!l.empty() && l.front() == '<' && low.back() == '>') {
|
||||
// Only filter if it's a known thinking tag: <think>, <reasoning>, etc.
|
||||
if (low.find("think") != std::string::npos ||
|
||||
low.find("reasoning") != std::string::npos ||
|
||||
low.find("reflect") != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (low.rfind("okay,", 0) == 0 || low.rfind("hmm", 0) == 0) continue;
|
||||
filtered.push_back(std::move(l));
|
||||
}
|
||||
|
||||
@@ -186,5 +186,11 @@ std::string LlamaGenerator::InferFormatted(const std::string& formatted_prompt,
|
||||
std::string output;
|
||||
for (const llama_token token : generated_tokens)
|
||||
AppendTokenPiecePublic(vocab, token, output);
|
||||
|
||||
/**
|
||||
* Advance seed for next generation to improve output diversity
|
||||
*/
|
||||
sampling_seed_ = (sampling_seed_ == 0xFFFFFFFFu) ? 0 : sampling_seed_ + 1;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ void LlamaGenerator::Load(const std::string& model_path) {
|
||||
}
|
||||
|
||||
llama_context_params context_params = llama_context_default_params();
|
||||
context_params.n_ctx = 2048;
|
||||
context_params.n_ctx = n_ctx_;
|
||||
|
||||
context_ = llama_init_from_model(model_, context_params);
|
||||
if (context_ == nullptr) {
|
||||
|
||||
@@ -48,3 +48,18 @@ void LlamaGenerator::SetSamplingOptions(float temperature, float top_p,
|
||||
sampling_seed_ = (seed < 0) ? static_cast<uint32_t>(LLAMA_DEFAULT_SEED)
|
||||
: static_cast<uint32_t>(seed);
|
||||
}
|
||||
|
||||
void LlamaGenerator::SetContextSize(uint32_t n_ctx) {
|
||||
/**
|
||||
* Validate context size: must be positive and reasonable for the model
|
||||
*/
|
||||
if (n_ctx == 0 || n_ctx > 32768) {
|
||||
throw std::runtime_error(
|
||||
"LlamaGenerator: context size must be in range [1, 32768]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Store context size for use during model loading
|
||||
*/
|
||||
n_ctx_ = n_ctx;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user