libdl  0.0.1
Simple yet powerful deep learning
Loading...
Searching...
No Matches
model.hpp
1#pragma once
2
3#include "../tensor/tensorimpl.hpp"
4#include "../tensor/tensorptr.hpp"
5
6#include <format>
7#include <map>
8#include <ranges>
9
10namespace dl {
11 class Device;
12
13 class ModelBase {
14 private:
16
17 protected:
18 void registerParameter(std::string name, TensorPtr& tensor);
19 void registerParameters(std::string prefix, std::ranges::range auto& tensors) {
20 for (auto&& [key, value] : tensors)
21 _parameters.insert({std::format("{}.{}", prefix, key), value});
22 }
23
24 public:
25 virtual ~ModelBase() = default;
26 size_t numParameters() const noexcept;
27 size_t numTrainableParams() const noexcept;
28 std::map<std::string, dl::TensorRef>& parameters() noexcept { return _parameters; }
29 const std::map<std::string, dl::TensorRef>& parameters() const noexcept { return _parameters; }
30 };
31
32 template <typename>
33 class Model {};
34
35 template <typename R, typename... Args>
36 class Model<R(Args...)> : public virtual ModelBase {
37 public:
38 using signature = R(Args...);
39
40 protected:
41 void registerSubmodel(std::string prefix, const ModelBase& model) {
42 registerParameters(prefix, model.parameters());
43 }
44
45 virtual R forward(Args... args) = 0;
48 // virtual R forward(Args... args) const = 0;
49
50 public:
51 virtual ~Model() = default;
52 void to(const Device& device) noexcept;
53
54 R operator()(Args&&... args) { return this->forward(std::forward<Args>(args)...); }
57 // R operator()(Args&&... args) const { return this->forward(std::forward<Args>(args)...); }
58 };
59} // namespace dl
virtual ~Model()=default
The Tensor is a managed pointer to a tensor. It can generally be thought of like an std::unique_ptr<T...
Definition tensorptr.hpp:45
T insert(T... args)