libdl
0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
embedding.hpp
1
#pragma once
2
3
#include "../device.hpp"
4
#include "../tensor/tensorptr.hpp"
5
#include "model.hpp"
6
7
namespace
dl {
8
class
Embedding
:
public
dl::Model
<dl::TensorPtr(const dl::TensorPtr)> {
9
private
:
10
dl::TensorPtr
weight;
11
12
public
:
13
Embedding
(
size_t
numEmbeddings,
size_t
embeddingDim) : weight(dl::empty({numEmbeddings, embeddingDim})) {
14
registerParameter(
"weight"
, weight);
15
}
16
virtual
~Embedding
() =
default
;
17
18
virtual
dl::TensorPtr
forward
(
const
dl::TensorPtr
input) {
20
throw
std::runtime_error
(
"Not yet implemented"
);
21
}
22
};
23
}
// namespace dl
dl::Embedding
Definition
embedding.hpp:8
dl::Embedding::forward
virtual dl::TensorPtr forward(const dl::TensorPtr input)
Definition
embedding.hpp:18
dl::Model
Definition
model.hpp:33
dl::TensorPtr
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition
tensorptr.hpp:45
std::runtime_error
dl
model
embedding.hpp
Generated by
1.9.8