Implement pipeline logging with bounded channels and orchestrator integration

This commit is contained in:
Aaron Po
2026-05-14 21:08:17 -04:00
parent f93b14897b
commit 74f11b57e2
16 changed files with 361 additions and 61 deletions

View File

@@ -12,8 +12,10 @@
#include <memory>
#include <optional>
#include <string>
#include <thread>
#include "biergarten_data_generator.h"
#include "concurrency/bounded_channel.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
@@ -25,12 +27,22 @@
#include "services/enrichment/enrichment_service.h"
#include "services/enrichment/mock_enrichment.h"
#include "services/enrichment/wikipedia_service.h"
#include "services/logging/channel_logger.h"
#include "services/logging/log_consumer.h"
#include "services/logging/log_entry.h"
#include "services/logging/logger.h"
#include "services/prompting/prompt_directory.h"
#include "web_client/http_web_client.h"
namespace di = boost::di;
static constexpr size_t kLogMaxCount = 512;
int main(const int argc, char** argv) {
auto log_channel = std::make_shared<BoundedChannel<LogEntry>>(kLogMaxCount);
ChannelLogger channel_logger(*log_channel);
LogConsumer log_worker(*log_channel);
std::thread log_thread([&log_worker] { log_worker.Run(); });
try {
Timer timer;
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v");
@@ -46,6 +58,8 @@ int main(const int argc, char** argv) {
ParseArguments(argc, argv);
if (!parsed_options.has_value()) {
log_channel->Close();
log_thread.join();
return 0;
}
@@ -54,66 +68,97 @@ int main(const int argc, char** argv) {
const auto sampling =
options.generator.sampling.value_or(SamplingOptions{});
// -----------------------------------------------------------------------
// Prompt directory
// Conditionally constructed before the injector; moved into LlamaGenerator.
// -----------------------------------------------------------------------
std::unique_ptr<IPromptDirectory> prompt_directory;
if (!options.generator.use_mocked) {
try {
prompt_directory =
std::make_unique<PromptDirectory>(options.pipeline.prompt_dir);
} catch (const std::exception& dir_error) {
spdlog::error("[Startup] Invalid --prompt-dir: {}", dir_error.what());
channel_logger.Log(
LogLevel::Error, PipelinePhase::Startup,
std::string("Invalid --prompt-dir: ") + dir_error.what());
log_channel->Close();
log_thread.join();
return 1;
}
}
// -----------------------------------------------------------------------
// Dependency injection
// -----------------------------------------------------------------------
const auto injector = di::make_injector(
di::bind<ApplicationOptions>().to(options),
di::bind<std::string>().to(model_path),
di::bind<WebClient>().to<HttpWebClient>(),
di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<ILogger>().to(
[log_channel](const auto&) -> std::unique_ptr<ILogger> {
return std::make_unique<ChannelLogger>(*log_channel);
}),
di::bind<IEnrichmentService>().to(
[options](const auto& inj) -> std::unique_ptr<IEnrichmentService> {
if (options.generator.use_mocked) {
return std::make_unique<MockEnrichmentService>();
}
return std::make_unique<WikipediaEnrichmentService>(
inj.template create<std::unique_ptr<WebClient>>());
}),
di::bind<DataGenerator>().to(
[options, model_path, sampling, &prompt_directory](
[&options, &model_path, &sampling, &prompt_directory,
&channel_logger](
const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) {
spdlog::info(
"[Generator] Using MockGenerator (no model path provided)");
channel_logger.Log(
LogLevel::Info, PipelinePhase::Startup,
"Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>();
}
spdlog::info(
"[Generator] Using LlamaGenerator: {} (temperature={}, "
"top-p={}, top-k={}, n_ctx={}, seed={})",
model_path, sampling.temperature, sampling.top_p,
sampling.top_k, sampling.n_ctx, sampling.seed);
channel_logger.Log(
LogLevel::Info, PipelinePhase::Startup,
"Using LlamaGenerator: " + model_path +
" (temperature=" + std::to_string(sampling.temperature) +
", top-p=" + std::to_string(sampling.top_p) +
", top-k=" + std::to_string(sampling.top_k) +
", n_ctx=" + std::to_string(sampling.n_ctx) +
", seed=" + std::to_string(sampling.seed) + ")");
return std::make_unique<LlamaGenerator>(
options, model_path,
inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory));
})
}));
);
// -----------------------------------------------------------------------
// Pipeline execution
// -----------------------------------------------------------------------
const auto orchestrator =
injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>();
const auto generator =
injector.create<std::unique_ptr<BiergartenDataGenerator>>();
if (!generator->Run()) {
spdlog::error("Pipeline execution failed");
if (!orchestrator->Run()) {
channel_logger.Log(LogLevel::Error, PipelinePhase::Teardown,
"Pipeline execution failed");
log_channel->Close();
log_thread.join();
return 1;
}
spdlog::info("Pipeline executed successfully in {} ms", timer.Elapsed());
channel_logger.Log(LogLevel::Info, PipelinePhase::Teardown,
"Pipeline executed successfully in " +
std::to_string(timer.Elapsed()) + " ms");
log_channel->Close();
log_thread.join();
return 0;
} catch (const std::exception& exception) {
// Channel may be in an unknown state; fall back to spdlog directly.
spdlog::critical("Unhandled fatal error in main: {}", exception.what());
log_channel->Close();
log_thread.join();
return 1;
}
}
}