libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
layernorm.hpp
1#pragma once
2
3#include "../device.hpp"
4#include "../tensor/tensorptr.hpp"
5#include "./model.hpp"
6
7namespace dl {
12 class LayerNorm final : public Model<TensorPtr(TensorPtr)> {
13 private:
14 TensorPtr beta;
15 TensorPtr gamma;
16
17 public:
18 LayerNorm(Shape normShape, const Device& device = Device::getDefault()) noexcept;
19 virtual ~LayerNorm() = default;
20
21 public:
22 virtual TensorPtr forward(TensorPtr input) noexcept override;
23 };
24} // namespace dl
static const Device & getDefault() noexcept
Returns the default device for this thread.
Definition device.hpp:76
Implements layer normalization as proposed by .
Definition layernorm.hpp:12
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45