-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgeneralized_sampling.go
144 lines (122 loc) · 4.09 KB
/
generalized_sampling.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package cfr
import (
"math/rand"
"github.com/timpalpant/go-cfr/internal/f32"
)
type GeneralizedSamplingCFR struct {
strategyProfile StrategyProfile
sampler Sampler
slicePool *floatSlicePool
mapPool *keyIntMapPool
rng *rand.Rand
traversingPlayer int
sampledActions map[string]int
}
func NewGeneralizedSampling(strategyProfile StrategyProfile, sampler Sampler) *GeneralizedSamplingCFR {
return &GeneralizedSamplingCFR{
strategyProfile: strategyProfile,
sampler: sampler,
slicePool: &floatSlicePool{},
mapPool: &keyIntMapPool{},
rng: rand.New(rand.NewSource(rand.Int63())),
}
}
func (c *GeneralizedSamplingCFR) Run(node GameTreeNode) float32 {
iter := c.strategyProfile.Iter()
c.traversingPlayer = int(iter % 2)
c.sampledActions = c.mapPool.alloc()
defer c.mapPool.free(c.sampledActions)
return c.runHelper(node, node.Player(), 1.0)
}
func (c *GeneralizedSamplingCFR) runHelper(node GameTreeNode, lastPlayer int, sampleProb float32) float32 {
var ev float32
switch node.Type() {
case TerminalNodeType:
ev = float32(node.Utility(lastPlayer))
case ChanceNodeType:
ev = c.handleChanceNode(node, lastPlayer, sampleProb)
default:
sgn := getSign(lastPlayer, node.Player())
ev = sgn * c.handlePlayerNode(node, sampleProb)
}
node.Close()
return ev
}
func (c *GeneralizedSamplingCFR) handleChanceNode(node GameTreeNode, lastPlayer int, sampleProb float32) float32 {
child, _ := node.SampleChild()
// Sampling probabilities cancel out in the calculation of counterfactual value.
return c.runHelper(child, lastPlayer, sampleProb)
}
func (c *GeneralizedSamplingCFR) handlePlayerNode(node GameTreeNode, sampleProb float32) float32 {
if node.Player() == c.traversingPlayer {
return c.handleTraversingPlayerNode(node, sampleProb)
} else {
return c.handleSampledPlayerNode(node, sampleProb)
}
}
func (c *GeneralizedSamplingCFR) handleTraversingPlayerNode(node GameTreeNode, sampleProb float32) float32 {
player := node.Player()
nChildren := node.NumChildren()
if nChildren == 1 {
// Optimization to skip trivial nodes with no real choice.
child := node.GetChild(0)
return c.runHelper(child, player, sampleProb)
}
policy := c.strategyProfile.GetPolicy(node)
qs := c.slicePool.alloc(nChildren)
copy(qs, c.sampler.Sample(node, policy))
regrets := c.slicePool.alloc(nChildren)
oldSampledActions := c.sampledActions
c.sampledActions = c.mapPool.alloc()
for i, q := range qs {
child := node.GetChild(i)
var util float32
if q > 0 {
util = c.runHelper(child, player, q*sampleProb)
} else {
util = c.probe(child, player)
}
regrets[i] = util
}
cfValue := f32.DotUnitary(policy.GetStrategy(), regrets)
f32.AddConst(-cfValue, regrets)
policy.AddRegret(1.0/sampleProb, qs, regrets)
c.slicePool.free(qs)
c.slicePool.free(regrets)
c.mapPool.free(c.sampledActions)
c.sampledActions = oldSampledActions
return cfValue
}
// Sample player action according to strategy, do not update policy.
// Save selected action so that they are reused if this infoset is hit again.
func (c *GeneralizedSamplingCFR) handleSampledPlayerNode(node GameTreeNode, sampleProb float32) float32 {
policy := c.strategyProfile.GetPolicy(node)
// Update average strategy for this node.
// We perform "stochastic" updates as described in the MC-CFR paper.
if sampleProb > 0 {
policy.AddStrategyWeight(1.0 / sampleProb)
}
// Sampling probabilities cancel out in the calculation of counterfactual value,
// so we don't include them here.
child := node.GetChild(getOrSample(c.sampledActions, node, policy, c.rng))
return c.runHelper(child, node.Player(), sampleProb)
}
func (c *GeneralizedSamplingCFR) probe(node GameTreeNode, player int) float32 {
var ev float32
switch node.Type() {
case TerminalNodeType:
ev = float32(node.Utility(player))
case ChanceNodeType:
child, _ := node.SampleChild()
ev = c.probe(child, player)
default:
policy := c.strategyProfile.GetPolicy(node)
strategy := policy.GetStrategy()
x := c.rng.Float32()
selected := sampleOne(strategy, x)
child := node.GetChild(selected)
ev = c.probe(child, player)
}
node.Close()
return ev
}