libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
gradientdescent.hpp
1#ifndef DL_LEARNING_OPTIMIZERS_GRADIENTDESCENT_HPP
2#define DL_LEARNING_OPTIMIZERS_GRADIENTDESCENT_HPP
3
4#include "../../tensor/tensorptr.hpp"
5#include "../optimizer.hpp"
6
7#include <map>
8#include <string>
9
10namespace dl::optim {
12 private:
14 const float learnrate;
15
16 public:
17 explicit GradientDescent(std::map<std::string, dl::TensorRef>& parameters, float learnrate = 0.001f)
18 : dl::Optimizer(), parameters(parameters), learnrate(learnrate) {}
19
20 virtual void step(dl::TensorPtr& loss) override {
21 loss->backward();
22 for (auto&& [_, tensor] : parameters) {
23 auto& gradient = tensor.get()->gradient();
24 assert(gradient != nullptr);
25 tensor.get() = tensor.get()->add(gradient->mul(dl::constant(-learnrate, gradient->device())));
26 }
27 }
28 };
29} // namespace dl::optim
30
31#endif
Defines an optimization strategy for a given set of Parameters.
Definition optimizer.hpp:11
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45