libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
bert.hpp
1#pragma once
2
3#include <dl/model/embedding.hpp>
4#include <dl/model/model.hpp>
5#include <dl/model/transformer/transformer.hpp>
6#include <dl/tensor/tensorptr.hpp>
7
8#include <dl/utils/composed.hpp>
9
10namespace nlp {
11
12 struct BERTConfig {
13 size_t vocabSize;
14 size_t maxPositionEmbeddings;
15 size_t typeVocabSize;
16 };
17
18 class BERTEmbeddings : public dl::Model<dl::TensorPtr(const dl::TensorPtr&, const dl::TensorPtr&)> {
19 private:
20 dl::Embedding wordEmbeddings;
21 dl::Embedding positionalEmbeddings;
22 dl::Embedding tokenTypeEmbeddings;
23 dl::LayerNorm layerNorm;
24
25 public:
27 : wordEmbeddings(bertConf.vocabSize, config.dimensions.model),
28 positionalEmbeddings(bertConf.maxPositionEmbeddings, config.dimensions.model),
29 tokenTypeEmbeddings(bertConf.typeVocabSize, config.dimensions.model),
30 layerNorm({config.dimensions.model}) {
31 registerSubmodel("word_embeddings", wordEmbeddings);
32 registerSubmodel("position_embeddings", positionalEmbeddings);
33 registerSubmodel("token_type_embeddings", tokenTypeEmbeddings);
34 registerSubmodel("LayerNorm", layerNorm);
35 }
36
37 virtual dl::TensorPtr forward(const dl::TensorPtr& inputIds, const dl::TensorPtr& inputTokenTypes) {
40 // auto& inputEmbeds = wordEmbeddings.forward(std::forward<decltype(inputIds)>(inputIds));
41 // auto& typeEmbeds = tokenTypeEmbeddings.forward(std::forward<decltype(inputTokenTypes)>(inputTokenTypes));
42 // return layerNorm.forward(inputEmbeds + typeEmbeds + posEmbeds);
44 throw std::runtime_error("Not yet implemented");
45 }
46 };
47
48 class BERTPooling : public dl::Model<dl::TensorPtr(const dl::TensorPtr&)> {
49 private:
50 dl::Linear dense;
51
52 public:
53 BERTPooling(dl::TransformerConf conf) noexcept : dense(conf.dimensions.model, conf.dimensions.model) {
54 registerSubmodel("dense", dense);
55 }
56
57 virtual dl::TensorPtr forward(const dl::TensorPtr& input) { return nullptr; }
58 };
59
64 class BERT : public dl::Model<dl::TensorPtr(const dl::TensorPtr&)> {
65 public:
66 static constexpr dl::TransformerConf transformerConf{
67 .dimensions = {.model = 768, .key = 64, .value = 64, .inner = 3072},
68 .numEncoders = 12,
69 .numAttnHeads = 12
70 };
71 BERTEmbeddings embeddings;
72 dl::Transformer encoder;
73 BERTPooling pooling;
74
75 public:
76 BERT(BERTConfig config)
77 : embeddings(config, transformerConf), encoder(transformerConf), pooling(transformerConf) {
78 registerSubmodel("bert.embeddings", embeddings);
79 registerSubmodel("bert", encoder);
80 registerSubmodel("bert.pooler", pooling);
81 }
82
83 virtual dl::TensorPtr forward(const dl::TensorPtr& input) override {
85 throw std::runtime_error("Not yet implemented");
86 // return pooling.forward(encoder.forward(embeddings.forward(input)));
87 }
88 };
89} // namespace nlp
Implements layer normalization as proposed by .
Definition layernorm.hpp:12
Applies a learnable linear transformation with optional bias.
Definition linear.hpp:12
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45
virtual dl::TensorPtr forward(const dl::TensorPtr &inputIds, const dl::TensorPtr &inputTokenTypes)
Definition bert.hpp:37
virtual dl::TensorPtr forward(const dl::TensorPtr &input) override
Definition bert.hpp:83