-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathparams.go
55 lines (46 loc) · 1.48 KB
/
params.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
package cfr
import (
"math"
)
// DiscountParams modify how regret is accumulated.
// An empty DiscountParams is valid and corresponds to traditional
// (MC)CFR without weighting.
type DiscountParams struct {
UseRegretMatchingPlus bool // CFR+
LinearWeighting bool // Linear CFR
DiscountAlpha float32 // Discounted CFR
DiscountBeta float32 // Discounted CFR
DiscountGamma float32 // Discounted CFR
}
// Gets the discount factors as configured by the parameters for the
// various CFR weighting schemes: CFR+, linear CFR, etc.
func (p DiscountParams) GetDiscountFactors(iter int) (positive, negative, sum float32) {
positive = float32(1.0)
negative = float32(1.0)
sum = float32(1.0)
// See: https://arxiv.org/pdf/1809.04040.pdf
// Linear CFR is equivalent to weighting the reach prob on each
// iteration by (t / (t+1)), and this reduces numerical instability.
if p.LinearWeighting {
sum = float32(iter) / float32(iter+1)
}
if p.UseRegretMatchingPlus {
negative = 0.0 // No negative regrets.
}
if p.DiscountAlpha != 0 {
// t^alpha / (t^alpha + 1)
x := float32(math.Pow(float64(iter), float64(p.DiscountAlpha)))
positive = x / (x + 1.0)
}
if p.DiscountBeta != 0 {
// t^beta / (t^beta + 1)
x := float32(math.Pow(float64(iter), float64(p.DiscountBeta)))
negative = x / (x + 1.0)
}
if p.DiscountGamma != 0 {
// (t / (t+1)) ^ gamma
x := float64(iter) / float64(iter+1)
sum = float32(math.Pow(x, float64(p.DiscountGamma)))
}
return
}