Files
the-biergarten-app/tooling/pipeline/src/main.cc
2026-05-17 01:29:32 -04:00

178 lines
6.8 KiB
C++

/**
* @file main.cc
* @brief Parses command-line options, validates runtime mode selection,
* initializes shared infrastructure, and executes the pipeline entry flow.
*/
#include <spdlog/spdlog.h>
#include <spdlog/fmt/fmt.h>
#include <boost/di.hpp>
#include <boost/program_options.hpp>
#include <exception>
#include <memory>
#include <optional>
#include <string>
#include <thread>
#include "biergarten_pipeline_orchestrator.h"
#include "concurrency/bounded_channel.h"
#include "data_generation/llama_generator.h"
#include "data_generation/mock_generator.h"
#include "data_generation/prompt_formatting/gemma4_jinja_prompt_formatter.h"
#include "data_model/models.h"
#include "llama_backend_state.h"
#include "services/database/export_service.h"
#include "services/database/sqlite_export_service.h"
#include "services/datetime/timer.h"
#include "services/enrichment/enrichment_service.h"
#include "services/enrichment/mock_enrichment.h"
#include "services/enrichment/wikipedia_service.h"
#include "services/logging/log_dispatcher.h"
#include "services/logging/log_entry.h"
#include "services/logging/log_producer.h"
#include "services/logging/logger.h"
#include "services/prompting/prompt_directory.h"
#include "web_client/http_web_client.h"
namespace di = boost::di;
static constexpr size_t kLogMaxCount = 512;
int main(const int argc, char** argv) {
spdlog::set_level(spdlog::level::debug);
spdlog::set_pattern("│ %Y-%m-%d %H:%M:%S.%e │ %^%-7l%$ │ %v");
BoundedChannel<LogEntry> log_channel(kLogMaxCount);
auto log_dispatcher = std::make_unique<LogDispatcher>(log_channel);
std::thread log_thread([&log_dispatcher] { log_dispatcher->Run(); });
std::shared_ptr<ILogger> log_producer =
std::make_shared<LogProducer>(log_channel);
auto shutdown = [&](const int exit_code) {
log_channel.Close();
log_thread.join();
return exit_code;
};
try {
Timer timer;
#ifndef BIERGARTEN_MOCK_ONLY
const LlamaBackendState llama_backend_state;
#endif
log_producer->Log(LogLevel::Info, PipelinePhase::Startup, "STARTING PIPELINE");
const std::optional<ApplicationOptions> parsed_options =
ParseArguments(argc, argv, log_producer);
if (!parsed_options.has_value()) {
return shutdown(0);
}
const auto options = *parsed_options;
const std::string model_path = options.generator.model_path.string();
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
std::unique_ptr<IPromptDirectory> prompt_directory;
if (!options.generator.use_mocked) {
try {
prompt_directory = std::make_unique<PromptDirectory>(
options.pipeline.prompt_dir, log_producer);
} catch (const std::exception& dir_error) {
log_producer->Log(LogLevel::Error, PipelinePhase::Startup,
fmt::format("Invalid --prompt-dir: {}", dir_error.what()));
return shutdown(1);
}
}
const auto injector = di::make_injector(
di::bind<ILogger>().to(log_producer),
di::bind<ApplicationOptions>().to(options),
di::bind<std::string>().to(model_path),
di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to(
[options, log_producer] {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Prompt formatter: none (mock mode)");
return std::unique_ptr<IPromptFormatter>(nullptr);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Prompt formatter: Gemma4JinjaPromptFormatter");
return std::unique_ptr<IPromptFormatter>(
std::make_unique<Gemma4JinjaPromptFormatter>());
}),
di::bind<WebClient>().to([options, log_producer] {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Web client: none (mock mode)");
return std::unique_ptr<WebClient>(nullptr);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Web client: HttpWebClient");
return std::unique_ptr<WebClient>(
std::make_unique<HttpWebClient>(log_producer));
}),
di::bind<IEnrichmentService>().to(
[options, &log_producer](
const auto& inj) -> std::unique_ptr<IEnrichmentService> {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Enrichment: mock");
return std::make_unique<MockEnrichmentService>();
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Enrichment: Wikipedia");
return std::make_unique<WikipediaEnrichmentService>(
inj.template create<std::unique_ptr<WebClient>>(),
log_producer);
}),
di::bind<DataGenerator>().to(
[&options, &model_path, &sampling, &prompt_directory,
&log_producer](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Generator: mock");
return std::make_unique<MockGenerator>();
}
log_producer->Log(
LogLevel::Info, PipelinePhase::Startup,
fmt::format(
"Generator: LlamaGenerator | model={} | temp={:.2f} top_p={:.2f} top_k={} n_ctx={} seed={}",
model_path,
sampling.temperature,
sampling.top_p,
sampling.top_k,
sampling.n_ctx,
sampling.seed));
return std::make_unique<LlamaGenerator>(
options, model_path, log_producer,
inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory));
}));
const auto orchestrator =
injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>();
if (!orchestrator->Run()) {
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
"Pipeline execution failed");
return shutdown(1);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Teardown,
fmt::format("Pipeline complete in {} ms", timer.Elapsed()));
return shutdown(0);
} catch (const std::exception& exception) {
if (log_producer) {
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
fmt::format("Unhandled fatal error: {}", exception.what()));
}
return shutdown(1);
}
}