-
Notifications
You must be signed in to change notification settings - Fork 105
/
Copy pathattention.html
206 lines (182 loc) · 6.44 KB
/
attention.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Interactive Point Transformation</title>
<script src="https://d3js.org/d3.v6.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<style>
.plot {
margin: 20px;
display: inline-block;
}
svg {
border: 1px solid black;
}
.slider-container {
margin-bottom: 20px;
}
</style>
</head>
<body>
<div>
<p>This demo shows how the representation of a query token changes in a transformer based on its relation to other tokens and the specified transformation \( v(x) \).</p>
<p>Equation for \( x_q \):</p>
<p>\[ x_q' = \sum_{x_k \in M(x_q, S_{x_k})} a(x_q, x_k; \theta_a) v(x_k; \theta_v) \]</p>
<div class="slider-container">
<label for="lambda-slider">Lambda: </label>
<input type="range" id="lambda-slider" min="0.001" max="3.0" step="0.001" value="1.0">
<span id="lambda-value">1.0</span>
</div>
</div>
<div>
<p>(c) Fayyaz Minhas</p>
</div>
<div>
<div id="plot1" class="plot">
<p>Tokens in original feature space</p>
</div>
<div id="plot2" class="plot">
<p>Tokens in \( v(x) \)</p>
</div>
<div id="plot3" class="plot">
<p>Representation of the query token (red) after transformation</p>
</div>
</div>
<script>
// Generate normally distributed points
const N = 5;
const points = d3.range(N).map(() => [d3.randomNormal()(), d3.randomNormal()()]);
// Initial value of lambda
let lambda = 1.0;
// Define kernel and v(x)
const kernel = (a, b) => Math.exp(-lambda * (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2)));
const v = (x) => [-x[0], -x[1]];
// SVG dimensions
const width = 400, height = 400, margin = 40;
// Scale functions
const xScale = d3.scaleLinear().domain([-3, 3]).range([margin, width - margin]);
const yScale = d3.scaleLinear().domain([-3, 3]).range([height - margin, margin]);
// Create SVG containers
const svg1 = d3.select("#plot1").append("svg").attr("width", width).attr("height", height);
const svg2 = d3.select("#plot2").append("svg").attr("width", width).attr("height", height);
const svg3 = d3.select("#plot3").append("svg").attr("width", width).attr("height", height);
// Draw points and labels in Plot-1
svg1.selectAll("circle")
.data(points)
.enter()
.append("circle")
.attr("cx", d => xScale(d[0]))
.attr("cy", d => yScale(d[1]))
.attr("r", 5)
.attr("fill", "blue");
svg1.selectAll("text")
.data(points)
.enter()
.append("text")
.attr("x", d => xScale(d[0]) + 7)
.attr("y", d => yScale(d[1]) - 7)
.text((d, i) => `x_${i + 1}`);
// Draw transformed points and labels in Plot-2
svg2.selectAll("circle")
.data(points)
.enter()
.append("circle")
.attr("cx", d => xScale(v(d)[0]))
.attr("cy", d => yScale(v(d)[1]))
.attr("r", 5)
.attr("fill", "green");
svg2.selectAll("text")
.data(points)
.enter()
.append("text")
.attr("x", d => xScale(v(d)[0]) + 7)
.attr("y", d => yScale(v(d)[1]) - 7)
.text((d, i) => `v(x_${i + 1})`);
// Function to update plots when lambda changes
function updatePlots() {
const [mx, my] = lastClick;
const x_q = [xScale.invert(mx), yScale.invert(my)];
// Compute a(x_q, x_k)
const S_xk = [...points, x_q];
const k_qk = S_xk.map(xk => kernel(x_q, xk));
const sum_k_qk = k_qk.reduce((sum, k_val) => sum + k_val, 0);
const a_qk = k_qk.map(k_val => k_val / sum_k_qk);
// Remove any existing user-defined points and lines
svg1.selectAll(".user-point").remove();
svg1.selectAll(".user-line").remove();
svg2.selectAll(".user-point").remove();
// Add the new point to Plot-1
svg1.append("circle")
.attr("class", "user-point")
.attr("cx", mx)
.attr("cy", my)
.attr("r", 5)
.attr("fill", "red");
// Draw lines from x_q to all points in Plot-1
svg1.selectAll(".user-line")
.data(points)
.enter()
.append("line")
.attr("class", "user-line")
.attr("x1", mx)
.attr("y1", my)
.attr("x2", d => xScale(d[0]))
.attr("y2", d => yScale(d[1]))
.attr("stroke", "black")
.attr("stroke-width", (d, i) => 5 * a_qk[i]);
// Show v(x_q) in Plot-2
svg2.append("circle")
.attr("class", "user-point")
.attr("cx", xScale(v(x_q)[0]))
.attr("cy", yScale(v(x_q)[1]))
.attr("r", 5)
.attr("fill", "red");
svg2.append("text")
.attr("class", "user-point")
.attr("x", xScale(v(x_q)[0]) + 7)
.attr("y", yScale(v(x_q)[1]) - 7)
.text("v(x_q)");
// Compute x_q' based on the provided transformation, including x_q itself in the set of points
const weights = a_qk;
const x_q_prime = S_xk.reduce((sum, xk, i) => {
const transformed = v(xk);
return [sum[0] + weights[i] * transformed[0], sum[1] + weights[i] * transformed[1]];
}, [0, 0]);
// Clear and draw transformed points in Plot-3
svg3.selectAll("*").remove();
svg3.selectAll("circle")
.data(points)
.enter()
.append("circle")
.attr("cx", d => xScale(v(d)[0]))
.attr("cy", d => yScale(v(d)[1]))
.attr("r", 5)
.attr("fill", "green");
// Add x_q' to Plot-3
svg3.append("circle")
.attr("cx", xScale(x_q_prime[0]))
.attr("cy", yScale(x_q_prime[1]))
.attr("r", 5)
.attr("fill", "red");
svg3.append("text")
.attr("x", xScale(x_q_prime[0]) + 7)
.attr("y", yScale(x_q_prime[1]) - 7)
.text("x_q'");
}
// Initialize lastClick to handle updates
let lastClick = [0, 0];
// Add x_q interactively in Plot-1
svg1.on("click", function(event) {
lastClick = d3.pointer(event);
updatePlots();
});
// Update lambda value and re-compute kernel on slider change
d3.select("#lambda-slider").on("input", function() {
lambda = +this.value;
d3.select("#lambda-value").text(lambda.toFixed(3));
updatePlots();
});
</script>
</body>
</html>