Update main.cc

This commit is contained in:
Aaron Po
2026-05-17 01:29:32 -04:00
parent c58e4c1986
commit 5d80b53351

View File

@@ -5,6 +5,7 @@
*/
#include <spdlog/spdlog.h>
#include <spdlog/fmt/fmt.h>
#include <boost/di.hpp>
#include <boost/program_options.hpp>
@@ -47,6 +48,13 @@ int main(const int argc, char** argv) {
std::shared_ptr<ILogger> log_producer =
std::make_shared<LogProducer>(log_channel);
auto shutdown = [&](const int exit_code) {
log_channel.Close();
log_thread.join();
return exit_code;
};
try {
Timer timer;
@@ -54,57 +62,70 @@ int main(const int argc, char** argv) {
const LlamaBackendState llama_backend_state;
#endif
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"STARTING PIPELINE");
log_producer->Log(LogLevel::Info, PipelinePhase::Startup, "STARTING PIPELINE");
const std::optional<ApplicationOptions> parsed_options =
ParseArguments(argc, argv, log_producer);
if (!parsed_options.has_value()) {
log_channel.Close();
log_thread.join();
return 0;
return shutdown(0);
}
const auto options = *parsed_options;
const std::string model_path = options.generator.model_path.string();
const auto sampling =
options.generator.sampling.value_or(SamplingOptions{});
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
// -----------------------------------------------------------------------
// Prompt directory
// Conditionally constructed before the injector; moved into LlamaGenerator.
// -----------------------------------------------------------------------
std::unique_ptr<IPromptDirectory> prompt_directory;
if (!options.generator.use_mocked) {
try {
prompt_directory = std::make_unique<PromptDirectory>(
options.pipeline.prompt_dir, log_producer);
} catch (const std::exception& dir_error) {
log_producer->Log(
LogLevel::Error, PipelinePhase::Startup,
std::string("Invalid --prompt-dir: ") + dir_error.what());
log_channel.Close();
log_thread.join();
return 1;
log_producer->Log(LogLevel::Error, PipelinePhase::Startup,
fmt::format("Invalid --prompt-dir: {}", dir_error.what()));
return shutdown(1);
}
}
// -----------------------------------------------------------------------
// Dependency injection
// -----------------------------------------------------------------------
const auto injector = di::make_injector(
di::bind<ILogger>().to(log_producer),
di::bind<ApplicationOptions>().to(options),
di::bind<std::string>().to(model_path),
di::bind<WebClient>().to<HttpWebClient>(),
di::bind<IExportService>().to<SqliteExportService>(),
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
di::bind<IPromptFormatter>().to(
[options, log_producer] {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Prompt formatter: none (mock mode)");
return std::unique_ptr<IPromptFormatter>(nullptr);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Prompt formatter: Gemma4JinjaPromptFormatter");
return std::unique_ptr<IPromptFormatter>(
std::make_unique<Gemma4JinjaPromptFormatter>());
}),
di::bind<WebClient>().to([options, log_producer] {
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Web client: none (mock mode)");
return std::unique_ptr<WebClient>(nullptr);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Web client: HttpWebClient");
return std::unique_ptr<WebClient>(
std::make_unique<HttpWebClient>(log_producer));
}),
di::bind<IEnrichmentService>().to(
[options, &log_producer](
const auto& inj) -> std::unique_ptr<IEnrichmentService> {
// if (options.generator.use_mocked) {
// return std::make_unique<MockEnrichmentService>();
// }
if (options.generator.use_mocked) {
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Enrichment: mock");
return std::make_unique<MockEnrichmentService>();
}
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Enrichment: Wikipedia");
return std::make_unique<WikipediaEnrichmentService>(
inj.template create<std::unique_ptr<WebClient>>(),
log_producer);
@@ -113,57 +134,45 @@ int main(const int argc, char** argv) {
[&options, &model_path, &sampling, &prompt_directory,
&log_producer](const auto& inj) -> std::unique_ptr<DataGenerator> {
if (options.generator.use_mocked) {
log_producer->Log(
LogLevel::Info, PipelinePhase::Startup,
"Using MockGenerator (no model path provided)");
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
"Generator: mock");
return std::make_unique<MockGenerator>();
}
log_producer->Log(
LogLevel::Info, PipelinePhase::Startup,
"Using LlamaGenerator: " + model_path +
" (temperature=" + std::to_string(sampling.temperature) +
", top-p=" + std::to_string(sampling.top_p) +
", top-k=" + std::to_string(sampling.top_k) +
", n_ctx=" + std::to_string(sampling.n_ctx) +
", seed=" + std::to_string(sampling.seed) + ")");
fmt::format(
"Generator: LlamaGenerator | model={} | temp={:.2f} top_p={:.2f} top_k={} n_ctx={} seed={}",
model_path,
sampling.temperature,
sampling.top_p,
sampling.top_k,
sampling.n_ctx,
sampling.seed));
return std::make_unique<LlamaGenerator>(
options, model_path, log_producer,
inj.template create<std::unique_ptr<IPromptFormatter>>(),
std::move(prompt_directory));
}));
// -----------------------------------------------------------------------
// Pipeline execution
// -----------------------------------------------------------------------
const auto orchestrator =
injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>();
if (!orchestrator->Run()) {
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
"Pipeline execution failed");
log_channel.Close();
log_thread.join();
return 1;
return shutdown(1);
}
log_producer->Log(LogLevel::Info, PipelinePhase::Teardown,
"Pipeline executed successfully in " +
std::to_string(timer.Elapsed()) + " ms");
fmt::format("Pipeline complete in {} ms", timer.Elapsed()));
log_channel.Close();
log_thread.join();
return 0;
return shutdown(0);
} catch (const std::exception& exception) {
// Attempt to use the logging infrastructure; if channel/dispatcher are
// compromised this is a best-effort fallback.
if (log_producer) {
log_producer->Log(
LogLevel::Error, PipelinePhase::Teardown,
std::string("Unhandled fatal error in main: ") + exception.what());
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
fmt::format("Unhandled fatal error: {}", exception.what()));
}
log_channel.Close();
log_thread.join();
return 1;
return shutdown(1);
}
}