libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
tensorptr.hpp
1#pragma once
2
3#include <concepts>
4#include <memory>
5#include <numeric>
6#include <vector>
7
8namespace dl {
9 class TensorImpl;
10
11 template <typename T>
13 private:
14 InitializerTensor() = delete;
15
16 public:
17 std::vector<T> data;
19
20 InitializerTensor(InitializerTensor<T>&& other) noexcept : data(std::move(data)), shape(std::move(shape)) {}
21 InitializerTensor(std::initializer_list<T>&& value) noexcept : data(value), shape({value.size()}) {}
22 InitializerTensor(std::initializer_list<InitializerTensor>&& value) noexcept : data(), shape() {
24 shape = {value.size()};
25 data.reserve(std::accumulate(value.begin(), value.end(), 0, [](size_t acc, auto& v) {
26 return acc + v.data.size();
27 }));
28 for (auto&& v : value) {
29 data.insert(data.end(), v.data.begin(), v.data.end());
30 }
31 shape.insert(shape.end(), std::begin(value.begin()->shape), std::end(value.begin()->shape));
32 }
33 InitializerTensor(std::ranges::range auto range) noexcept
34 : data(std::begin(range), std::end(range)), shape({data.size()}) {}
35 };
36
45 class TensorPtr final {
46 private:
48
49 explicit TensorPtr(std::shared_ptr<TensorImpl>&& data) : data(std::move(data)) {}
50
51 public:
52 TensorPtr(TensorPtr&& other) : data(std::move(other.data)){};
53 TensorPtr(const TensorPtr& other);
54 TensorPtr(std::nullptr_t p) : data(p) {}
55 TensorPtr(int value);
56 TensorPtr(float value);
57 TensorPtr(double value);
61
62 TensorImpl* operator->() noexcept { return data.get(); }
63 const TensorImpl* operator->() const noexcept { return data.get(); }
64
65 TensorImpl& operator*() noexcept { return *data; }
66 const TensorImpl& operator*() const noexcept { return *data; }
67
68 TensorPtr& operator=(const TensorPtr& other);
69 TensorPtr& operator=(TensorPtr&& other);
70
71 bool operator==(const std::nullptr_t& other) const noexcept { return data == other; }
72 operator bool() const noexcept { return (bool)data; }
73
74 template <typename T, typename... Args>
75 static TensorPtr create(Args&&... args) noexcept {
76 return TensorPtr(std::make_unique<T>(std::forward<Args>(args)...));
77 }
78 };
79
81} // namespace dl
T accumulate(T... args)
T begin(T... args)
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45
T end(T... args)
T get(T... args)
T insert(T... args)
T move(T... args)
T reserve(T... args)
T size(T... args)
InitializerTensor(std::initializer_list< InitializerTensor > &&value) noexcept
Definition tensorptr.hpp:22