TPDE
Loading...
Searching...
No Matches
SmallVector.hpp
1// SPDX-FileCopyrightText: 2025 Contributors to TPDE <https://tpde.org>
2//
3// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4#pragma once
5
6#include "tpde/util/AddressSanitizer.hpp"
7
8#include <cassert>
9#include <cstddef>
10#include <memory>
11#include <type_traits>
12#include <utility>
13
14namespace tpde::util {
15
16class SmallVectorUntypedBase {
17public:
18 using size_type = std::size_t;
19 using difference_type = ptrdiff_t;
20
21protected:
22 void *ptr;
23 size_type sz;
24 size_type cap;
25
26 SmallVectorUntypedBase() = delete;
27 SmallVectorUntypedBase(size_type cap)
28 : ptr(reinterpret_cast<void *>(this + 1)), sz(0), cap(cap) {}
29
30 bool is_small() const {
31 return ptr == reinterpret_cast<const void *>(this + 1);
32 }
33
34 void *grow_malloc(size_type min_size, size_type elem_sz, size_type &new_cap);
35 void grow_trivial(size_type min_size, size_type elem_sz);
36
37public:
38 size_type size() const { return sz; }
39 size_type capacity() const { return cap; }
40 bool empty() const { return size() == 0; }
41};
42
43template <typename T>
44class SmallVectorBase : public SmallVectorUntypedBase {
45 // TODO: support types with larger alignment
46 static_assert(alignof(T) <= alignof(SmallVectorUntypedBase),
47 "SmallVector only supports types with pointer-sized alignment");
48
49public:
50 using value_type = T;
51 using pointer = T *;
52 using const_pointer = const T *;
53 using reference = T &;
54 using const_reference = const T &;
55 using iterator = T *;
56 using const_iterator = const T *;
57
58 static constexpr bool IsTrivial = std::is_trivially_copy_constructible_v<T> &&
59 std::is_trivially_move_constructible_v<T> &&
60 std::is_trivially_destructible_v<T>;
61
62protected:
63 SmallVectorBase(size_type cap) : SmallVectorUntypedBase(cap) {
64 poison_memory_region(ptr, cap * sizeof(T));
65 }
66
67 ~SmallVectorBase() {
68 std::destroy(begin(), end());
69 if (!is_small()) {
70 free(ptr);
71 }
72 }
73
74public:
75 pointer data() { return reinterpret_cast<T *>(ptr); }
76 const_pointer data() const { return reinterpret_cast<const T *>(ptr); }
77
78 iterator begin() { return data(); }
79 iterator end() { return data() + sz; }
80 const_iterator begin() const { return data(); }
81 const_iterator end() const { return data() + sz; }
82 const_iterator cbegin() const { return data(); }
83 const_iterator cend() const { return data() + sz; }
84
85 /// Maximum number of elements the vector can hold.
86 static size_type max_size() { return size_type(-1) / sizeof(T); }
87
88 reference operator[](size_type idx) {
89 assert(idx < size());
90 return data()[idx];
91 }
92 const_reference operator[](size_type idx) const {
93 assert(idx < size());
94 return data()[idx];
95 }
96
97 reference front() { return (*this)[0]; }
98 const_reference front() const { return (*this)[0]; }
99 reference back() { return (*this)[size() - 1]; }
100 const_reference back() const { return (*this)[size() - 1]; }
101
102private:
103 void ensure_space(size_type num_elems) {
104 assert(num_elems <= max_size() - size());
105 if (size() + num_elems > capacity()) [[unlikely]] {
106 grow(size() + num_elems);
107 }
108 }
109
110 void grow(size_type new_size)
111 requires IsTrivial
112 {
113 grow_trivial(new_size, sizeof(T));
114 }
115
116 void grow(size_type new_size)
117 requires(!IsTrivial);
118
119public:
120 /// Append an element. elem must not be a reference inside to the vector.`
121 void push_back(const T &elem) {
122 ensure_space(1);
123 unpoison_memory_region(end(), sizeof(T));
124 ::new (reinterpret_cast<void *>(end())) T(elem);
125 sz += 1;
126 }
127
128 /// Append an element. elem must not be a reference inside to the vector.`
129 void push_back(T &&elem) {
130 ensure_space(1);
131 unpoison_memory_region(end(), sizeof(T));
132 ::new (reinterpret_cast<void *>(end())) T(::std::move(elem));
133 sz += 1;
134 }
135
136 template <typename... ArgT>
137 reference emplace_back(ArgT &&...args) {
138 ensure_space(1);
139 unpoison_memory_region(end(), sizeof(T));
140 ::new (reinterpret_cast<void *>(end())) T(::std::forward<ArgT>(args)...);
141 sz += 1;
142 return back();
143 }
144
145 void pop_back() {
146 back().~T();
147 poison_memory_region(end(), sizeof(T));
148 sz -= 1;
149 }
150
151 void reserve(size_type new_cap) {
152 if (cap < new_cap) {
153 grow(new_cap);
154 }
155 }
156
157private:
158 template <bool Initialize>
159 void resize(size_type new_size) {
160 if (sz > new_size) {
161 std::destroy(begin() + new_size, end());
162 poison_memory_region(begin() + new_size, (sz - new_size) * sizeof(T));
163 sz = new_size;
164 } else if (sz < new_size) {
165 reserve(new_size);
166 unpoison_memory_region(end(), (new_size - sz) * sizeof(T));
167 for (pointer it = end(), e = begin() + new_size; it != e; ++it) {
168 Initialize ? ::new (it) T() : ::new (it) T;
169 }
170 sz = new_size;
171 }
172 }
173
174public:
175 void resize(size_type new_size) { resize<true>(new_size); }
176 void resize_uninitialized(size_type new_size) { resize<false>(new_size); }
177
178 void resize(size_type new_size, const T &init) {
179 if (sz > new_size) {
180 std::destroy(begin() + new_size, end());
181 poison_memory_region(begin() + new_size, (sz - new_size) * sizeof(T));
182 sz = new_size;
183 } else if (sz < new_size) {
184 reserve(new_size);
185 unpoison_memory_region(end(), (new_size - sz) * sizeof(T));
186 std::uninitialized_fill(begin() + sz, begin() + new_size, init);
187 sz = new_size;
188 }
189 }
190
191 iterator erase(iterator start_it, iterator end_it) {
192 iterator res = start_it;
193 while (end_it != end()) {
194 *start_it++ = std::move(*end_it++);
195 }
196 std::destroy(start_it, end_it);
197 sz = start_it - begin();
198 return res;
199 }
200
201 template <typename It>
202 void append(It start_it, It end_it) {
203 size_type n = std::distance(start_it, end_it);
204 reserve(sz + n);
205 unpoison_memory_region(end(), n * sizeof(T));
206 std::uninitialized_copy(start_it, end_it, end());
207 sz += n;
208 }
209
210 void clear() {
211 std::destroy(begin(), end());
212 poison_memory_region(begin(), sz * sizeof(T));
213 sz = 0;
214 }
215};
216
217template <class T>
218void SmallVectorBase<T>::grow(size_type new_size)
219 requires(!IsTrivial)
220{
221 size_type new_cap;
222 T *new_alloc = static_cast<T *>(grow_malloc(new_size, sizeof(T), new_cap));
223 std::uninitialized_move(begin(), end(), new_alloc);
224 poison_memory_region(new_alloc + sz, (new_cap - sz) * sizeof(T));
225 std::destroy(begin(), end());
226 if (!is_small()) {
227 free(ptr);
228 } else {
229 poison_memory_region(data(), cap * sizeof(T));
230 }
231 ptr = new_alloc;
232 cap = new_cap;
233}
234
235template <typename T>
236constexpr size_t SmallVectorDefaultSize = sizeof(T) < 256 ? 256 / sizeof(T) : 1;
237
238template <class T, size_t N = SmallVectorDefaultSize<T>>
239class SmallVector : public SmallVectorBase<T> {
240 // This is required, is_small() checks whether the current allocation is
241 // immediately following the SmallVectorBase. For zero-sized small allocation,
242 // the heap-allocated buffer could not be distinguished otherwise.
243 // TODO: find an efficient way to avoid this.
244 alignas(T) char elements[N == 0 ? 1 : N * sizeof(T)];
245
246public:
247 SmallVector() : SmallVectorBase<T>(N) {}
248
249 SmallVector(const SmallVector &) = delete;
250 SmallVector(SmallVector &&) = delete;
251 SmallVector &operator=(const SmallVector &) = delete;
252 SmallVector &operator=(SmallVector &&) = delete;
253 // TODO: Implement this if required.
254#if 0
255 SmallVector(SmallVector &other) : SmallVectorBase<T>(N) {
256 SmallVectorBase<T>::operator=(other);
257 }
258 SmallVector(SmallVectorBase<T> &other) : SmallVectorBase<T>(N) {
259 SmallVectorBase<T>::operator=(other);
260 }
261
262 SmallVector(SmallVector &&other) : SmallVectorBase<T>(N) {
263 SmallVectorBase<T>::operator=(std::move(other));
264 }
265 SmallVector(SmallVectorBase<T> &&other) : SmallVectorBase<T>(N) {
266 SmallVectorBase<T>::operator=(std::move(other));
267 }
268
269 SmallVector &operator=(SmallVector &other) {
270 SmallVectorBase<T>::operator=(other);
271 return *this;
272 }
273 SmallVector &operator=(SmallVectorBase<T> &other) {
274 SmallVectorBase<T>::operator=(other);
275 return *this;
276 }
277
278 SmallVector &operator=(SmallVector &&other) {
279 SmallVectorBase<T>::operator=(std::move(other));
280 return *this;
281 }
282 SmallVector &operator=(SmallVectorBase<T> &&other) {
283 SmallVectorBase<T>::operator=(std::move(other));
284 return *this;
285 }
286#endif
287};
288
289} // end namespace tpde::util