libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
embedding.hpp
1#pragma once
2
3#include "../device.hpp"
4#include "../tensor/tensorptr.hpp"
5#include "model.hpp"
6
7namespace dl {
8 class Embedding : public dl::Model<dl::TensorPtr(const dl::TensorPtr)> {
9 private:
10 dl::TensorPtr weight;
11
12 public:
13 Embedding(size_t numEmbeddings, size_t embeddingDim) : weight(dl::empty({numEmbeddings, embeddingDim})) {
14 registerParameter("weight", weight);
15 }
16 virtual ~Embedding() = default;
17
18 virtual dl::TensorPtr forward(const dl::TensorPtr input) {
20 throw std::runtime_error("Not yet implemented");
21 }
22 };
23} // namespace dl
virtual dl::TensorPtr forward(const dl::TensorPtr input)
Definition embedding.hpp:18
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45