buffer.h
Go to the documentation of this file.
1 #ifndef TTG_MADNESS_BUFFER_H
2 #define TTG_MADNESS_BUFFER_H
3 
5 
6 namespace ttg_madness {
7 
9 template<typename T, typename Allocator>
10 struct Buffer : private Allocator {
11 
12  using element_type = std::decay_t<T>;
13 
14  using allocator_traits = std::allocator_traits<Allocator>;
15  using allocator_type = typename allocator_traits::allocator_type;
16 
17  static_assert(std::is_trivially_copyable_v<element_type>,
18  "Only trivially copyable types are supported for devices.");
19  static_assert(std::is_default_constructible_v<element_type>,
20  "Only default constructible types are supported for devices.");
21 
22 private:
23  using delete_fn_t = std::function<void(element_type*)>;
24 
25  using host_data_ptr = std::add_pointer_t<element_type>;
26  host_data_ptr m_host_data = nullptr;
27  std::size_t m_count = 0;
28  bool m_owned= false;
29 
30  static void delete_non_owned(element_type *ptr) {
31  // nothing to be done, we don't own the memory
32  }
33 
34  allocator_type& get_allocator_reference() { return static_cast<allocator_type&>(*this); }
35 
36  element_type* allocate(std::size_t n) {
37  return allocator_traits::allocate(get_allocator_reference(), n);
38  }
39 
40  void deallocate() {
41  allocator_traits::deallocate(get_allocator_reference(), m_host_data, m_count);
42  }
43 
44 public:
45 
46  Buffer() : Buffer(nullptr, 0)
47  { }
48 
49  Buffer(std::size_t n)
50  : allocator_type()
51  , m_host_data(allocate(n))
52  , m_count(n)
53  , m_owned(true)
54  { }
55 
56  /* Constructing a buffer using application-managed memory.
57  * The memory pointed to by ptr must be accessible during
58  * the life-time of the buffer. */
59  Buffer(element_type* ptr, std::size_t n = 1)
60  : allocator_type()
61  , m_host_data(ptr)
62  , m_count(n)
63  , m_owned(false)
64  { }
65 
66  virtual ~Buffer() {
67  if (m_owned) {
68  deallocate();
69  m_owned = false;
70  }
71  unpin(); // make sure the copies are not pinned
72  }
73 
74  /* allow moving device buffers */
75  Buffer(Buffer&& db)
76  : allocator_type(std::move(db))
77  , m_host_data(db.m_host_data)
78  , m_count(db.m_count)
79  , m_owned(db.m_owned)
80  {
81  db.m_host_data = nullptr;
82  db.m_count = 0;
83  db.m_owned = false;
84  }
85 
86  /* explicitly disable copying of buffers
87  * TODO: should we allow this? What data to use?
88  */
89  Buffer(const Buffer& db) = delete;
90 
91  /* allow moving device buffers */
93  allocator_type::operator=(std::move(db));
94  std::swap(m_host_data, db.m_host_data);
95  std::swap(m_count, db.m_count);
96  std::swap(m_owned, db.m_owned);
97  return *this;
98  }
99 
100  /* explicitly disable copying of buffers
101  * TODO: should we allow this? What data to use?
102  */
103  Buffer& operator=(const Buffer& db) = delete;
104 
105  /* set the current device, useful when a device
106  * buffer was modified outside of a TTG */
108  assert(is_valid());
109  if (!device.is_host()) throw std::runtime_error("MADNESS backend does not support non-host memory!");
110  /* no-op */
111  }
112 
113  /* Get the owner device ID, i.e., the last updated
114  * device buffer. */
116  assert(is_valid());
117  return {}; // host only
118  }
119 
120  /* Get the pointer on the currently active device. */
122  assert(is_valid());
123  return m_host_data;
124  }
125 
126  /* Get the pointer on the currently active device. */
128  assert(is_valid());
129  return m_host_data;
130  }
131 
132  /* Get the pointer on the owning device.
133  * @note: This may not be the device assigned to the currently executing task.
134  * See \ref ttg::device::current_device for that. */
136  assert(is_valid());
137  return m_host_data;
138  }
139 
140  /* get the current device pointer */
142  assert(is_valid());
143  return m_host_data;
144  }
145 
146  /* get the device pointer at the given device
147  */
149  assert(is_valid());
150  if (device.is_device()) throw std::runtime_error("MADNESS missing support for non-host memory!");
151  return m_host_data;
152  }
153 
154  /* get the device pointer at the given device
155  */
156  const element_type* device_ptr_on(const ttg::device::Device& device) const {
157  assert(is_valid());
158  if (device.is_device()) throw std::runtime_error("MADNESS missing support for non-host memory!");
159  return m_host_data;
160  }
161 
163  return m_host_data;
164  }
165 
166  const element_type* host_ptr() const {
167  return m_host_data;
168  }
169 
170  bool is_valid_on(const ttg::device::Device& device) const {
171  assert(is_valid());
172  if (device.is_device()) throw std::runtime_error("MADNESS missing support for non-host memory!");
173  return true;
174  }
175 
176  void allocate_on(const ttg::device::Device& device_id) {
177  /* TODO: need exposed PaRSEC memory allocator */
178  throw std::runtime_error("not implemented yet");
179  }
180 
181  /* TODO: can we do this automatically?
182  * Pin the memory on all devices we currently track.
183  * Pinned memory won't be released by PaRSEC and can be used
184  * at any time.
185  */
186  void pin() {
187  // nothing to do
188  }
189 
190  /* Unpin the memory on all devices we currently track. */
191  void unpin() {
192  // nothing to do
193  }
194 
195  /* Pin the memory on a given device */
196  void pin_on(int device_id) {
197  /* TODO: how can we pin memory on a device? */
198  }
199 
200  /* Pin the memory on a given device */
201  void unpin_on(int device_id) {
202  /* TODO: how can we unpin memory on a device? */
203  }
204 
205  bool is_valid() const {
206  return true;
207  }
208 
209  operator bool() const {
210  return true;
211  }
212 
213  std::size_t size() const {
214  return m_count;
215  }
216 
217  /* Reallocate the buffer with count elements */
218  void reset(std::size_t n) {
219 
220  if (m_owned) {
221  deallocate();
222  m_owned = false;
223  }
224 
225  if (n == 0) {
226  m_host_data = nullptr;
227  m_owned = false;
228  } else {
229  m_host_data = allocate(n);
230  m_owned = true;
231  }
232  m_count = n;
233  }
234 
235  /* Reset the buffer to use the ptr to count elements */
236  void reset(T* ptr, std::size_t n = 1) {
237  /* TODO: can we resize if count is smaller than m_count? */
238  if (n == m_count) {
239  return;
240  }
241 
242  if (m_owned) {
243  deallocate();
244  }
245 
246  if (nullptr == ptr) {
247  m_host_data = nullptr;
248  m_count = 0;
249  m_owned = false;
250  } else {
251  m_host_data = ptr;
252  m_count = n;
253  m_owned = false;
254  }
255  }
256 
257  /* serialization support */
258 
259 #if defined(TTG_SERIALIZATION_SUPPORTS_BOOST) && 0
260  template <typename Archive>
261  void serialize(Archive& ar, const unsigned int version) {
262  if constexpr (ttg::detail::is_output_archive_v<Archive>) {
263  std::size_t s = size();
264  ar& s;
265  /* TODO: how to serialize the array? */
266  } else {
267  std::size_t s;
268  ar & s;
269  /* initialize internal pointers and then reset */
270  reset(s);
271  /* TODO: how to deserialize the array? */
272  }
273  }
274 #endif // TTG_SERIALIZATION_SUPPORTS_BOOST
275 
276 #if defined(TTG_SERIALIZATION_SUPPORTS_MADNESS)
277  template <typename Archive>
278  std::enable_if_t<std::is_base_of_v<madness::archive::BufferInputArchive, Archive> ||
279  std::is_base_of_v<madness::archive::BufferOutputArchive, Archive>>
280  serialize(Archive& ar) {
281  if constexpr (ttg::detail::is_output_archive_v<Archive>) {
282  std::size_t s = size();
283  ar& s;
284  ar << wrap(host_ptr(), s);
285  } else {
286  std::size_t s;
287  ar & s;
288  reset(s);
289  ar >> wrap(host_ptr(), s); // MatrixTile<T>(bm.rows(), bm.cols());
290  }
291  }
292 #endif // TTG_SERIALIZATION_SUPPORTS_MADNESS
293 
294 
295 };
296 
297 } // namespace ttg_madness
298 
299 #endif // TTG_MADNESS_BUFFER_H
Represents a device in a specific execution space.
Definition: device.h:14
bool is_host() const
Definition: device.h:47
bool is_device() const
Definition: device.h:43
auto wrap(funcT &&func, const std::tuple< ttg::Edge< keyT, input_edge_valuesT >... > &inedges, const std::tuple< output_edgesT... > &outedges, const std::string &name="wrapper", const std::vector< std::string > &innames=std::vector< std::string >(sizeof...(input_edge_valuesT), "input"), const std::vector< std::string > &outnames=std::vector< std::string >(sizeof...(output_edgesT), "output"))
Definition: make_tt.h:595
this contains MADNESS-based TTG functionality
Definition: fwd.h:16
std::array< int, 3 > version()
Definition: version.cc:4
A runtime-managed buffer mirrored between host and device memory.
Definition: buffer.h:10
ttg::device::Device get_owner_device() const
Definition: buffer.h:115
Buffer & operator=(const Buffer &db)=delete
const element_type * device_ptr_on(const ttg::device::Device &device) const
Definition: buffer.h:156
Buffer & operator=(Buffer &&db)
Definition: buffer.h:92
const element_type * current_device_ptr() const
Definition: buffer.h:127
const element_type * host_ptr() const
Definition: buffer.h:166
virtual ~Buffer()
Definition: buffer.h:66
void unpin_on(int device_id)
Definition: buffer.h:201
Buffer(Buffer &&db)
Definition: buffer.h:75
void reset(std::size_t n)
Definition: buffer.h:218
std::decay_t< T > element_type
Definition: buffer.h:12
element_type * device_ptr_on(const ttg::device::Device &device)
Definition: buffer.h:148
const element_type * owner_device_ptr() const
Definition: buffer.h:141
element_type * host_ptr()
Definition: buffer.h:162
Buffer(const Buffer &db)=delete
Buffer(element_type *ptr, std::size_t n=1)
Definition: buffer.h:59
bool is_valid() const
Definition: buffer.h:205
element_type * owner_device_ptr()
Definition: buffer.h:135
element_type * current_device_ptr()
Definition: buffer.h:121
void reset(T *ptr, std::size_t n=1)
Definition: buffer.h:236
Buffer(std::size_t n)
Definition: buffer.h:49
std::size_t size() const
Definition: buffer.h:213
bool is_valid_on(const ttg::device::Device &device) const
Definition: buffer.h:170
void allocate_on(const ttg::device::Device &device_id)
Definition: buffer.h:176
typename allocator_traits::allocator_type allocator_type
Definition: buffer.h:15
void pin_on(int device_id)
Definition: buffer.h:196
std::allocator_traits< Allocator > allocator_traits
Definition: buffer.h:14
void set_current_device(const ttg::device::Device &device)
Definition: buffer.h:107