TPDE
Loading...
Searching...
No Matches
RegisterFile.hpp
1// SPDX-FileCopyrightText: 2025 Contributors to TPDE <https://tpde.org>
2//
3// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4#pragma once
5
6#include "tpde/ValLocalIdx.hpp"
7#include "tpde/base.hpp"
8#include "tpde/util/misc.hpp"
9
10#include <array>
11
12namespace tpde {
13
14struct Reg {
15 u8 reg_id;
16
17 explicit constexpr Reg(const u64 id) noexcept : reg_id(static_cast<u8>(id)) {
18 assert(id <= 255);
19 }
20
21 constexpr u8 id() const noexcept { return reg_id; }
22
23 constexpr bool invalid() const noexcept { return reg_id == 0xFF; }
24
25 constexpr bool valid() const noexcept { return reg_id != 0xFF; }
26
27 constexpr static Reg make_invalid() noexcept { return Reg{(u8)0xFF}; }
28
29 constexpr bool operator==(const Reg &other) const noexcept {
30 return reg_id == other.reg_id;
31 }
32};
33
34struct RegBank {
35private:
36 u8 bank;
37
38public:
39 constexpr RegBank() noexcept : bank(u8(-1)) {}
40
41 constexpr explicit RegBank(u8 bank) noexcept : bank(bank) {}
42
43 constexpr u8 id() const noexcept { return bank; }
44
45 constexpr bool operator==(const RegBank &other) const noexcept {
46 return bank == other.bank;
47 }
48};
49
50template <unsigned NumBanks, unsigned RegsPerBank>
51class RegisterFile {
52public:
53 static constexpr unsigned NumRegs = NumBanks * RegsPerBank;
54
55 static_assert(RegsPerBank > 0 && (RegsPerBank & (RegsPerBank - 1)) == 0,
56 "RegsPerBank must be a power of two");
57 static_assert(NumRegs < Reg::make_invalid().id());
58 static_assert(NumRegs <= 64);
59
60 // later add the possibility for more than 64 registers
61 // for architectures that require it
62 using RegBitSet = u64;
63
64 /// Registers that are generally allocatable and not reserved.
65 RegBitSet allocatable = 0;
66 /// Registers that are currently in use. Requires allocatable.
67 RegBitSet used = 0;
68 /// Registers that were clobbered at some point. Used to track registers that
69 /// need to be saved/restored.
70 RegBitSet clobbered = 0;
71 std::array<u8, NumBanks> clocks{};
72
73 struct Assignment {
74 ValLocalIdx local_idx;
75 u32 part;
76 };
77
78 std::array<Assignment, NumRegs> assignments;
79
80 std::array<u8, NumRegs> lock_counts{};
81
82 void reset() noexcept {
83 used = {};
84 clobbered = {};
85 clocks = {};
86 lock_counts = {};
87 }
88
89 [[nodiscard]] bool is_used(const Reg reg) const noexcept {
90 assert(reg.id() < NumRegs);
91 return (used & 1ull << reg.id()) != 0;
92 }
93
94 [[nodiscard]] bool is_fixed(const Reg reg) const noexcept {
95 assert(reg.id() < NumRegs);
96 return lock_counts[reg.id()] > 0;
97 }
98
99 [[nodiscard]] bool is_clobbered(const Reg reg) const noexcept {
100 assert(reg.id() < NumRegs);
101 return (clobbered & 1ull << reg.id()) != 0;
102 }
103
104 void mark_used(const Reg reg,
105 const ValLocalIdx local_idx,
106 const u32 part) noexcept {
107 assert(reg.id() < NumRegs);
108 assert(!is_used(reg));
109 assert(!is_fixed(reg));
110 assert(lock_counts[reg.id()] == 0);
111 used |= (1ull << reg.id());
112 assignments[reg.id()] = Assignment{.local_idx = local_idx, .part = part};
113 }
114
115 void update_reg_assignment(const Reg reg,
116 const ValLocalIdx local_idx,
117 const u32 part) noexcept {
118 assert(is_used(reg));
119 assignments[reg.id()].local_idx = local_idx;
120 assignments[reg.id()].part = part;
121 }
122
123 void unmark_used(const Reg reg) noexcept {
124 assert(reg.id() < NumRegs);
125 assert(is_used(reg));
126 assert(!is_fixed(reg));
127 assert(lock_counts[reg.id()] == 0);
128 used &= ~(1ull << reg.id());
129 }
130
131 void mark_fixed(const Reg reg) noexcept {
132 assert(reg.id() < NumRegs);
133 assert(is_used(reg));
134 assert(lock_counts[reg.id()] == 0);
135 lock_counts[reg.id()] = 1;
136 }
137
138 void unmark_fixed(const Reg reg) noexcept {
139 assert(reg.id() < NumRegs);
140 assert(is_used(reg));
141 assert(is_fixed(reg));
142 assert(lock_counts[reg.id()] == 1);
143 lock_counts[reg.id()] = 0;
144 }
145
146 void inc_lock_count(const Reg reg) noexcept {
147 assert(reg.id() < NumRegs);
148 assert(is_used(reg));
149 ++lock_counts[reg.id()];
150 }
151
152 /// Returns true if the last lock was released.
153 bool dec_lock_count(const Reg reg) noexcept {
154 assert(reg.id() < NumRegs);
155 assert(is_used(reg));
156 assert(lock_counts[reg.id()] > 0);
157 if (--lock_counts[reg.id()] == 0) {
158 return true;
159 }
160 return false;
161 }
162
163 /// Decrement lock count by sub, and assert that the register is now unlocked
164 void dec_lock_count_must_zero(const Reg reg,
165 [[maybe_unused]] u8 sub = 1) noexcept {
166 assert(reg.id() < NumRegs);
167 assert(is_used(reg));
168 assert(lock_counts[reg.id()] == sub);
169 lock_counts[reg.id()] = 0;
170 }
171
172 void mark_clobbered(const Reg reg) noexcept {
173 assert(reg.id() < NumRegs);
174 clobbered |= (1ull << reg.id());
175 }
176
177 [[nodiscard]] ValLocalIdx reg_local_idx(const Reg reg) const noexcept {
178 assert(is_used(reg));
179 return assignments[reg.id()].local_idx;
180 }
181
182 [[nodiscard]] u32 reg_part(const Reg reg) const noexcept {
183 assert(is_used(reg));
184 return assignments[reg.id()].part;
185 }
186
187 [[nodiscard]] util::BitSetIterator<> used_regs() const noexcept {
188 return util::BitSetIterator<>{used};
189 }
190
191 [[nodiscard]] Reg
192 find_first_free_excluding(const RegBank bank,
193 const u64 exclusion_mask) const noexcept {
194 // TODO(ts): implement preferred registers
195 const RegBitSet free_bank = allocatable & ~used & bank_regs(bank);
196 const RegBitSet selectable = free_bank & ~exclusion_mask;
197 if (selectable == 0) {
198 return Reg::make_invalid();
199 }
200 return Reg{static_cast<u8>(util::cnt_tz(selectable))};
201 }
202
203 [[nodiscard]] Reg
204 find_first_nonfixed_excluding(const RegBank bank,
205 const u64 exclusion_mask) const noexcept {
206 // TODO(ts): implement preferred registers
207 for (auto reg_id : util::BitSetIterator<>{used & bank_regs(bank)}) {
208 if (!is_fixed(Reg{reg_id}) && !((u64{1} << reg_id) & exclusion_mask)) {
209 return Reg{reg_id};
210 }
211 }
212 return Reg::make_invalid();
213 }
214
215 [[nodiscard]] static RegBank reg_bank(const Reg reg) noexcept {
216 return RegBank(reg.id() / RegsPerBank);
217 }
218
219 [[nodiscard]] static RegBitSet bank_regs(const RegBank bank) noexcept {
220 assert(bank.id() <= 1);
221 return ((1ull << RegsPerBank) - 1) << (bank.id() * RegsPerBank);
222 }
223};
224
225} // namespace tpde