60 void backward(
bool enableAutodiff = false) noexcept;
62 const
TensorPtr& gradient() const noexcept {
return grad; }
63 void discardGradient() noexcept {
69 virtual bool operator==(
const TensorPtr& other)
const noexcept = 0;
70 virtual bool allclose(
const TensorPtr& other,
float rtol = 1e-5,
float atol = 1e-8) const noexcept = 0;
72 virtual TensorPtr add(const TensorPtr& other) const noexcept = 0;
73 virtual TensorPtr sub(const TensorPtr& other) const noexcept = 0;
74 virtual TensorPtr mul(const TensorPtr& other) const noexcept = 0;
75 virtual TensorPtr div(const TensorPtr& other) const noexcept = 0;
88 virtual
TensorPtr transpose(
std::vector<
size_t>&& permutation) const noexcept = 0;
91 virtual
TensorPtr pow(
float exponent) const noexcept = 0;
92 virtual
TensorPtr exp() const noexcept = 0;
93 virtual
TensorPtr log() const noexcept = 0;
94 virtual
TensorPtr sqrt() const noexcept = 0;
95 virtual
TensorPtr rsqrt() const noexcept = 0;
98 virtual
TensorPtr mean() const noexcept = 0;
99 virtual
TensorPtr mean(
int dim,
bool keepdim) const noexcept = 0;
100 virtual
TensorPtr sum() const noexcept = 0;
101 virtual
TensorPtr sum(
int dim,
bool keepdim) const noexcept = 0;
102 virtual
TensorPtr min() const noexcept = 0;
103 virtual
TensorPtr min(
int dim,
bool keepdim) const noexcept = 0;
111 virtual
TensorPtr max() const noexcept = 0;
112 virtual
TensorPtr max(
int dim,
bool keepdim) const noexcept = 0;
115 virtual
TensorPtr var(
int dim,
DOF dof) const noexcept = 0;
117 virtual
TensorPtr erf() const noexcept = 0;
119 virtual
void mul_inplace(const
TensorPtr& other) noexcept = 0;
129 virtual
TensorPtr clone() const noexcept = 0;
131 virtual
Shape shape() const noexcept = 0;
132 virtual
size_t shape(
int dim) const noexcept = 0;
133 size_t numDim() const noexcept {
return shape().
size(); }
135 virtual TensorPtr flatten() const noexcept = 0;
146 virtual
size_t toBytes(
char* buffer,
size_t buflen) const noexcept = 0;
bool requiresGrad() const noexcept
Returns true iff this tensor requires a gradient, i.e., needs to be updated during backpropagation.