libdl
0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
tensorptr.hpp
1
#pragma once
2
3
#include <concepts>
4
#include <memory>
5
#include <numeric>
6
#include <vector>
7
8
namespace
dl {
9
class
TensorImpl;
10
11
template
<
typename
T>
12
struct
InitializerTensor
{
13
private
:
14
InitializerTensor
() =
delete
;
15
16
public
:
17
std::vector<T>
data;
18
std::vector<size_t>
shape;
19
20
InitializerTensor
(
InitializerTensor<T>
&& other) noexcept : data(
std::move
(data)), shape(
std::move
(shape)) {}
21
InitializerTensor
(
std::initializer_list<T>
&& value) noexcept : data(value), shape({value.size()}) {}
22
InitializerTensor
(
std::initializer_list<InitializerTensor>
&& value) noexcept : data(), shape() {
24
shape = {value.
size
()};
25
data.
reserve
(
std::accumulate
(value.begin(), value.end(), 0, [](
size_t
acc,
auto
& v) {
26
return acc + v.data.size();
27
}));
28
for
(
auto
&& v : value) {
29
data.
insert
(data.
end
(), v.data.begin(), v.data.end());
30
}
31
shape.
insert
(shape.
end
(),
std::begin
(value.begin()->shape),
std::end
(value.begin()->shape));
32
}
33
InitializerTensor
(std::ranges::range
auto
range) noexcept
34
: data(
std::begin
(range),
std::end
(range)), shape({data.
size
()}) {}
35
};
36
45
class
TensorPtr
final {
46
private
:
47
std::shared_ptr<TensorImpl>
data;
48
49
explicit
TensorPtr
(
std::shared_ptr<TensorImpl>
&& data) : data(
std::move
(data)) {}
50
51
public
:
52
TensorPtr
(
TensorPtr
&& other) : data(
std::move
(other.data)){};
53
TensorPtr
(
const
TensorPtr
& other);
54
TensorPtr
(
std::nullptr_t
p) : data(p) {}
55
TensorPtr
(
int
value);
56
TensorPtr
(
float
value);
57
TensorPtr
(
double
value);
58
TensorPtr
(
InitializerTensor<int>
value);
59
TensorPtr
(
InitializerTensor<float>
value);
60
TensorPtr
(
InitializerTensor<double>
value);
61
62
TensorImpl
* operator->()
noexcept
{
return
data.
get
(); }
63
const
TensorImpl
* operator->()
const
noexcept
{
return
data.
get
(); }
64
65
TensorImpl
& operator*()
noexcept
{
return
*data; }
66
const
TensorImpl
& operator*()
const
noexcept
{
return
*data; }
67
68
TensorPtr
& operator=(
const
TensorPtr
& other);
69
TensorPtr
& operator=(
TensorPtr
&& other);
70
71
bool
operator==(
const
std::nullptr_t
& other)
const
noexcept
{
return
data == other; }
72
operator
bool()
const
noexcept
{
return
(
bool
)data; }
73
74
template
<
typename
T,
typename
... Args>
75
static
TensorPtr
create(Args&&... args)
noexcept
{
76
return
TensorPtr
(std::make_unique<T>(std::forward<Args>(args)...));
77
}
78
};
79
80
using
TensorRef
=
std::reference_wrapper<TensorPtr>
;
81
}
// namespace dl
std::accumulate
T accumulate(T... args)
std::begin
T begin(T... args)
dl::TensorImpl
Definition
tensorimpl.hpp:13
dl::TensorPtr
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition
tensorptr.hpp:45
std::vector::end
T end(T... args)
std::shared_ptr::get
T get(T... args)
std::initializer_list
std::vector::insert
T insert(T... args)
std::move
T move(T... args)
std::nullptr_t
std::reference_wrapper
std::vector::reserve
T reserve(T... args)
std::shared_ptr
std::vector::size
T size(T... args)
dl::InitializerTensor
Definition
tensorptr.hpp:12
dl::InitializerTensor::InitializerTensor
InitializerTensor(std::initializer_list< InitializerTensor > &&value) noexcept
Definition
tensorptr.hpp:22
std::vector
dl
tensor
tensorptr.hpp
Generated by
1.9.8