libdl
0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
bert.hpp
1
#pragma once
2
3
#include <dl/model/embedding.hpp>
4
#include <dl/model/model.hpp>
5
#include <dl/model/transformer/transformer.hpp>
6
#include <dl/tensor/tensorptr.hpp>
7
8
#include <dl/utils/composed.hpp>
9
10
namespace
nlp {
11
12
struct
BERTConfig
{
13
size_t
vocabSize;
14
size_t
maxPositionEmbeddings;
15
size_t
typeVocabSize;
16
};
17
18
class
BERTEmbeddings
:
public
dl::Model
<dl::TensorPtr(const dl::TensorPtr&, const dl::TensorPtr&)> {
19
private
:
20
dl::Embedding
wordEmbeddings;
21
dl::Embedding
positionalEmbeddings;
22
dl::Embedding
tokenTypeEmbeddings;
23
dl::LayerNorm
layerNorm;
24
25
public
:
26
BERTEmbeddings
(
BERTConfig
bertConf,
dl::TransformerConf
config)
27
: wordEmbeddings(bertConf.vocabSize, config.dimensions.
model
),
28
positionalEmbeddings(bertConf.maxPositionEmbeddings, config.dimensions.
model
),
29
tokenTypeEmbeddings(bertConf.typeVocabSize, config.dimensions.
model
),
30
layerNorm({config.dimensions.
model
}) {
31
registerSubmodel(
"word_embeddings"
, wordEmbeddings);
32
registerSubmodel(
"position_embeddings"
, positionalEmbeddings);
33
registerSubmodel(
"token_type_embeddings"
, tokenTypeEmbeddings);
34
registerSubmodel(
"LayerNorm"
, layerNorm);
35
}
36
37
virtual
dl::TensorPtr
forward
(
const
dl::TensorPtr
& inputIds,
const
dl::TensorPtr
& inputTokenTypes) {
40
// auto& inputEmbeds = wordEmbeddings.forward(std::forward<decltype(inputIds)>(inputIds));
41
// auto& typeEmbeds = tokenTypeEmbeddings.forward(std::forward<decltype(inputTokenTypes)>(inputTokenTypes));
42
// return layerNorm.forward(inputEmbeds + typeEmbeds + posEmbeds);
44
throw
std::runtime_error
(
"Not yet implemented"
);
45
}
46
};
47
48
class
BERTPooling
:
public
dl::Model
<dl::TensorPtr(const dl::TensorPtr&)> {
49
private
:
50
dl::Linear
dense;
51
52
public
:
53
BERTPooling
(
dl::TransformerConf
conf) noexcept : dense(conf.dimensions.
model
, conf.dimensions.
model
) {
54
registerSubmodel(
"dense"
, dense);
55
}
56
57
virtual
dl::TensorPtr
forward(
const
dl::TensorPtr
& input) {
return
nullptr
; }
58
};
59
64
class
BERT
:
public
dl::Model
<dl::TensorPtr(const dl::TensorPtr&)> {
65
public
:
66
static
constexpr
dl::TransformerConf
transformerConf{
67
.dimensions = {.
model
= 768, .key = 64, .value = 64, .inner = 3072},
68
.numEncoders = 12,
69
.numAttnHeads = 12
70
};
71
BERTEmbeddings
embeddings;
72
dl::Transformer
encoder;
73
BERTPooling
pooling;
74
75
public
:
76
BERT
(
BERTConfig
config)
77
: embeddings(config, transformerConf), encoder(transformerConf), pooling(transformerConf) {
78
registerSubmodel(
"bert.embeddings"
, embeddings);
79
registerSubmodel(
"bert"
, encoder);
80
registerSubmodel(
"bert.pooler"
, pooling);
81
}
82
83
virtual
dl::TensorPtr
forward
(
const
dl::TensorPtr
& input)
override
{
85
throw
std::runtime_error
(
"Not yet implemented"
);
86
// return pooling.forward(encoder.forward(embeddings.forward(input)));
87
}
88
};
89
}
// namespace nlp
dl::Embedding
Definition
embedding.hpp:8
dl::LayerNorm
Implements layer normalization as proposed by .
Definition
layernorm.hpp:12
dl::Linear
Applies a learnable linear transformation with optional bias.
Definition
linear.hpp:12
dl::Model
Definition
model.hpp:33
dl::TensorPtr
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition
tensorptr.hpp:45
dl::Transformer
Definition
transformer.hpp:107
nlp::BERTEmbeddings
Definition
bert.hpp:18
nlp::BERTEmbeddings::forward
virtual dl::TensorPtr forward(const dl::TensorPtr &inputIds, const dl::TensorPtr &inputTokenTypes)
Definition
bert.hpp:37
nlp::BERTPooling
Definition
bert.hpp:48
nlp::BERT
Definition
bert.hpp:64
nlp::BERT::forward
virtual dl::TensorPtr forward(const dl::TensorPtr &input) override
Definition
bert.hpp:83
std::runtime_error
dl::TransformerConf
Definition
transformer.hpp:25
dl::TransformerConf::model
size_t model
Definition
transformer.hpp:27
nlp::BERTConfig
Definition
bert.hpp:12
nlp
transformer
bert.hpp
Generated by
1.9.8