libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
dl::Trainer< Model, Dataset, Optimizer > Class Template Referencefinal
Inheritance diagram for dl::Trainer< Model, Dataset, Optimizer >:
Collaboration diagram for dl::Trainer< Model, Dataset, Optimizer >:

Public Member Functions

 Trainer (TrainerConf< Model, Dataset, Optimizer > &&conf)
 
void fit (Model &model, auto adapter)
 
void validate (Model &model, auto evaluator, auto adapter)
 
auto test (Model &model, auto evaluator, auto adapter)
 

Detailed Description

template<typename Model, typename Dataset, typename Optimizer>
class dl::Trainer< Model, Dataset, Optimizer >

Definition at line 108 of file trainer.hpp.

Constructor & Destructor Documentation

◆ Trainer()

template<typename Model , typename Dataset , typename Optimizer >
dl::Trainer< Model, Dataset, Optimizer >::Trainer ( TrainerConf< Model, Dataset, Optimizer > &&  conf)
inline

Definition at line 124 of file trainer.hpp.

124 : conf(std::move(conf)) {
125 notify(&TrainerObserver::setSubject, (TrainerSubject&)*this);
126 }
T move(T... args)

Member Function Documentation

◆ fit()

template<typename Model , typename Dataset , typename Optimizer >
void dl::Trainer< Model, Dataset, Optimizer >::fit ( Model model,
auto  adapter 
)
inline
Todo:
log the loss
Todo:
with batching increment by batch size

Definition at line 128 of file trainer.hpp.

128 {
129 setRunning();
130 auto& dataset = conf.dataset;
131 assert(dataset != nullptr);
132 auto dataloader = dataset->trainingData();
133 assert(dataloader != nullptr);
134 notify(&TrainerObserver::onTrainingBegun, model);
135 const auto trainsetSize = std::distance(std::begin(*dataloader), std::end(*dataloader));
136 for (size_t epoch = 0; isRunning(); ++epoch) {
137 size_t progress = 0;
138 for (auto&& [out, in] : *dataloader) {
139 if (!isRunning())
140 break;
141 auto loss = adapter(model, in, out);
142 conf.optimizer->step(loss);
144 notify(&TrainerObserver::progressChanged, epoch, trainsetSize, progress);
145 progress += 1;
146 }
147 }
148 auto tmp = &TrainerObserver::onTrainingEnded;
149 notify(&TrainerObserver::onTrainingEnded, model);
150 }
T begin(T... args)
T distance(T... args)
T end(T... args)

◆ test()

template<typename Model , typename Dataset , typename Optimizer >
auto dl::Trainer< Model, Dataset, Optimizer >::test ( Model model,
auto  evaluator,
auto  adapter 
)
inline

Definition at line 152 of file trainer.hpp.

152 {
153 auto& dataset = conf.dataset;
154 assert(dataset != nullptr);
155 auto dataloader = dataset->testData();
156 assert(dataloader != nullptr);
157 for (auto&& [out, in] : *dataloader) {
158 evaluator += adapter(model, in, out);
159 }
160 return evaluator.aggregated();
161 }

◆ validate()

template<typename Model , typename Dataset , typename Optimizer >
void dl::Trainer< Model, Dataset, Optimizer >::validate ( Model model,
auto  evaluator,
auto  adapter 
)
inline

Definition at line 151 of file trainer.hpp.

151{}

The documentation for this class was generated from the following file: