forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNamedTensor.cpp
92 lines (79 loc) · 2.91 KB
/
NamedTensor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#ifdef BUILD_NAMEDTENSOR
#include <ATen/NamedTensor.h>
#include <ATen/core/Tensor.h>
#include <torch/csrc/utils/memory.h>
namespace at {
bool NamedTensorMeta::has_names() const {
return !std::all_of(
names_.begin(), names_.end(), [](const Dimname& n) {
return n.type() == NameType::WILDCARD;
});
}
namespace impl {
// Two Dimnames cannot be in the same Tensor if one of them can refer to the other.
// In practice, this constraint means that a Tensor cannot have duplicate names
// unless they are tagged and the tags are different.
static DimnameList::const_iterator find_incompatible_name(
DimnameList::const_iterator begin,
DimnameList::const_iterator end,
const Dimname& target) {
return std::find_if(begin, end,
[&target](const Dimname& candidate) {
return target.can_refer_to(candidate) || candidate.can_refer_to(target);
});
}
static void check_unique_names(DimnameList names) {
// Strategy: Compare each element with the ones that come after it.
// Although this is O(N^2), in practice N is small (no more than 25).
for (auto it = names.begin(); it != names.end(); ++it) {
auto dup = find_incompatible_name(it + 1, names.end(), *it);
while (dup != names.end()) {
// Simple error message if you're not using tags
TORCH_CHECK(it->type() == NameType::TAGGED || dup->type() == NameType::TAGGED,
"Cannot construct a tensor with duplicate names. Got names: ",
names, ".");
// Complicated error message if you're using tags
TORCH_CHECK(false,
"Cannot construct a tensor with duplicate names unless they are tagged ",
"and have different tags. Got names: ", names, ", offending names: (",
*it, " and ", *dup, ").");
dup = find_incompatible_name(dup + 1, names.end(), *it);
}
}
}
static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) {
return static_cast<NamedTensorMeta*>(impl->named_tensor_meta());
}
void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names) {
if (!names) {
impl->set_named_tensor_meta(nullptr);
return;
}
auto ndim = impl->dim();
TORCH_CHECK(ndim == names->size(),
"Number of names (", names->size(), ") and "
"number of dimensions in tensor (", ndim, ") ",
"do not match.");
check_unique_names(*names);
auto* meta = get_named_tensor_meta(impl);
if (meta == nullptr) {
impl->set_named_tensor_meta(torch::make_unique<NamedTensorMeta>(*names));
} else {
meta->set_names_(*names);
}
}
optional<DimnameList> internal_get_names(TensorImpl* impl) {
const auto* meta = get_named_tensor_meta(impl);
if (meta == nullptr) {
return nullopt;
} else {
return meta->names();
}
}
bool internal_is_named(TensorImpl* impl) {
const auto* named_tensor_meta = get_named_tensor_meta(impl);
return named_tensor_meta != nullptr && named_tensor_meta->has_names();
}
} // namespace impl
} // namespace at
#endif