forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdorefa.py
57 lines (49 loc) · 1.69 KB
/
dorefa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dorefa.py
# Author: Yuxin Wu <[email protected]>
import tensorflow as tf
from tensorpack.utils.argtools import memoized
@memoized
def get_dorefa(bitW, bitA, bitG):
"""
return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively
It's unsafe to call this function multiple times with different parameters
"""
G = tf.get_default_graph()
def quantize(x, k):
n = float(2**k - 1)
with G.gradient_override_map({"Round": "Identity"}):
return tf.round(x * n) / n
def fw(x):
if bitW == 32:
return x
if bitW == 1: # BWN
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1
def fa(x):
if bitA == 32:
return x
return quantize(x, bitA)
@tf.RegisterGradient("FGGrad")
def grad_fg(op, x):
rank = x.get_shape().ndims
assert rank is not None
maxx = tf.reduce_max(tf.abs(x), list(range(1, rank)), keep_dims=True)
x = x / maxx
n = float(2**bitG - 1)
x = x * 0.5 + 0.5 + tf.random_uniform(
tf.shape(x), minval=-0.5 / n, maxval=0.5 / n)
x = tf.clip_by_value(x, 0.0, 1.0)
x = quantize(x, bitG) - 0.5
return x * maxx * 2
def fg(x):
if bitG == 32:
return x
with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x)
return fw, fa, fg