From 5d80b533514b5b3bffaae9f010efec65e872beaf Mon Sep 17 00:00:00 2001 From: Aaron Po Date: Sun, 17 May 2026 01:29:32 -0400 Subject: [PATCH] Update main.cc --- tooling/pipeline/src/main.cc | 115 +++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 53 deletions(-) diff --git a/tooling/pipeline/src/main.cc b/tooling/pipeline/src/main.cc index fbfd301..f586f4d 100644 --- a/tooling/pipeline/src/main.cc +++ b/tooling/pipeline/src/main.cc @@ -5,6 +5,7 @@ */ #include +#include #include #include @@ -47,6 +48,13 @@ int main(const int argc, char** argv) { std::shared_ptr log_producer = std::make_shared(log_channel); + + auto shutdown = [&](const int exit_code) { + log_channel.Close(); + log_thread.join(); + return exit_code; + }; + try { Timer timer; @@ -54,57 +62,70 @@ int main(const int argc, char** argv) { const LlamaBackendState llama_backend_state; #endif - log_producer->Log(LogLevel::Info, PipelinePhase::Startup, - "STARTING PIPELINE"); + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, "STARTING PIPELINE"); + const std::optional parsed_options = ParseArguments(argc, argv, log_producer); if (!parsed_options.has_value()) { - log_channel.Close(); - log_thread.join(); - return 0; + return shutdown(0); } const auto options = *parsed_options; const std::string model_path = options.generator.model_path.string(); - const auto sampling = - options.generator.sampling.value_or(SamplingOptions{}); + const auto sampling = options.generator.sampling.value_or(SamplingOptions{}); - // ----------------------------------------------------------------------- - // Prompt directory - // Conditionally constructed before the injector; moved into LlamaGenerator. - // ----------------------------------------------------------------------- std::unique_ptr prompt_directory; + if (!options.generator.use_mocked) { try { prompt_directory = std::make_unique( options.pipeline.prompt_dir, log_producer); } catch (const std::exception& dir_error) { - log_producer->Log( - LogLevel::Error, PipelinePhase::Startup, - std::string("Invalid --prompt-dir: ") + dir_error.what()); - log_channel.Close(); - log_thread.join(); - return 1; + log_producer->Log(LogLevel::Error, PipelinePhase::Startup, + fmt::format("Invalid --prompt-dir: {}", dir_error.what())); + return shutdown(1); } } - // ----------------------------------------------------------------------- - // Dependency injection - // ----------------------------------------------------------------------- const auto injector = di::make_injector( di::bind().to(log_producer), di::bind().to(options), di::bind().to(model_path), - di::bind().to(), di::bind().to(), - di::bind().to(), + di::bind().to( + [options, log_producer] { + if (options.generator.use_mocked) { + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Prompt formatter: none (mock mode)"); + return std::unique_ptr(nullptr); + } + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Prompt formatter: Gemma4JinjaPromptFormatter"); + return std::unique_ptr( + std::make_unique()); + }), + di::bind().to([options, log_producer] { + if (options.generator.use_mocked) { + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Web client: none (mock mode)"); + return std::unique_ptr(nullptr); + } + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Web client: HttpWebClient"); + return std::unique_ptr( + std::make_unique(log_producer)); + }), di::bind().to( [options, &log_producer]( const auto& inj) -> std::unique_ptr { - // if (options.generator.use_mocked) { - // return std::make_unique(); - // } + if (options.generator.use_mocked) { + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Enrichment: mock"); + return std::make_unique(); + } + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Enrichment: Wikipedia"); return std::make_unique( inj.template create>(), log_producer); @@ -113,57 +134,45 @@ int main(const int argc, char** argv) { [&options, &model_path, &sampling, &prompt_directory, &log_producer](const auto& inj) -> std::unique_ptr { if (options.generator.use_mocked) { - log_producer->Log( - LogLevel::Info, PipelinePhase::Startup, - "Using MockGenerator (no model path provided)"); + log_producer->Log(LogLevel::Info, PipelinePhase::Startup, + "Generator: mock"); return std::make_unique(); } log_producer->Log( LogLevel::Info, PipelinePhase::Startup, - "Using LlamaGenerator: " + model_path + - " (temperature=" + std::to_string(sampling.temperature) + - ", top-p=" + std::to_string(sampling.top_p) + - ", top-k=" + std::to_string(sampling.top_k) + - ", n_ctx=" + std::to_string(sampling.n_ctx) + - ", seed=" + std::to_string(sampling.seed) + ")"); + fmt::format( + "Generator: LlamaGenerator | model={} | temp={:.2f} top_p={:.2f} top_k={} n_ctx={} seed={}", + model_path, + sampling.temperature, + sampling.top_p, + sampling.top_k, + sampling.n_ctx, + sampling.seed)); return std::make_unique( options, model_path, log_producer, inj.template create>(), std::move(prompt_directory)); })); - // ----------------------------------------------------------------------- - // Pipeline execution - // ----------------------------------------------------------------------- const auto orchestrator = injector.create>(); if (!orchestrator->Run()) { log_producer->Log(LogLevel::Error, PipelinePhase::Teardown, "Pipeline execution failed"); - log_channel.Close(); - log_thread.join(); - return 1; + return shutdown(1); } log_producer->Log(LogLevel::Info, PipelinePhase::Teardown, - "Pipeline executed successfully in " + - std::to_string(timer.Elapsed()) + " ms"); + fmt::format("Pipeline complete in {} ms", timer.Elapsed())); - log_channel.Close(); - log_thread.join(); - return 0; + return shutdown(0); } catch (const std::exception& exception) { - // Attempt to use the logging infrastructure; if channel/dispatcher are - // compromised this is a best-effort fallback. if (log_producer) { - log_producer->Log( - LogLevel::Error, PipelinePhase::Teardown, - std::string("Unhandled fatal error in main: ") + exception.what()); + log_producer->Log(LogLevel::Error, PipelinePhase::Teardown, + fmt::format("Unhandled fatal error: {}", exception.what())); } - log_channel.Close(); - log_thread.join(); - return 1; + return shutdown(1); } } \ No newline at end of file