libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
tensorimpl.hpp
1#pragma once
2
3#include "math.hpp"
4#include "shape.hpp"
5
6#include <any>
7#include <functional>
8#include <iostream>
9
10namespace dl {
11 class Device;
12
13 class TensorImpl {
14 using GradFn = std::function<void(TensorPtr&)>;
15
16 private:
17 bool _requiresGrad;
18 Device const& _device;
19
20 public:
21 GradFn gradfn = nullptr;
22 TensorPtr grad = nullptr;
23
24 protected:
25 TensorImpl(Device const& device, bool requiresGrad) noexcept;
26
27 public:
36 TensorPtr to(Device const& other) const noexcept;
43 Device const& device() const noexcept;
44
51 void setRequiresGrad(bool requiresGrad) noexcept;
58 bool requiresGrad() const noexcept;
59
60 void backward(bool enableAutodiff = false) noexcept;
61
62 const TensorPtr& gradient() const noexcept { return grad; }
63 void discardGradient() noexcept {
64 gradfn = nullptr;
65 grad = nullptr;
66 }
67
68 virtual std::ostream& writeToStream(std::ostream& stream) const noexcept = 0;
69 virtual bool operator==(const TensorPtr& other) const noexcept = 0;
70 virtual bool allclose(const TensorPtr& other, float rtol = 1e-5, float atol = 1e-8) const noexcept = 0;
71
72 virtual TensorPtr add(const TensorPtr& other) const noexcept = 0;
73 virtual TensorPtr sub(const TensorPtr& other) const noexcept = 0;
74 virtual TensorPtr mul(const TensorPtr& other) const noexcept = 0;
75 virtual TensorPtr div(const TensorPtr& other) const noexcept = 0;
76
86 virtual TensorPtr fma(const TensorPtr& factor, const TensorPtr& summand) const noexcept = 0;
87 virtual TensorPtr matmul(const TensorPtr& other) const noexcept = 0;
88 virtual TensorPtr transpose(std::vector<size_t>&& permutation) const noexcept = 0;
89
90 // Powers:
91 virtual TensorPtr pow(float exponent) const noexcept = 0;
92 virtual TensorPtr exp() const noexcept = 0;
93 virtual TensorPtr log() const noexcept = 0;
94 virtual TensorPtr sqrt() const noexcept = 0;
95 virtual TensorPtr rsqrt() const noexcept = 0;
96
97 // Statistical functions:
98 virtual TensorPtr mean() const noexcept = 0;
99 virtual TensorPtr mean(int dim, bool keepdim) const noexcept = 0;
100 virtual TensorPtr sum() const noexcept = 0;
101 virtual TensorPtr sum(int dim, bool keepdim) const noexcept = 0;
102 virtual TensorPtr min() const noexcept = 0;
103 virtual TensorPtr min(int dim, bool keepdim) const noexcept = 0;
110 virtual TensorPtr min(const TensorPtr& other) const noexcept = 0;
111 virtual TensorPtr max() const noexcept = 0;
112 virtual TensorPtr max(int dim, bool keepdim) const noexcept = 0;
113 virtual TensorPtr max(const TensorPtr& other) const noexcept = 0;
114 virtual TensorPtr var(DOF dof) const noexcept = 0;
115 virtual TensorPtr var(int dim, DOF dof) const noexcept = 0;
116
117 virtual TensorPtr erf() const noexcept = 0;
118
119 virtual void mul_inplace(const TensorPtr& other) noexcept = 0;
127 virtual void reshape(SShape shape) noexcept = 0;
128
129 virtual TensorPtr clone() const noexcept = 0;
130
131 virtual Shape shape() const noexcept = 0;
132 virtual size_t shape(int dim) const noexcept = 0;
133 size_t numDim() const noexcept { return shape().size(); }
134
135 virtual TensorPtr flatten() const noexcept = 0;
136
146 virtual size_t toBytes(char* buffer, size_t buflen) const noexcept = 0;
147 };
148} // namespace dl
virtual size_t toBytes(char *buffer, size_t buflen) const noexcept=0
Writes this tensor's data into the byte array.
virtual TensorPtr fma(const TensorPtr &factor, const TensorPtr &summand) const noexcept=0
Performs "fused multiply and add".
TensorPtr to(Device const &other) const noexcept
Creates a copy of this tensor on the requested device and returns a pointer to it....
bool requiresGrad() const noexcept
Returns true iff this tensor requires a gradient, i.e., needs to be updated during backpropagation.
Device const & device() const noexcept
Returns the device this tensor is stored on.
virtual void reshape(SShape shape) noexcept=0
Reshapes the tensor to fit the specified size.
void setRequiresGrad(bool requiresGrad) noexcept
Set this tensors requirements for a gradient.
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45
Implements auto-diff enabled wrappers around their concrete tensor implementations.
T size(T... args)
Wrapper around std::size_t to discern between var(TensorPtr, DOF) and var(TensorPtr,...
Definition math.hpp:177