libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
dl::TransformerEncoder Class Referencefinal
Inheritance diagram for dl::TransformerEncoder:
Collaboration diagram for dl::TransformerEncoder:

Public Member Functions

 TransformerEncoder (TransformerEncoder &other)=delete
 
 TransformerEncoder (TransformerEncoder &&other)=delete
 
 TransformerEncoder (TransformerConf conf) noexcept
 
virtual TensorPtr forward (TensorPtr input) override
 
TensorPtr scaledDotProductAttention (TensorPtr query, TensorPtr key, TensorPtr value) noexcept
 Implements the scaled dot-product attention.
 
TensorPtr multiHeadAttention (TensorPtr query, TensorPtr key, TensorPtr value) noexcept
 Implements the transformer's multi-head attention.
 

Public Attributes

TransformerConf conf
 
dl::Linear weightQuery
 
dl::Linear weightKey
 
dl::Linear weightValue
 
dl::Linear weightOut
 
dl::Linear weightIntermed
 
dl::LayerNorm mhaNorm
 
dl::Linear weightIntermedOut
 
dl::LayerNorm ffnNorm
 
const float dimKeysInvSqrt
 The precomputed inverse square root of dimKeys.
 

Detailed Description

Definition at line 36 of file transformer.hpp.

Member Function Documentation

◆ multiHeadAttention()

TensorPtr dl::TransformerEncoder::multiHeadAttention ( TensorPtr  query,
TensorPtr  key,
TensorPtr  value 
)
noexcept

Implements the transformer's multi-head attention.

Multi-head attention is chapter 3.2.2 in the transformer paper [8]. Let \(W_i^Q, W_i^K\in \mathbb{R}^{d_\text{model} \times d_k}, W_i^V \in \mathbb{R}^{d_\text{model} \times d_v}\) denote the query, key and value matrix of the i-th head respectively for \(1\leq i \leq h\), where \(h\) is the total number of heads. Further, let \(W^O\) denote the output linearity. Multi-head attention is defined as the concatinated attention of each attention head:

\[ \text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\cdot W^O \text{ with } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V), \]

where "Attention" denotes scaledDotProductAttention (TransformerEncoder::scaledDotProductAttention).

Note
In practice, the linearities are usually accompanied with biases \(b^O, b_i^Q, b_i^K, b_i^V \in \mathbb{R}^{d_\text{model}}\), such that MHA is more accurately described as

\[\text{MHA}(Q, K, V) = \text{Concat}( \text{Attention}(QW_1^Q+b_1^Q, KW_1^K+b_1^K, VW_1^V+b_1^V), \dots)\cdot W^O + b^O.\]

Parameters
query
key
value
Returns
TensorPtr
See also
For a more detailed description, please read the Transformer page.

◆ scaledDotProductAttention()

TensorPtr dl::TransformerEncoder::scaledDotProductAttention ( TensorPtr  query,
TensorPtr  key,
TensorPtr  value 
)
noexcept

Implements the scaled dot-product attention.

Scaled dot-product attention is Eq. (1) in the transformer paper [8] :

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V.\]

Parameters
value
key
query
Returns
the scaled dot-product attention.

Member Data Documentation

◆ conf

TransformerConf dl::TransformerEncoder::conf

Definition at line 38 of file transformer.hpp.

◆ dimKeysInvSqrt

const float dl::TransformerEncoder::dimKeysInvSqrt

The precomputed inverse square root of dimKeys.

This is the precomputed normalization factor \(\sqrt{d_k}^{-1}\) used in the scaled dot-product attention.

Definition at line 62 of file transformer.hpp.

◆ ffnNorm

dl::LayerNorm dl::TransformerEncoder::ffnNorm

Definition at line 48 of file transformer.hpp.

◆ mhaNorm

dl::LayerNorm dl::TransformerEncoder::mhaNorm

Definition at line 45 of file transformer.hpp.

◆ weightIntermed

dl::Linear dl::TransformerEncoder::weightIntermed

Definition at line 44 of file transformer.hpp.

◆ weightIntermedOut

dl::Linear dl::TransformerEncoder::weightIntermedOut

Definition at line 47 of file transformer.hpp.

◆ weightKey

dl::Linear dl::TransformerEncoder::weightKey

Definition at line 41 of file transformer.hpp.

◆ weightOut

dl::Linear dl::TransformerEncoder::weightOut

Definition at line 43 of file transformer.hpp.

◆ weightQuery

dl::Linear dl::TransformerEncoder::weightQuery

Definition at line 40 of file transformer.hpp.

◆ weightValue

dl::Linear dl::TransformerEncoder::weightValue

Definition at line 42 of file transformer.hpp.


The documentation for this class was generated from the following file: