-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzddb.cc
355 lines (314 loc) · 10.7 KB
/
zddb.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#include <algorithm>
#include <cassert>
#include <fstream>
#include <iostream>
#include <print>
#include <tuple>
#include <unordered_map>
#include <vector>
// We need to define std::hash for our std::tuple in order to put them in
// a std::unordered_map.
// begin cut-and-paste from stackoverflow
// function has to live in the std namespace
// so that it is picked up by argument-dependent name lookup (ADL).
namespace std {
namespace {
// Code from boost
// Reciprocal of the golden ratio helps spread entropy
// and handles duplicates.
// See Mike Seymour in magic-numbers-in-boosthash-combine:
// https://stackoverflow.com/questions/4948780
template <class T>
inline void hash_combine(std::size_t &seed, T const &v) {
seed ^= hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// Recursive template code derived from Matthieu M.
template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
struct HashValueImpl {
static void apply(size_t& seed, Tuple const &tuple) {
HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
hash_combine(seed, get<Index>(tuple));
}
};
template <class Tuple>
struct HashValueImpl<Tuple, 0> {
static void apply(size_t& seed, Tuple const &tuple) {
hash_combine(seed, get<0>(tuple));
}
};
}
template <typename ... TT>
struct hash<std::tuple<TT...>> {
size_t operator()(std::tuple<TT...> const &tt) const {
size_t seed = 0;
HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
return seed;
}
};
}
// end cut-and-paste from stackoverflow
namespace {
using LabelTy = int;
using IdxTy = int;
static constexpr IdxTy lo_idx = -1;
static constexpr IdxTy hi_idx = -2;
static constexpr LabelTy LO = -1;
static constexpr LabelTy HI = -2;
struct ZddNode {
explicit ZddNode(IdxTy lo, IdxTy hi, LabelTy label)
: lo(lo), hi(hi), label(label) {}
IdxTy lo = -3;
IdxTy hi = -3;
LabelTy label = -3;
};
// This is a container that holds any number of ZDDs that share the same set of
// underlying variables, and can only be mutated by adding new nodes.
struct ZddNodes {
std::vector<ZddNode> nodes;
std::unordered_map<std::tuple<IdxTy, IdxTy, LabelTy>, IdxTy> unique;
// Get the label for the given node by index.
LabelTy get_label(IdxTy idx) const {
if (idx == lo_idx)
return LO;
if (idx == hi_idx)
return HI;
assert(idx < nodes.size());
return nodes[idx].label;
}
// Look up the node by index.
const ZddNode &operator[](IdxTy idx) const {
assert(idx < nodes.size());
return nodes[idx];
}
// Create a new node, or return the unique one if it already exists.
IdxTy get(IdxTy lo, IdxTy hi, LabelTy label) {
assert(label != LO && label != HI);
assert(hi != lo_idx);
auto [it, did_insert] = unique.try_emplace({lo, hi, label}, nodes.size());
if (did_insert) {
auto idx = nodes.size();
nodes.emplace_back(lo, hi, label);
return idx;
}
return it->second;
}
};
// This is a thin wrapper around ZddNodes that represents a single specific ZDD
// by pointing at the underlying flat array of nodes and knowing the index of
// the root node.
struct Zdd {
explicit Zdd() {}
explicit Zdd(ZddNodes *memory) : memory(memory), root(lo_idx) {}
explicit Zdd(ZddNodes *memory, IdxTy root) : memory(memory), root(root) {}
ZddNodes *memory = nullptr;
IdxTy root = -3;
LabelTy get_label(IdxTy idx) const {
assert(memory);
return memory->get_label(idx);
}
const ZddNode &operator[](IdxTy idx) const {
assert(memory);
return (*memory)[idx];
}
IdxTy get(IdxTy lo, IdxTy hi, LabelTy label) {
assert(memory);
return memory->get(lo, hi, label);
}
// Verify that this is a correctly constructed ZDD.
void verify() const {
assert(memory);
for (auto n : memory->nodes) {
(void)n;
assert(n.lo == lo_idx || n.lo == hi_idx ||
n.label < (*memory)[n.lo].label);
assert(n.hi == hi_idx || n.label < (*memory)[n.hi].label);
}
}
// Enumerates every set in the ZDD and calls your callback for them, stops
// when your callback returns true.
template<typename CB>
void enumerate_sets(CB cb) {
if (root == lo_idx)
return;
std::vector<std::pair<IdxTy, std::vector<LabelTy>>> stack;
stack.emplace_back(root, std::vector<LabelTy>{});
do {
auto &[nidx, variables] = stack.back();
assert(nidx != lo_idx);
if (nidx == hi_idx) {
if constexpr (std::is_void_v<decltype(cb(variables))>)
cb(variables);
else if (cb(variables))
return;
stack.pop_back();
continue;
}
const auto &n = (*this)[nidx];
if (n.lo == lo_idx) {
// Reuse existing `stack` entry to follow 'hi' edge. Update index, add
// current node to the list of variables.
nidx = n.hi;
variables.emplace_back(n.label);
continue;
}
// Existing entry in `stack` follows 'lo' edge. Update index but the
// variables stay the same.
nidx = n.lo;
// Follow 'hi' edge and add it to `stack`. Add current node label to
// the variables.
auto variables_copy = variables;
variables_copy.emplace_back(n.label);
stack.emplace_back(n.hi, variables_copy);
} while (!stack.empty());
}
};
IdxTy multiunion(ZddNodes &ret, std::vector<Zdd> worklist, bool include_hi);
// Unions any number of ZDDs. Commonly used to construct ZDDs by unioning sets.
IdxTy multiunion(ZddNodes &ret, const std::vector<Zdd> &in) {
std::vector<Zdd> worklist;
bool include_hi = false;
for (int i = 0, e = in.size(); i != e; ++i) {
if (in[i].root == lo_idx)
continue;
if (in[i].root == hi_idx) {
include_hi = true;
continue;
}
worklist.push_back(in[i]);
}
return multiunion(ret, worklist, include_hi);
}
// Internal API. Worklist must not include any lo_idx or hi_idx ZDDs.
// To union with hi_idx, set include_hi instead.
IdxTy multiunion(ZddNodes &ret, std::vector<Zdd> worklist, bool include_hi) {
#ifndef NDEBUG
for (auto &z : worklist)
assert(z.root >= 0);
#endif
if (worklist.empty())
return include_hi ? hi_idx : lo_idx;
// When we're down to one ZDD, just make a copy of it.
//
// This is a performance improvement only, multiunion works correctly if you
// simply remove this code block.
if (worklist.size() == 1 && !include_hi) {
auto root = worklist[0].root;
if (root < 0)
return root;
auto memory = worklist[0].memory;
std::unordered_map<IdxTy, IdxTy> cache{{lo_idx, lo_idx}, {hi_idx, hi_idx}};
std::vector<IdxTy> copy_worklist(1, root);
do {
auto n_idx = copy_worklist.back();
Zdd zdd(memory, n_idx);
auto lo = zdd[n_idx].lo, hi = zdd[n_idx].hi;
auto lo_it = cache.find(lo), hi_it = cache.find(hi);
bool contains_lo = lo_it != cache.end();
bool contains_hi = hi_it != cache.end();
if (contains_lo && contains_hi) {
auto idx = ret.get(lo_it->second, hi_it->second, zdd.get_label(n_idx));
if (n_idx == root)
return idx;
cache[n_idx] = idx;
copy_worklist.pop_back();
} else {
// Don't pop the current node, we'll revisit it after visiting the two
// we're pushing now, it will be a leaf next time.
if (!contains_lo)
copy_worklist.push_back(lo);
if (!contains_hi && lo != hi)
copy_worklist.push_back(hi);
}
} while (1);
}
// Find the next lowest label. Every node points to nodes with a greater
// label or one of the terminal nodes.
LabelTy next_label;
{
auto next_it =
std::min_element(worklist.begin(), worklist.end(),
[&](const Zdd &lhs, const Zdd &rhs) {
// There are no C++ objects for the terminal nodes, detect their special
// labels so we don't try to do lookups of them. Terminal labels sort
// last since any other node with a label must occur first along the
// path to the terminal nodes.
bool lhs_is_terminal = lhs.root < 0;
bool rhs_is_terminal = rhs.root < 0;
if (!lhs_is_terminal && !rhs_is_terminal)
return lhs.get_label(lhs.root) < rhs.get_label(rhs.root);
if (lhs_is_terminal && rhs_is_terminal)
return lhs.root < rhs.root;
return rhs_is_terminal;
});
next_label = next_it->get_label(next_it->root);
}
// The worklist never contains 'LO' nodes because unioning with 'LO' is
// the identity operation. The worklist never contains 'HI' nodes, we store
// those by setting `include_hi` to true instead.
assert(next_label != LO && next_label != HI);
// Partition the remaining nodes into lo-side and hi-side:
// * for nodes with label == next_label, expand them:
// + add their lo to next_lo and hi to next_hi
// - except don't add lo_idx
// - also don't add hi_idx by setting include_hi instead
// * otherwise, add the node to next_lo for processing on a later step.
std::vector<Zdd> next_lo, next_hi;
bool include_hi_lo = include_hi, include_hi_hi = false;
for (const auto &z : worklist) {
auto root = z.root;
if (z.get_label(root) == next_label) {
switch (z[root].lo) {
case lo_idx:
break;
case hi_idx:
include_hi_lo = true;
break;
default:
next_lo.emplace_back(z.memory, z[root].lo);
}
if (z[root].hi == hi_idx)
include_hi_hi = true;
else
next_hi.emplace_back(z.memory, z[root].hi);
} else {
next_lo.push_back(z);
}
}
return ret.get(multiunion(ret, next_lo, include_hi_lo),
multiunion(ret, next_hi, include_hi_hi),
next_label);
}
// Produce a ZDD with a single set in it representing the bytes in line, each
// byte represented with one label in range of [0 .. 255] + (column * 256).
IdxTy line_to_zdd(Zdd &ret, std::string line) {
assert(line.size() < 8388608); // assumes 32-bit 'LabelTy'
IdxTy hi = hi_idx;
int column = line.size() - 1;
for (auto i = line.rbegin(), e = line.rend(); i != e; ++i)
hi = ret.get(lo_idx, hi, *i + (256 * column--));
return hi;
}
} // end anonymous namespace
int main(int argc, char **argv) {
ZddNodes flat2;
Zdd all_lines(&flat2);
{
ZddNodes flat1;
std::vector<Zdd> lines;
{
std::ifstream ifs(argv[1]);
for (std::string line; std::getline(ifs, line);) {
lines.emplace_back(&flat1);
lines.back().root = line_to_zdd(lines.back(), line);
assert((lines.back().verify(), true));
}
}
all_lines.root = multiunion(flat2, lines);
assert((all_lines.verify(), true));
}
all_lines.enumerate_sets([&](auto variables) {
for (auto v : variables)
std::print("{}", (char)(v % 256));
std::println("");
});
}