-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinterface.go
122 lines (105 loc) · 3.93 KB
/
interface.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package cfr
import (
"encoding"
"io"
)
// NodeType is the type of node in an extensive-form game tree.
type NodeType int
const (
ChanceNodeType NodeType = iota
TerminalNodeType
PlayerNodeType
)
// InfoSet is the observable game history from the point of view of one player.
type InfoSet interface {
// Key is an identifier used to uniquely look up this InfoSet
// when accumulating probabilities in tabular CFR.
//
// It may be an arbitrary string of bytes and does not need to be
// human-readable. For example, it could be a simplified abstraction
// or hash of the full game history.
Key() []byte
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
}
// ChanceNode is a node that has a pre-defined probability distribution over its children.
type ChanceNode interface {
// Get the probability of the ith child of this node.
// May only be called for nodes with Type == Chance.
GetChildProbability(i int) float64
// Sample a single child from this Chance node according to the probability
// distribution over children.
//
// Implementations may reuse sampling.SampleChanceNode to sample from the CDF,
// (by scanning over GetChildProbability) or implement their own more efficient
// sampling.
SampleChild() (child GameTreeNode, p float64)
}
// PlayerNode is a node in which one of the player's acts.
type PlayerNode interface {
// Player returns this current node's acting player.
// It may only be called for nodes with IsChance() == false.
Player() int
// InfoSet returns the information set for this node for the given player.
InfoSet(player int) InfoSet
// InfoSetKey returns the equivalent of InfoSet(player).Key(),
// but can be used to avoid allocations incurred by the InfoSet interface.
InfoSetKey(player int) []byte
// Utility returns this node's utility for the given player.
// It must only be called for nodes with type == Terminal.
Utility(player int) float64
}
// Tree node represents a node in a directed rooted tree.
type TreeNode interface {
// The number of direct children of this node.
NumChildren() int
// Get the ith child of this node.
GetChild(i int) GameTreeNode
// Get the parent of this node.
Parent() GameTreeNode
}
// GameTreeNode is the interface for a node in an extensive-form game tree.
type GameTreeNode interface {
// NodeType returns the type of game node.
Type() NodeType
// Release resources held by this node (including any children).
Close()
TreeNode
ChanceNode
PlayerNode
}
// StrategyProfile maintains a collection of regret-matching policies for each
// player node in the game tree.
//
// The policytable and deepcfr packages provide implementations of StrategyProfile.
type StrategyProfile interface {
// GetPolicy returns the NodePolicy for the given node.
GetPolicy(node GameTreeNode) NodePolicy
// Calculate the next strategy profile for all visited nodes.
Update()
// Get the current iteration (number of times update has been called).
Iter() int
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
io.Closer
}
// NodePolicy maintains the action policy for a single Player node.
type NodePolicy interface {
// AddRegret provides new observed instantaneous regrets
// to add to the total accumulated regret with the given weight.
AddRegret(w float32, samplingQ, instantaneousRegrets []float32)
// GetStrategy gets the current vector of probabilities with which the ith
// available action should be played.
GetStrategy() []float32
// GetBaseline gets the current vector of action-dependend baseline values,
// used in VR-MCCFR.
GetBaseline() []float32
// UpdateBaseline updates the current vector of baseline values.
UpdateBaseline(w float32, action int, value float32)
// AddStrategyWeight adds the current strategy with weight w to the average.
AddStrategyWeight(w float32)
// GetAverageStrategy returns the average strategy over all iterations.
GetAverageStrategy() []float32
// IsEmpty returns true if the NodePolicy is new and has no accumulated regret.
IsEmpty() bool
}