libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
trainer.hpp
1#pragma once
2
3#include "../model/model.hpp"
4#include "dataloader.hpp"
5#include "dataset.hpp"
6#include "optimizer.hpp"
7
8#include <functional>
9#include <memory>
10#include <variant>
11#include <vector>
12
13namespace dl {
14 class ModelBase;
15 template <typename>
16 class Model;
17 class Device;
18 template <typename>
19 class Dataset;
20
21 enum class TrainStage { Fitting, Evaluation, Validation };
22
23 class TrainerSubject;
24
26 private:
27 public:
28 virtual ~TrainerObserver() = default;
29 virtual void setSubject(TrainerSubject& trainer) = 0;
30 virtual void onTrainingBegun(const ModelBase& model) = 0;
31 virtual void onTrainingEnded(const ModelBase& model) = 0;
32 virtual void enterTrainingStage(TrainStage stage) = 0;
33 virtual void exitTrainingStage() = 0;
34 virtual void progressChanged(size_t epoch, size_t total, size_t step) = 0;
35 };
36
37 template <typename Model, typename Dataset, typename Optimizer>
38 struct TrainerConf {
39 public:
43 };
44
45 template <typename Model = void*, typename Dataset = void*, typename Optimizer = void*>
46 class TrainerConfBuilder final {
47 template <typename M, typename D, typename O>
48 friend class TrainerConfBuilder;
49
50 private:
51 TrainerConf<Model, Dataset, Optimizer> conf{.observers = {}, .dataset = nullptr, .optimizer = nullptr};
52
53 explicit TrainerConfBuilder(TrainerConf<Model, Dataset, Optimizer>&& conf) noexcept : conf(std::move(conf)) {}
54
55 public:
56 TrainerConfBuilder() noexcept {}
57
58 template <typename DSet, typename... Args>
59 TrainerConfBuilder<Model, DSet, Optimizer> setDataset(Args&&... args) noexcept {
61 .observers = std::move(conf.observers),
62 .dataset = std::make_unique<DSet>(std::forward<Args>(args)...),
63 .optimizer = std::move(conf.optimizer)
64 }};
65 }
66
67 template <typename Opt, typename... Args>
68 TrainerConfBuilder<Model, Dataset, Opt> setOptimizer(Args&&... args) noexcept {
70 .observers = std::move(conf.observers),
71 .dataset = std::move(conf.dataset),
72 .optimizer = std::make_unique<Opt>(std::forward<Args>(args)...)
73 }};
74 }
75
76 template <typename O, typename... Args>
77 TrainerConfBuilder& addObserver(Args&&... args) noexcept {
78 return addObserver(std::make_unique<O>(std::forward<Args>(args)...));
79 }
80
81 TrainerConfBuilder& addObserver(std::unique_ptr<TrainerObserver> observer) noexcept {
82 conf.observers.push_back(std::move(observer));
83 return *this;
84 }
85
86 TrainerConf<Model, Dataset, Optimizer> build() noexcept { return std::move(conf); }
87 };
88
90 private:
91 bool stopped;
92
93 protected:
94 TrainerSubject() noexcept : stopped(false) {}
95 void setRunning() noexcept { stopped = false; }
96
97 public:
102 void stop() noexcept { stopped = true; }
103
104 bool isRunning() const noexcept { return !stopped; }
105 };
106
107 template <typename Model, typename Dataset, typename Optimizer>
108 class Trainer final : TrainerSubject {
109 private:
111
112 Trainer(const Trainer&) = delete;
113 Trainer(Trainer&& other) = delete;
114 Trainer& operator=(const Trainer&) = delete;
115 Trainer& operator=(Trainer&&) = delete;
116
117 template <typename Callable, typename... Args>
118 void notify(Callable&& fn, Args&&... args) {
119 for (auto&& observer : conf.observers)
120 std::invoke(fn, observer, std::forward<Args>(args)...);
121 }
122
123 public:
125 notify(&TrainerObserver::setSubject, (TrainerSubject&)*this);
126 }
127
128 void fit(Model& model, auto adapter) {
129 setRunning();
130 auto& dataset = conf.dataset;
131 assert(dataset != nullptr);
132 auto dataloader = dataset->trainingData();
133 assert(dataloader != nullptr);
134 notify(&TrainerObserver::onTrainingBegun, model);
135 const auto trainsetSize = std::distance(std::begin(*dataloader), std::end(*dataloader));
136 for (size_t epoch = 0; isRunning(); ++epoch) {
137 size_t progress = 0;
138 for (auto&& [out, in] : *dataloader) {
139 if (!isRunning())
140 break;
141 auto loss = adapter(model, in, out);
142 conf.optimizer->step(loss);
144 notify(&TrainerObserver::progressChanged, epoch, trainsetSize, progress);
145 progress += 1;
146 }
147 }
148 auto tmp = &TrainerObserver::onTrainingEnded;
149 notify(&TrainerObserver::onTrainingEnded, model);
150 }
151 void validate(Model& model, auto evaluator, auto adapter) {}
152 auto test(Model& model, auto evaluator, auto adapter) {
153 auto& dataset = conf.dataset;
154 assert(dataset != nullptr);
155 auto dataloader = dataset->testData();
156 assert(dataloader != nullptr);
157 for (auto&& [out, in] : *dataloader) {
158 evaluator += adapter(model, in, out);
159 }
160 return evaluator.aggregated();
161 }
162 };
163
165
166 namespace detail {
172 template <typename T>
174 using type = T::signature;
175 };
176 } // namespace detail
177
183 template <typename T>
184 using ModelSignature = typename detail::_ModelSignature<T>::type;
185
186 namespace observers {
187 std::unique_ptr<TrainerObserver> limitEpochs(size_t epochs) noexcept;
188 std::unique_ptr<TrainerObserver> earlyStopping(size_t patience) noexcept;
189 std::unique_ptr<TrainerObserver> consoleUI() noexcept;
190 } // namespace observers
191} // namespace dl
T begin(T... args)
void stop() noexcept
Stops any currently running training, validation and test processes.
Definition trainer.hpp:102
void fit(Model &model, auto adapter)
Definition trainer.hpp:128
T distance(T... args)
T end(T... args)
T move(T... args)
Infers the model signature from the provided model type.
Definition trainer.hpp:173