mirror of
https://github.com/aaronpo97/the-biergarten-app.git
synced 2026-05-31 17:53:59 +00:00
Update main.cc
This commit is contained in:
@@ -5,6 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
#include <spdlog/fmt/fmt.h>
|
||||||
|
|
||||||
#include <boost/di.hpp>
|
#include <boost/di.hpp>
|
||||||
#include <boost/program_options.hpp>
|
#include <boost/program_options.hpp>
|
||||||
@@ -47,6 +48,13 @@ int main(const int argc, char** argv) {
|
|||||||
|
|
||||||
std::shared_ptr<ILogger> log_producer =
|
std::shared_ptr<ILogger> log_producer =
|
||||||
std::make_shared<LogProducer>(log_channel);
|
std::make_shared<LogProducer>(log_channel);
|
||||||
|
|
||||||
|
auto shutdown = [&](const int exit_code) {
|
||||||
|
log_channel.Close();
|
||||||
|
log_thread.join();
|
||||||
|
return exit_code;
|
||||||
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Timer timer;
|
Timer timer;
|
||||||
|
|
||||||
@@ -54,57 +62,70 @@ int main(const int argc, char** argv) {
|
|||||||
const LlamaBackendState llama_backend_state;
|
const LlamaBackendState llama_backend_state;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup, "STARTING PIPELINE");
|
||||||
"STARTING PIPELINE");
|
|
||||||
const std::optional<ApplicationOptions> parsed_options =
|
const std::optional<ApplicationOptions> parsed_options =
|
||||||
ParseArguments(argc, argv, log_producer);
|
ParseArguments(argc, argv, log_producer);
|
||||||
|
|
||||||
if (!parsed_options.has_value()) {
|
if (!parsed_options.has_value()) {
|
||||||
log_channel.Close();
|
return shutdown(0);
|
||||||
log_thread.join();
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto options = *parsed_options;
|
const auto options = *parsed_options;
|
||||||
const std::string model_path = options.generator.model_path.string();
|
const std::string model_path = options.generator.model_path.string();
|
||||||
const auto sampling =
|
const auto sampling = options.generator.sampling.value_or(SamplingOptions{});
|
||||||
options.generator.sampling.value_or(SamplingOptions{});
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Prompt directory
|
|
||||||
// Conditionally constructed before the injector; moved into LlamaGenerator.
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
std::unique_ptr<IPromptDirectory> prompt_directory;
|
std::unique_ptr<IPromptDirectory> prompt_directory;
|
||||||
|
|
||||||
if (!options.generator.use_mocked) {
|
if (!options.generator.use_mocked) {
|
||||||
try {
|
try {
|
||||||
prompt_directory = std::make_unique<PromptDirectory>(
|
prompt_directory = std::make_unique<PromptDirectory>(
|
||||||
options.pipeline.prompt_dir, log_producer);
|
options.pipeline.prompt_dir, log_producer);
|
||||||
} catch (const std::exception& dir_error) {
|
} catch (const std::exception& dir_error) {
|
||||||
log_producer->Log(
|
log_producer->Log(LogLevel::Error, PipelinePhase::Startup,
|
||||||
LogLevel::Error, PipelinePhase::Startup,
|
fmt::format("Invalid --prompt-dir: {}", dir_error.what()));
|
||||||
std::string("Invalid --prompt-dir: ") + dir_error.what());
|
return shutdown(1);
|
||||||
log_channel.Close();
|
|
||||||
log_thread.join();
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Dependency injection
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
const auto injector = di::make_injector(
|
const auto injector = di::make_injector(
|
||||||
di::bind<ILogger>().to(log_producer),
|
di::bind<ILogger>().to(log_producer),
|
||||||
di::bind<ApplicationOptions>().to(options),
|
di::bind<ApplicationOptions>().to(options),
|
||||||
di::bind<std::string>().to(model_path),
|
di::bind<std::string>().to(model_path),
|
||||||
di::bind<WebClient>().to<HttpWebClient>(),
|
|
||||||
di::bind<IExportService>().to<SqliteExportService>(),
|
di::bind<IExportService>().to<SqliteExportService>(),
|
||||||
di::bind<IPromptFormatter>().to<Gemma4JinjaPromptFormatter>(),
|
di::bind<IPromptFormatter>().to(
|
||||||
|
[options, log_producer] {
|
||||||
|
if (options.generator.use_mocked) {
|
||||||
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
|
"Prompt formatter: none (mock mode)");
|
||||||
|
return std::unique_ptr<IPromptFormatter>(nullptr);
|
||||||
|
}
|
||||||
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
|
"Prompt formatter: Gemma4JinjaPromptFormatter");
|
||||||
|
return std::unique_ptr<IPromptFormatter>(
|
||||||
|
std::make_unique<Gemma4JinjaPromptFormatter>());
|
||||||
|
}),
|
||||||
|
di::bind<WebClient>().to([options, log_producer] {
|
||||||
|
if (options.generator.use_mocked) {
|
||||||
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
|
"Web client: none (mock mode)");
|
||||||
|
return std::unique_ptr<WebClient>(nullptr);
|
||||||
|
}
|
||||||
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
|
"Web client: HttpWebClient");
|
||||||
|
return std::unique_ptr<WebClient>(
|
||||||
|
std::make_unique<HttpWebClient>(log_producer));
|
||||||
|
}),
|
||||||
di::bind<IEnrichmentService>().to(
|
di::bind<IEnrichmentService>().to(
|
||||||
[options, &log_producer](
|
[options, &log_producer](
|
||||||
const auto& inj) -> std::unique_ptr<IEnrichmentService> {
|
const auto& inj) -> std::unique_ptr<IEnrichmentService> {
|
||||||
// if (options.generator.use_mocked) {
|
if (options.generator.use_mocked) {
|
||||||
// return std::make_unique<MockEnrichmentService>();
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
// }
|
"Enrichment: mock");
|
||||||
|
return std::make_unique<MockEnrichmentService>();
|
||||||
|
}
|
||||||
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
|
"Enrichment: Wikipedia");
|
||||||
return std::make_unique<WikipediaEnrichmentService>(
|
return std::make_unique<WikipediaEnrichmentService>(
|
||||||
inj.template create<std::unique_ptr<WebClient>>(),
|
inj.template create<std::unique_ptr<WebClient>>(),
|
||||||
log_producer);
|
log_producer);
|
||||||
@@ -113,57 +134,45 @@ int main(const int argc, char** argv) {
|
|||||||
[&options, &model_path, &sampling, &prompt_directory,
|
[&options, &model_path, &sampling, &prompt_directory,
|
||||||
&log_producer](const auto& inj) -> std::unique_ptr<DataGenerator> {
|
&log_producer](const auto& inj) -> std::unique_ptr<DataGenerator> {
|
||||||
if (options.generator.use_mocked) {
|
if (options.generator.use_mocked) {
|
||||||
log_producer->Log(
|
log_producer->Log(LogLevel::Info, PipelinePhase::Startup,
|
||||||
LogLevel::Info, PipelinePhase::Startup,
|
"Generator: mock");
|
||||||
"Using MockGenerator (no model path provided)");
|
|
||||||
return std::make_unique<MockGenerator>();
|
return std::make_unique<MockGenerator>();
|
||||||
}
|
}
|
||||||
log_producer->Log(
|
log_producer->Log(
|
||||||
LogLevel::Info, PipelinePhase::Startup,
|
LogLevel::Info, PipelinePhase::Startup,
|
||||||
"Using LlamaGenerator: " + model_path +
|
fmt::format(
|
||||||
" (temperature=" + std::to_string(sampling.temperature) +
|
"Generator: LlamaGenerator | model={} | temp={:.2f} top_p={:.2f} top_k={} n_ctx={} seed={}",
|
||||||
", top-p=" + std::to_string(sampling.top_p) +
|
model_path,
|
||||||
", top-k=" + std::to_string(sampling.top_k) +
|
sampling.temperature,
|
||||||
", n_ctx=" + std::to_string(sampling.n_ctx) +
|
sampling.top_p,
|
||||||
", seed=" + std::to_string(sampling.seed) + ")");
|
sampling.top_k,
|
||||||
|
sampling.n_ctx,
|
||||||
|
sampling.seed));
|
||||||
return std::make_unique<LlamaGenerator>(
|
return std::make_unique<LlamaGenerator>(
|
||||||
options, model_path, log_producer,
|
options, model_path, log_producer,
|
||||||
inj.template create<std::unique_ptr<IPromptFormatter>>(),
|
inj.template create<std::unique_ptr<IPromptFormatter>>(),
|
||||||
std::move(prompt_directory));
|
std::move(prompt_directory));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Pipeline execution
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
const auto orchestrator =
|
const auto orchestrator =
|
||||||
injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>();
|
injector.create<std::unique_ptr<BiergartenPipelineOrchestrator>>();
|
||||||
|
|
||||||
if (!orchestrator->Run()) {
|
if (!orchestrator->Run()) {
|
||||||
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
|
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
|
||||||
"Pipeline execution failed");
|
"Pipeline execution failed");
|
||||||
log_channel.Close();
|
return shutdown(1);
|
||||||
log_thread.join();
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log_producer->Log(LogLevel::Info, PipelinePhase::Teardown,
|
log_producer->Log(LogLevel::Info, PipelinePhase::Teardown,
|
||||||
"Pipeline executed successfully in " +
|
fmt::format("Pipeline complete in {} ms", timer.Elapsed()));
|
||||||
std::to_string(timer.Elapsed()) + " ms");
|
|
||||||
|
|
||||||
log_channel.Close();
|
return shutdown(0);
|
||||||
log_thread.join();
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
} catch (const std::exception& exception) {
|
} catch (const std::exception& exception) {
|
||||||
// Attempt to use the logging infrastructure; if channel/dispatcher are
|
|
||||||
// compromised this is a best-effort fallback.
|
|
||||||
if (log_producer) {
|
if (log_producer) {
|
||||||
log_producer->Log(
|
log_producer->Log(LogLevel::Error, PipelinePhase::Teardown,
|
||||||
LogLevel::Error, PipelinePhase::Teardown,
|
fmt::format("Unhandled fatal error: {}", exception.what()));
|
||||||
std::string("Unhandled fatal error in main: ") + exception.what());
|
|
||||||
}
|
}
|
||||||
log_channel.Close();
|
return shutdown(1);
|
||||||
log_thread.join();
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user