forked from TESSEorg/ttg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblockmatrix.h
206 lines (172 loc) · 6.4 KB
/
blockmatrix.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#include <ttg/serialization.h>
#include <iostream>
#include <unordered_map>
template <typename T>
class BlockMatrix {
private:
int _rows = 0;
int _cols = 0;
std::shared_ptr<T> m_block; // should become std::shared_ptr<T[]> but could not make older Apple clang to accept it
public:
BlockMatrix() = default;
BlockMatrix(int rows, int cols) : _rows(rows), _cols(cols), m_block(new T[_rows * _cols], [](T* p) { delete[] p; }) {}
BlockMatrix(int rows, int cols, T* block) : _rows(rows), _cols(cols), m_block(block) {}
// Copy constructor
/*BlockMatrix(const BlockMatrix<T>& other) : _rows(other._rows), _cols(other._cols),
m_block(std::make_shared<T>(*other.m_block)) {}
//Move constructor
BlockMatrix(BlockMatrix<T>&& other) : _rows(other._rows), _cols(other._cols),
m_block(std::make_shared<T>(*other.m_block)) //is it possible to use move instead?
{}
BlockMatrix<T> operator=(BlockMatrix<T> other) {
//std::shared_ptr<T>(other.get()).swap(m_block);
std::swap(*this, other);
return *this;
}*/
~BlockMatrix() {}
int size() const { return _rows * _cols; }
int rows() const { return _rows; }
int cols() const { return _cols; }
const T* get() const { return m_block.get(); }
T* get() { return m_block.get(); }
void fill() {
// Initialize all elements of the matrix to 1
for (int i = 0; i < _rows; ++i) {
for (int j = 0; j < _cols; ++j) {
m_block.get()[i * _cols + j] = 1;
}
}
}
bool operator==(const BlockMatrix& m) const {
bool equal = true;
for (int i = 0; i < _rows; i++) {
for (int j = 0; j < _cols; j++) {
if (m_block.get()[i * _cols + j] != m.m_block.get()[i * _cols + j]) {
equal = false;
break;
}
}
}
return equal;
}
bool operator!=(const BlockMatrix& m) const {
bool notequal = false;
for (int i = 0; i < _rows; i++) {
for (int j = 0; j < _cols; j++) {
if (m_block.get()[i * _cols + j] != m.m_block.get()[i * _cols + j]) {
notequal = true;
break;
}
}
}
return notequal;
}
// Return by value
inline T& operator()(int row, int col) { return m_block.get()[row * _cols + col]; }
inline const T& operator()(int row, int col) const { return m_block.get()[row * _cols + col]; }
void operator()(int row, int col, T val) { m_block.get()[row * _cols + col] = val; }
#ifdef TTG_SERIALIZATION_SUPPORTS_BOOST
template <typename Archive>
void save(Archive& ar, const unsigned int version) const {
ar << rows() << cols();
ar << boost::serialization::make_array(get(), rows() * cols());
}
template <typename Archive>
void load(Archive& ar, const unsigned int version) {
int rows, cols;
ar >> rows >> cols;
*this = BlockMatrix<T>(rows, cols);
ar >> boost::serialization::make_array(get(), this->rows() * this->cols()); // BlockMatrix<T>(bm.rows(),
// bm.cols());
}
BOOST_SERIALIZATION_SPLIT_MEMBER();
#endif // TTG_SERIALIZATION_SUPPORTS_BOOST
};
#ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS
namespace madness {
namespace archive {
template <class Archive, typename T>
struct ArchiveStoreImpl<Archive, BlockMatrix<T>> {
static inline void store(const Archive& ar, const BlockMatrix<T>& bm) {
ar << bm.rows() << bm.cols();
ar << wrap(bm.get(), bm.rows() * bm.cols()); // BlockMatrix<T>(bm.rows(), bm.cols());
}
};
template <class Archive, typename T>
struct ArchiveLoadImpl<Archive, BlockMatrix<T>> {
static inline void load(const Archive& ar, BlockMatrix<T>& bm) {
int rows, cols;
ar >> rows >> cols;
bm = BlockMatrix<T>(rows, cols);
ar >> wrap(bm.get(), bm.rows() * bm.cols()); // BlockMatrix<T>(bm.rows(), bm.cols());
}
};
} // namespace archive
} // namespace madness
static_assert(madness::is_serializable_v<madness::archive::BufferOutputArchive, BlockMatrix<double>>);
#endif // TTG_SERIALIZATION_SUPPORTS_MADNESS
template <typename T>
std::ostream& operator<<(std::ostream& s, const BlockMatrix<T>& m) {
for (int i = 0; i < m.rows(); i++) {
for (int j = 0; j < m.cols(); j++) s << m(i, j) << " ";
s << std::endl;
}
return s;
}
// https://stackoverflow.com/questions/32685540/why-cant-i-compile-an-unordered-map-with-a-pair-as-key
// We need this since pair cannot be hashed by unordered_map.
struct pair_hash {
template <class T1, class T2>
std::size_t operator()(const std::pair<T1, T2>& p) const {
auto h1 = std::hash<T1>{}(p.first);
auto h2 = std::hash<T2>{}(p.second);
// Mainly for demonstration purposes, i.e. works but is overly simple
// In the real world, use sth. like boost.hash_combine
return h1 ^ h2;
}
};
template <typename T>
class Matrix {
private:
int nb_row; //# of blocks in a row
int nb_col; //# of blocks in a col
int b_rows; //# of rows in a block
int b_cols; //# of cols in a block
// Array of BlockMatrix<T>
std::unordered_map<std::pair<int, int>, BlockMatrix<T>, pair_hash> m;
public:
Matrix() = default;
Matrix(int nb_row, int nb_col, int b_rows, int b_cols)
: nb_row(nb_row), nb_col(nb_col), b_rows(b_rows), b_cols(b_cols) {
for (int i = 0; i < nb_row; i++)
for (int j = 0; j < nb_col; j++) {
m[std::make_pair(i, j)] = BlockMatrix<T>(b_rows, b_cols);
}
}
~Matrix() {}
// Return total # of elements in the matrix
int size() const { return (nb_row * b_rows) * (nb_col * b_cols); }
// Return # of block rows
int rows() const { return nb_row; }
// Return # of block cols
int cols() const { return nb_col; }
std::unordered_map<std::pair<int, int>, BlockMatrix<T>, pair_hash> get() const { return m; }
void fill() {
for (int i = 0; i < nb_row; i++)
for (int j = 0; j < nb_col; j++) m[std::make_pair(i, j)].fill();
}
bool operator==(const Matrix& matrix) const { return (matrix.m == m); }
bool operator!=(const Matrix& matrix) const { return (matrix.m != m); }
// Return by value
BlockMatrix<T> operator()(int block_row, int block_col) { return m[std::make_pair(block_row, block_col)]; }
/*void operator=(int block_row, int block_col, BlockMatrix<T> val) {
m[std::make_pair(block_row,block_col)] = val;
}*/
void print() {
for (int i = 0; i < nb_row; i++) {
for (int j = 0; j < nb_col; j++) {
std::cout << m[std::make_pair(i, j)];
}
}
}
};