TPDE
Loading...
Searching...
No Matches
SmallVector.hpp
1// SPDX-FileCopyrightText: 2025 Contributors to TPDE <https://tpde.org>
2//
3// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4#pragma once
5
6#include "tpde/util/AddressSanitizer.hpp"
7
8#include <cassert>
9#include <cstddef>
10#include <memory>
11#include <type_traits>
12#include <utility>
13
14namespace tpde::util {
15
16class SmallVectorUntypedBase {
17public:
18 using size_type = std::size_t;
19 using difference_type = ptrdiff_t;
20
21protected:
22 void *ptr;
23 size_type sz;
24 size_type cap;
25
26 SmallVectorUntypedBase() = delete;
27 SmallVectorUntypedBase(size_type cap) : ptr(small_ptr()), sz(0), cap(cap) {}
28
29 void *small_ptr() { return static_cast<void *>(this + 1); }
30 const void *small_ptr() const { return static_cast<const void *>(this + 1); }
31
32 bool is_small() const { return ptr == small_ptr(); }
33
34 void *grow_malloc(size_type min_size, size_type elem_sz, size_type &new_cap);
35 void grow_trivial(size_type min_size, size_type elem_sz);
36
37public:
38 size_type size() const { return sz; }
39 size_type capacity() const { return cap; }
40 bool empty() const { return size() == 0; }
41};
42
43template <typename T>
44class SmallVectorBase : public SmallVectorUntypedBase {
45 // TODO: support types with larger alignment
46 static_assert(alignof(T) <= alignof(SmallVectorUntypedBase),
47 "SmallVector only supports types with pointer-sized alignment");
48
49public:
50 using value_type = T;
51 using pointer = T *;
52 using const_pointer = const T *;
53 using reference = T &;
54 using const_reference = const T &;
55 using iterator = T *;
56 using const_iterator = const T *;
57
58 static constexpr bool IsTrivial = std::is_trivially_copy_constructible_v<T> &&
59 std::is_trivially_move_constructible_v<T> &&
60 std::is_trivially_destructible_v<T>;
61
62protected:
63 SmallVectorBase(size_type cap) : SmallVectorUntypedBase(cap) {
64 poison_memory_region(ptr, cap * sizeof(T));
65 }
66
67 ~SmallVectorBase() {
68 std::destroy(begin(), end());
69 if (!is_small()) {
70 free(ptr);
71 }
72 }
73
74 SmallVectorBase &operator=(SmallVectorBase &&other) {
75 clear();
76 if (!other.is_small()) {
77 ptr = other.ptr;
78 sz = other.sz;
79 cap = other.cap;
80 other.ptr = other.small_ptr();
81 other.sz = 0;
82 other.cap = 0;
83 } else {
84 reserve(other.size());
85 unpoison_memory_region(begin(), other.size() * sizeof(T));
86 std::uninitialized_move(other.begin(), other.end(), begin());
87 sz = other.size();
88 other.clear();
89 }
90 return *this;
91 }
92
93public:
94 pointer data() { return reinterpret_cast<T *>(ptr); }
95 const_pointer data() const { return reinterpret_cast<const T *>(ptr); }
96
97 iterator begin() { return data(); }
98 iterator end() { return data() + sz; }
99 const_iterator begin() const { return data(); }
100 const_iterator end() const { return data() + sz; }
101 const_iterator cbegin() const { return data(); }
102 const_iterator cend() const { return data() + sz; }
103
104 /// Maximum number of elements the vector can hold.
105 static size_type max_size() { return size_type(-1) / sizeof(T); }
106
107 reference operator[](size_type idx) {
108 assert(idx < size());
109 return data()[idx];
110 }
111 const_reference operator[](size_type idx) const {
112 assert(idx < size());
113 return data()[idx];
114 }
115
116 reference front() { return (*this)[0]; }
117 const_reference front() const { return (*this)[0]; }
118 reference back() { return (*this)[size() - 1]; }
119 const_reference back() const { return (*this)[size() - 1]; }
120
121private:
122 void ensure_space(size_type num_elems) {
123 assert(num_elems <= max_size() - size());
124 if (size() + num_elems > capacity()) [[unlikely]] {
125 grow(size() + num_elems);
126 }
127 }
128
129 void grow(size_type new_size)
130 requires IsTrivial
131 {
132 grow_trivial(new_size, sizeof(T));
133 }
134
135 void grow(size_type new_size)
136 requires(!IsTrivial);
137
138public:
139 /// Append an element. elem must not be a reference inside to the vector.`
140 void push_back(const T &elem) {
141 ensure_space(1);
142 unpoison_memory_region(end(), sizeof(T));
143 ::new (reinterpret_cast<void *>(end())) T(elem);
144 sz += 1;
145 }
146
147 /// Append an element. elem must not be a reference inside to the vector.`
148 void push_back(T &&elem) {
149 ensure_space(1);
150 unpoison_memory_region(end(), sizeof(T));
151 ::new (reinterpret_cast<void *>(end())) T(::std::move(elem));
152 sz += 1;
153 }
154
155 template <typename... ArgT>
156 reference emplace_back(ArgT &&...args) {
157 ensure_space(1);
158 unpoison_memory_region(end(), sizeof(T));
159 ::new (reinterpret_cast<void *>(end())) T(::std::forward<ArgT>(args)...);
160 sz += 1;
161 return back();
162 }
163
164 void pop_back() {
165 back().~T();
166 poison_memory_region(end(), sizeof(T));
167 sz -= 1;
168 }
169
170 void reserve(size_type new_cap) {
171 if (cap < new_cap) {
172 grow(new_cap);
173 }
174 }
175
176private:
177 template <bool Initialize>
178 void resize(size_type new_size) {
179 if (sz > new_size) {
180 std::destroy(begin() + new_size, end());
181 poison_memory_region(begin() + new_size, (sz - new_size) * sizeof(T));
182 sz = new_size;
183 } else if (sz < new_size) {
184 reserve(new_size);
185 unpoison_memory_region(end(), (new_size - sz) * sizeof(T));
186 for (pointer it = end(), e = begin() + new_size; it != e; ++it) {
187 Initialize ? ::new (it) T() : ::new (it) T;
188 }
189 sz = new_size;
190 }
191 }
192
193public:
194 void resize(size_type new_size) { resize<true>(new_size); }
195 void resize_uninitialized(size_type new_size) { resize<false>(new_size); }
196
197 void resize(size_type new_size, const T &init) {
198 if (sz > new_size) {
199 std::destroy(begin() + new_size, end());
200 poison_memory_region(begin() + new_size, (sz - new_size) * sizeof(T));
201 sz = new_size;
202 } else if (sz < new_size) {
203 reserve(new_size);
204 unpoison_memory_region(end(), (new_size - sz) * sizeof(T));
205 std::uninitialized_fill(begin() + sz, begin() + new_size, init);
206 sz = new_size;
207 }
208 }
209
210 iterator erase(iterator start_it, iterator end_it) {
211 iterator res = start_it;
212 while (end_it != end()) {
213 *start_it++ = std::move(*end_it++);
214 }
215 std::destroy(start_it, end_it);
216 sz = start_it - begin();
217 return res;
218 }
219
220 template <typename It>
221 void append(It start_it, It end_it) {
222 size_type n = std::distance(start_it, end_it);
223 reserve(sz + n);
224 unpoison_memory_region(end(), n * sizeof(T));
225 std::uninitialized_copy(start_it, end_it, end());
226 sz += n;
227 }
228
229 void clear() {
230 std::destroy(begin(), end());
231 poison_memory_region(begin(), sz * sizeof(T));
232 sz = 0;
233 }
234};
235
236template <class T>
237void SmallVectorBase<T>::grow(size_type new_size)
238 requires(!IsTrivial)
239{
240 size_type new_cap;
241 T *new_alloc = static_cast<T *>(grow_malloc(new_size, sizeof(T), new_cap));
242 std::uninitialized_move(begin(), end(), new_alloc);
243 poison_memory_region(new_alloc + sz, (new_cap - sz) * sizeof(T));
244 std::destroy(begin(), end());
245 if (!is_small()) {
246 free(ptr);
247 } else {
248 poison_memory_region(data(), cap * sizeof(T));
249 }
250 ptr = new_alloc;
251 cap = new_cap;
252}
253
254template <typename T>
255constexpr size_t SmallVectorDefaultSize = sizeof(T) < 256 ? 256 / sizeof(T) : 1;
256
257template <class T, size_t N = SmallVectorDefaultSize<T>>
258class SmallVector : public SmallVectorBase<T> {
259 // This is required, is_small() checks whether the current allocation is
260 // immediately following the SmallVectorBase. For zero-sized small allocation,
261 // the heap-allocated buffer could not be distinguished otherwise.
262 // TODO: find an efficient way to avoid this.
263 alignas(T) char elements[N == 0 ? 1 : N * sizeof(T)];
264
265public:
266 SmallVector() : SmallVectorBase<T>(N) {}
267
268 SmallVector(const SmallVector &) = delete;
269 SmallVector &operator=(const SmallVector &) = delete;
270 // TODO: Implement this if required.
271#if 0
272 SmallVector(SmallVector &other) : SmallVectorBase<T>(N) {
273 SmallVectorBase<T>::operator=(other);
274 }
275 SmallVector(SmallVectorBase<T> &other) : SmallVectorBase<T>(N) {
276 SmallVectorBase<T>::operator=(other);
277 }
278#endif
279
280 SmallVector(SmallVector &&other) : SmallVectorBase<T>(N) {
281 SmallVectorBase<T>::operator=(std::move(other));
282 }
283 SmallVector(SmallVectorBase<T> &&other) : SmallVectorBase<T>(N) {
284 SmallVectorBase<T>::operator=(std::move(other));
285 }
286
287#if 0
288 SmallVector &operator=(SmallVector &other) {
289 SmallVectorBase<T>::operator=(other);
290 return *this;
291 }
292 SmallVector &operator=(SmallVectorBase<T> &other) {
293 SmallVectorBase<T>::operator=(other);
294 return *this;
295 }
296#endif
297
298 SmallVector &operator=(SmallVector &&other) {
299 SmallVectorBase<T>::operator=(std::move(other));
300 return *this;
301 }
302 SmallVector &operator=(SmallVectorBase<T> &&other) {
303 SmallVectorBase<T>::operator=(std::move(other));
304 return *this;
305 }
306};
307
308} // end namespace tpde::util