libdl
0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
linear.hpp
1
#pragma once
2
3
#include "../tensor/tensorptr.hpp"
4
#include "./model.hpp"
5
6
namespace
dl {
12
class
Linear
final :
public
Model
<TensorPtr(TensorPtr)> {
13
private
:
14
TensorPtr
_weights;
15
TensorPtr
_bias;
16
17
public
:
18
Linear
(
size_t
inFeatures,
size_t
outFeatures,
const
Device
& device,
bool
bias =
true
)
noexcept
;
19
Linear
(
size_t
inFeatures,
size_t
outFeatures,
bool
bias =
true
)
noexcept
;
20
21
public
:
22
virtual
TensorPtr
forward(
TensorPtr
input)
noexcept
override
;
23
24
inline
TensorPtr
& weights()
noexcept
{
return
_weights; }
25
inline
const
TensorPtr
& weights()
const
noexcept
{
return
_weights; }
26
inline
TensorPtr
& bias()
noexcept
{
return
_bias; }
27
inline
const
TensorPtr
& bias()
const
noexcept
{
return
_bias; }
28
29
dl::TensorPtr
operator()(
TensorPtr
input)
noexcept
{
return
forward(input); }
30
};
31
}
// namespace dl
dl::Device
Definition
device.hpp:8
dl::Linear
Applies a learnable linear transformation with optional bias.
Definition
linear.hpp:12
dl::Model
Definition
model.hpp:33
dl::TensorPtr
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition
tensorptr.hpp:45
dl
model
linear.hpp
Generated by
1.9.8