forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTensorNames.cpp
127 lines (107 loc) · 3.77 KB
/
TensorNames.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <ATen/TensorNames.h>
#include <ATen/WrapDimUtils.h>
namespace at { namespace namedinference {
Dimname TensorName::toDimname() const {
return name_;
}
const TensorName& TensorName::unify(const TensorName& other, const char* op_name) const {
// unify(None, None)
if (name_.isWildcard() && other.name_.isWildcard()) {
return *this;
}
// unify(A, A)
if (name_ == other.name_) {
return *this;
}
// unify(A, None)
if (other.name_.isWildcard()) {
const auto it = std::find(other.origin_.begin(), other.origin_.end(), name_);
TORCH_CHECK(it == other.origin_.end(),
op_name, ":",
" Cannot match ", *this, " with ", other,
" because the latter names already have ", name_, ".",
" Are your tensors misaligned?");
return *this;
}
// unify(None, A)
if (name_.isWildcard()) {
return other.unify(*this, op_name);
}
// unify(A, B)
TORCH_CHECK(name_ == other.name_,
op_name, ":",
" Expected ", *this,
" to match ", other,
" but they do not match.");
return *this;
}
TensorNames::TensorNames(ArrayRef<Dimname> names) {
names_.reserve(names.size());
for (int64_t idx = 0; idx < names.size(); ++idx) {
names_.emplace_back(names, idx);
}
}
TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) {
start = maybe_wrap_dim(start, names.size());
end = maybe_wrap_dim(end, names.size());
names_.reserve(end - start);
for (int64_t idx = start; idx < end; ++idx) {
names_.emplace_back(names, idx);
}
}
TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) {
int64_t size_diff = std::labs(names_.size() - other.names_.size());
if (names_.size() > other.names_.size()) {
for (int64_t idx = size_diff; idx < names_.size(); ++idx) {
names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name);
}
} else {
// pad names_ to the same length as other.names_ before unification
names_.insert(
names_.begin(),
other.names_.begin(),
other.names_.begin() + size_diff);
for (int64_t idx = size_diff; idx < names_.size(); ++idx) {
names_[idx] = names_[idx].unify(other.names_[idx], op_name);
}
}
return *this;
}
void TensorNames::append(TensorName&& name) {
names_.emplace_back(name);
}
void TensorNames::checkUnique(const char* op_name) const {
// O(N^2), but named tensors can have at most N = 64 dimensions, so this
// doesn't matter unless benchmarking tells us it does. The alternative is
// to create some sort of set data structure but the overhead of that
// might dominate for small sizes.
for (auto it = names_.begin(); it != names_.end(); ++it) {
const auto name = it->toDimname();
if (name.isWildcard()) continue;
auto dup = std::find_if(it + 1, names_.end(),
[&](const TensorName& other) { return other.toDimname() == name; });
TORCH_CHECK(dup == names_.end(),
op_name, ": ",
"Attempted to propagate dims ", *it, " and ", *dup, " to the output, ",
"but that would create a tensor with duplicate names [", toDimnameVec(),
"]. Please rename your inputs with Tensor.rename to prevent this.");
}
}
// Let's say the TensorName represents 'C' in ['N', 'C', 'H, 'W'].
// It should print like:
// 'C' (index 1 of ['N', 'C', 'H', 'W'])
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
out << tensorname.name_ << " (index ";
out << tensorname.origin_idx_ << " of ";
out << tensorname.origin_ << ")";
return out;
}
std::vector<Dimname> TensorNames::toDimnameVec() const {
std::vector<Dimname> result;
result.reserve(names_.size());
for (const auto& tensor_name : names_) {
result.emplace_back(tensor_name.toDimname());
}
return result;
}
}} // namespace at::namedinference