Update main.cc

This commit is contained in:
Aaron Po
2026-05-17 01:29:32 -04:00
parent c58e4c1986
commit 5d80b53351

View File

@@ -5,6 +5,7 @@
*/ */
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
#include <spdlog/fmt/fmt.h>
#include <boost/di.hpp> #include <boost/di.hpp>
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
@@ -47,6 +48,13 @@ int main(const int argc, char** argv) {
std::shared_ptr<ILogger> log_producer = std::shared_ptr<ILogger> log_producer =
std::make_shared<LogProducer>(log_channel); std::make_shared<LogProducer>(log_channel);
auto shutdown = [&](const int exit_code) {
log_channel.Close();
log_thread.join();
return exit_code;
};
try { try {
Timer timer; Timer timer;
@@ -54,57 +62,70 @@ int main(const int argc, char** argv) {
const LlamaBackendState llama_backend_state; const LlamaBackendState llama_backend_state;
#endif #endif
log_producer->Log(LogLevel::Info, PipelinePhase::Startup, log_producer->Log(LogLevel::Info, PipelinePhase::Startup, "STARTING PIPELINE");
"STARTING PIPELINE");
const std::optional<ApplicationOptions> parsed_options = const std::optional<ApplicationOptions> parsed_options =
ParseArguments(argc, argv, log_producer); ParseArguments(argc, argv, log_producer);
if (!parsed_options.has_value()) { if (!parsed_options.has_value()) {
log_channel.Close(); return shutdown(0);
log_thread.join();
return 0;
} }
const auto options = *parsed_options; const auto options = *parsed_options;
const std::string model_path = options.generator.model_path.string(); const std::string model_path = options.generator.model_path.string();
const auto sampling = const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
options.generator.sampling.value_or(SamplingOptions{});
// -----------------------------------------------------------------------
// Prompt directory
// Conditionally constructed before the injector; moved into LlamaGenerator.
// -----------------------------------------------------------------------
std::unique_ptr<IPromptDirectory> prompt_directory; std::unique_ptr<IPromptDirectory> prompt_directory;
if (!options.generator.use_mocked) { if (!options.generator.use_mocked) {
try { try {
prompt_directory = std::make_unique<PromptDirectory>( prompt_directory = std::make_unique<PromptDirectory>(
options.pipeline.prompt_dir, log_producer); options.pipeline.prompt_dir, log_producer);
} catch (const std::exception& dir_error) { } catch (const std::exception& dir_error) {
log_producer->Log( log_producer->Log(LogLevel::Error, PipelinePhase::Startup,
LogLevel::Error, PipelinePhase::Startup, fmt::format("Invalid --prompt-dir: {}", dir_error.what()));
std::string("Invalid --prompt-dir: ") + dir_error.what()); return shutdown(1);
log_channel.Close();
log_thread.join();
return 1;
} }
} }
// -----------------------------------------------------------------------
// Dependency injection
// -----------------------------------------------------------------------
const auto injector = di::make_injector( const auto injector = di::make_injector(
di::bind<ILogger>().to(log_producer), di::bind<ILogger>().to(log_producer),
di::bind<ApplicationOptions>().to(options), di::bind<ApplicationOptions>().to(options),
di::bind<std::string>().to(model_path), di::bind<std::string>().to(model_path),
di::bind<WebClient>().to<HttpWebClient>(),
di::bind<IExportService>().to<SqliteExportService>(), di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(), di::bind<IPromptFormatter>().to(
[options, log_producer] {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Prompt formatter: none (mock mode)");
return std::unique_ptr<IPromptFormatter>(nullptr);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Prompt formatter: Gemma4JinjaPromptFormatter");
return std::unique_ptr<IPromptFormatter>(
std::make_unique<Gemma4JinjaPromptFormatter>());
}),
di::bind<WebClient>().to([options, log_producer] {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Web client: none (mock mode)");
return std::unique_ptr<WebClient>(nullptr);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Web client: HttpWebClient");
return std::unique_ptr<WebClient>(
std::make_unique<HttpWebClient>(log_producer));
}),
di::bind<IEnrichmentService>().to( di::bind<IEnrichmentService>().to(
[options, &log_producer]( [options, &log_producer](
const auto& inj) -> std::unique_ptr<IEnrichmentService> { const auto& inj) -> std::unique_ptr<IEnrichmentService> {
// if (options.generator.use_mocked) { if (options.generator.use_mocked) {
// return std::make_unique<MockEnrichmentService>(); log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
// } "Enrichment: mock");
return std::make_unique<MockEnrichmentService>();
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Enrichment: Wikipedia");
return std::make_unique<WikipediaEnrichmentService>( return std::make_unique<WikipediaEnrichmentService>(
inj.template create<std::unique_ptr<WebClient>>(), inj.template create<std::unique_ptr<WebClient>>(),
log_producer); log_producer);
@@ -113,57 +134,45 @@ int main(const int argc, char** argv) {
[&options, &model_path, &sampling, &prompt_directory, [&options, &model_path, &sampling, &prompt_directory,
&log_producer](const auto& inj) -> std::unique_ptr<DataGenerator> { &log_producer](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) { if (options.generator.use_mocked) {
log_producer->Log( log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
LogLevel::Info, PipelinePhase::Startup, "Generator: mock");
"Using MockGenerator (no model path provided)");
return std::make_unique<MockGenerator>(); return std::make_unique<MockGenerator>();
} }
log_producer->Log( log_producer->Log(
LogLevel::Info, PipelinePhase::Startup, LogLevel::Info, PipelinePhase::Startup,
"Using LlamaGenerator: " + model_path + fmt::format(
" (temperature=" + std::to_string(sampling.temperature) + "Generator: LlamaGenerator | model={} | temp={:.2f} top_p={:.2f} top_k={} n_ctx={} seed={}",
", top-p=" + std::to_string(sampling.top_p) + model_path,
", top-k=" + std::to_string(sampling.top_k) + sampling.temperature,
", n_ctx=" + std::to_string(sampling.n_ctx) + sampling.top_p,
", seed=" + std::to_string(sampling.seed) + ")"); sampling.top_k,
sampling.n_ctx,
sampling.seed));
return std::make_unique<LlamaGenerator>( return std::make_unique<LlamaGenerator>(
options, model_path, log_producer, options, model_path, log_producer,
inj.template create<std::unique_ptr<IPromptFormatter>>(), inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory)); std::move(prompt_directory));
})); }));
// -----------------------------------------------------------------------
// Pipeline execution
// -----------------------------------------------------------------------
const auto orchestrator = const auto orchestrator =
injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>(); injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>();
if (!orchestrator->Run()) { if (!orchestrator->Run()) {
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown, log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
"Pipeline execution failed"); "Pipeline execution failed");
log_channel.Close(); return shutdown(1);
log_thread.join();
return 1;
} }
log_producer->Log(LogLevel::Info, PipelinePhase::Teardown, log_producer->Log(LogLevel::Info, PipelinePhase::Teardown,
"Pipeline executed successfully in " + fmt::format("Pipeline complete in {} ms", timer.Elapsed()));
std::to_string(timer.Elapsed()) + " ms");
log_channel.Close(); return shutdown(0);
log_thread.join();
return 0;
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
// Attempt to use the logging infrastructure; if channel/dispatcher are
// compromised this is a best-effort fallback.
if (log_producer) { if (log_producer) {
log_producer->Log( log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
LogLevel::Error, PipelinePhase::Teardown, fmt::format("Unhandled fatal error: {}", exception.what()));
std::string("Unhandled fatal error in main: ") + exception.what());
} }
log_channel.Close(); return shutdown(1);
log_thread.join();
return 1;
} }
} }