libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
linear.hpp
1#pragma once
2
3#include "../tensor/tensorptr.hpp"
4#include "./model.hpp"
5
6namespace dl {
12 class Linear final : public Model<TensorPtr(TensorPtr)> {
13 private:
14 TensorPtr weights;
15 TensorPtr bias;
16
17 public:
18 Linear(size_t inFeatures, size_t outFeatures, const Device& device, bool bias = true) noexcept;
19 Linear(size_t inFeatures, size_t outFeatures, bool bias = true) noexcept;
20
21 public:
22 virtual TensorPtr forward(TensorPtr input) noexcept override;
25 };
26} // namespace dl
Applies a learnable linear transformation with optional bias.
Definition linear.hpp:12
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45