libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
linear.hpp
1#pragma once
2
3#include "../tensor/tensorptr.hpp"
4#include "./model.hpp"
5
6namespace dl {
12 class Linear final : public Model<TensorPtr(TensorPtr)> {
13 private:
14 TensorPtr _weights;
15 TensorPtr _bias;
16
17 public:
18 Linear(size_t inFeatures, size_t outFeatures, const Device& device, bool bias = true) noexcept;
19 Linear(size_t inFeatures, size_t outFeatures, bool bias = true) noexcept;
20
21 public:
22 virtual TensorPtr forward(TensorPtr input) noexcept override;
23
24 inline TensorPtr& weights() noexcept { return _weights; }
25 inline const TensorPtr& weights() const noexcept { return _weights; }
26 inline TensorPtr& bias() noexcept { return _bias; }
27 inline const TensorPtr& bias() const noexcept { return _bias; }
28
29 dl::TensorPtr operator()(TensorPtr input) noexcept { return forward(input); }
30 };
31} // namespace dl
Applies a learnable linear transformation with optional bias.
Definition linear.hpp:12
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45