libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
nlp::BERTEmbeddings Class Reference
Inheritance diagram for nlp::BERTEmbeddings:
Collaboration diagram for nlp::BERTEmbeddings:

Public Member Functions

 BERTEmbeddings (BERTConfig bertConf, dl::TransformerConf config)
 
virtual dl::TensorPtr forward (const dl::TensorPtr &inputIds, const dl::TensorPtr &inputTokenTypes)
 

Detailed Description

Definition at line 18 of file bert.hpp.

Constructor & Destructor Documentation

◆ BERTEmbeddings()

nlp::BERTEmbeddings::BERTEmbeddings ( BERTConfig  bertConf,
dl::TransformerConf  config 
)
inline

Definition at line 26 of file bert.hpp.

27 : wordEmbeddings(bertConf.vocabSize, config.dimensions.model),
28 positionalEmbeddings(bertConf.maxPositionEmbeddings, config.dimensions.model),
29 tokenTypeEmbeddings(bertConf.typeVocabSize, config.dimensions.model),
30 layerNorm({config.dimensions.model}) {
31 registerSubmodel("word_embeddings", wordEmbeddings);
32 registerSubmodel("position_embeddings", positionalEmbeddings);
33 registerSubmodel("token_type_embeddings", tokenTypeEmbeddings);
34 registerSubmodel("LayerNorm", layerNorm);
35 }

Member Function Documentation

◆ forward()

virtual dl::TensorPtr nlp::BERTEmbeddings::forward ( const dl::TensorPtr inputIds,
const dl::TensorPtr inputTokenTypes 
)
inlinevirtual

Note that the embeddings returned are references to tensors instead of tensors since they return the references to the learned embeddings.

Todo:
implement

Definition at line 37 of file bert.hpp.

37 {
40 // auto& inputEmbeds = wordEmbeddings.forward(std::forward<decltype(inputIds)>(inputIds));
41 // auto& typeEmbeds = tokenTypeEmbeddings.forward(std::forward<decltype(inputTokenTypes)>(inputTokenTypes));
42 // return layerNorm.forward(inputEmbeds + typeEmbeds + posEmbeds);
44 throw std::runtime_error("Not yet implemented");
45 }

The documentation for this class was generated from the following file: