libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
transformer.hpp
1#pragma once
2
3#include "../layernorm.hpp"
4#include "../linear.hpp"
5#include "../model.hpp"
6
7#include <cmath>
8#include <vector>
9
10namespace dl {
11
20 constexpr double calcPosEncoding(size_t pos, size_t i, size_t dimModel) {
21 return (i % 2 == 0) ? std::sin(pos / std::pow(10000, i / dimModel))
22 : std::sin(pos / std::pow(10000, (i - 1) / dimModel));
23 }
24
26 struct {
27 size_t model;
28 size_t key;
29 size_t value;
30 size_t inner;
31 } dimensions;
32 size_t numEncoders;
33 size_t numAttnHeads;
34 };
35
36 class TransformerEncoder final : public Model<TensorPtr(TensorPtr)> {
37 public:
38 TransformerConf conf;
39 // Multi-Head Attention
40 dl::Linear weightQuery;
41 dl::Linear weightKey;
42 dl::Linear weightValue;
43 dl::Linear weightOut;
44 dl::Linear weightIntermed;
45 dl::LayerNorm mhaNorm;
46 // FFN
47 dl::Linear weightIntermedOut;
48 dl::LayerNorm ffnNorm;
49
51 TransformerEncoder(TransformerEncoder&& other) = delete;
52
53 public:
55 virtual TensorPtr forward(TensorPtr input) override;
56
62 const float dimKeysInvSqrt;
63
75
101 };
102
107 class Transformer final : public Model<TensorPtr(TensorPtr)> {
108 public:
109 const TransformerConf conf;
111 dl::Linear weightOut;
112
113 public:
114 Transformer(TransformerConf conf) noexcept;
115
116 virtual TensorPtr forward(TensorPtr input) override;
117 };
118}; // namespace dl
Implements layer normalization as proposed by .
Definition layernorm.hpp:12
Applies a learnable linear transformation with optional bias.
Definition linear.hpp:12
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45
const float dimKeysInvSqrt
The precomputed inverse square root of dimKeys.
TensorPtr scaledDotProductAttention(TensorPtr query, TensorPtr key, TensorPtr value) noexcept
Implements the scaled dot-product attention.
TensorPtr multiHeadAttention(TensorPtr query, TensorPtr key, TensorPtr value) noexcept
Implements the transformer's multi-head attention.
TensorPtr pow(TensorPtr base, float exponent) noexcept
Computes the exponent -th power of each element in base and returns the resulting tensor.
T pow(T... args)
T sin(T... args)