Skip to content

Commit

Permalink
major updates to Good / Bad stats -- distinguish between US, CS for d…
Browse files Browse the repository at this point in the history
…ifferent stats as appropriate; params updates for BLA, etc.
  • Loading branch information
rcoreilly committed May 31, 2024
1 parent d24e085 commit 92af672
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 36 deletions.
125 changes: 98 additions & 27 deletions examples/choose/choose.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ func main() {
}
}

var (
CSAggStats = []string{"PVposEst", "PVnegEst", "PVposVar", "PVnegVar", "GateVMtxGo", "GateVMtxNo", "GateVMtxGoNo", "GateBLAposAcq", "GateBLAposExt", "GateBLAposAcqExt"}
USAggStats = []string{"Rew_R", "DA_R", "RewPred_R", "VtaDA"}
)

// see params.go for network params, config.go for Config

// Sim encapsulates the entire simulation model, and we define all the
Expand Down Expand Up @@ -868,24 +873,34 @@ func (ss *Sim) GatedStats(di int) {
ss.Stats.SetFloat32("AChShould", nan)
ss.Stats.SetFloat32("AChShouldnt", nan)
ss.Stats.SetFloat32("BadCSGate", nan)
ss.Stats.SetFloat32("BadUSGate", nan)
ss.Stats.SetFloat32("GateVMtxGo", nan)
ss.Stats.SetFloat32("GateVMtxNo", nan)
ss.Stats.SetFloat32("GateVMtxGoNo", nan)
ss.Stats.SetFloat32("GateBLApos", nan)
ss.Stats.SetFloat32("GateBLAposAcq", nan)
ss.Stats.SetFloat32("GateBLAposExt", nan)
ss.Stats.SetFloat32("GateBLAposAcqExt", nan)
hasPos := rp.HasPosUS(ctx, diu)
if justGated && !hasPos {
ss.Stats.SetFloat32("BadCSGate", num.FromBool[float32](!ev.ArmIsBest(ev.Arm)))
vsgo := net.AxonLayerByName("VMtxGo")
vsno := net.AxonLayerByName("VMtxNo")
goact := ss.MaxPoolSpkMax(vsgo, diu)
noact := ss.MaxPoolSpkMax(vsno, diu)
ss.Stats.SetFloat32("GateVMtxGo", goact)
ss.Stats.SetFloat32("GateVMtxNo", noact)
ss.Stats.SetFloat32("GateVMtxGoNo", goact-noact)
bla := net.AxonLayerByName("BLAposAcqD1")
blact := ss.MaxPoolSpkMax(bla, diu)
ss.Stats.SetFloat32("GateBLApos", blact)

if justGated {
if hasPos {
ss.Stats.SetFloat32("BadUSGate", num.FromBool[float32](!ev.ArmIsBest(ev.Arm)))
} else {
ss.Stats.SetFloat32("BadCSGate", num.FromBool[float32](!ev.ArmIsBest(ev.Arm)))
vsgo := net.AxonLayerByName("VMtxGo")
vsno := net.AxonLayerByName("VMtxNo")
goact := ss.MaxPoolSpkMax(vsgo, diu)
noact := ss.MaxPoolSpkMax(vsno, diu)
ss.Stats.SetFloat32("GateVMtxGo", goact)
ss.Stats.SetFloat32("GateVMtxNo", noact)
ss.Stats.SetFloat32("GateVMtxGoNo", goact-noact)
bla := net.AxonLayerByName("BLAposAcqD1")
ble := net.AxonLayerByName("BLAposExtD2")
blaact := ss.MaxPoolSpkMax(bla, diu)
bleact := ss.MaxPoolSpkMax(ble, diu)
ss.Stats.SetFloat32("GateBLAposAcq", blaact)
ss.Stats.SetFloat32("GateBLAposExt", bleact)
ss.Stats.SetFloat32("GateBLAposAcqExt", blaact-bleact)
}
}
if ev.ShouldGate {
if hasPos {
Expand Down Expand Up @@ -1030,8 +1045,11 @@ func (ss *Sim) ConfigLogItems() {
ss.Logs.AddStatAggItem("GateVMtxGo", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("GateVMtxNo", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("GateVMtxGoNo", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("GateBLApos", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("GateBLAposAcq", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("GateBLAposExt", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("GateBLAposAcqExt", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("BadCSGate", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("BadUSGate", etime.Run, etime.Epoch, etime.Trial)

// Add a special debug message -- use of etime.Debug triggers
// inclusion
Expand All @@ -1040,8 +1058,9 @@ func (ss *Sim) ConfigLogItems() {
}

stNm := []string{"Good", "Bad"}
allst := append(CSAggStats, USAggStats...)
for wrong := 0; wrong < 2; wrong++ {
for _, st := range ss.Config.Log.AggStats {
for _, st := range allst {
itmName := fmt.Sprintf("%s_%s", stNm[wrong], st)
ss.Logs.AddItem(&elog.Item{
Name: itmName,
Expand Down Expand Up @@ -1103,11 +1122,11 @@ func (ss *Sim) ConfigLogItems() {
}
}

// EpochBadStats aggregates stats separately for BadCSGate = 0 vs. 1
// EpochCSBadStats aggregates stats separately for BadCSGate = 0 vs. 1
// i.e., for trials when it selects the "wrong" option (not the best) = (Bad)
// vs. when it does select the best option (Good)
func (ss *Sim) EpochBadStats() {
lgnm := "EpochBadStats"
func (ss *Sim) EpochCSBadStats() {
lgnm := "EpochCSBadStats"

ix := ss.Logs.IndexView(etime.Train, etime.Trial)
ix.Filter(func(et *table.Table, row int) bool {
Expand All @@ -1123,12 +1142,56 @@ func (ss *Sim) EpochBadStats() {
}
dt := spl.AggsToTable(table.ColumnNameOnly)
ix.Sequential() // return to full table for subsequent use

dt.SetMetaData("PVposEst:On", "+")
dt.SetMetaData("PVnegEst:On", "+")
dt.SetMetaData("GateVMtxGoNo:On", "+")
dt.SetMetaData("GateBLAposAcqExt:On", "+")
// dt.SetMetaData("XAxisRot", "45")
dt.SetMetaData("Type", "Bar")
ss.Logs.MiscTables[lgnm] = dt

// grab selected stats at CS for higher level aggregation,
nrows := dt.Rows
stNm := []string{"Good", "Bad"}
for ri := 0; ri < nrows; ri++ {
wrong := dt.Float("BadCSGate", ri)
for _, st := range CSAggStats {
ss.Stats.SetFloat(fmt.Sprintf("%s_%s", stNm[int(wrong)], st), dt.Float(st, ri))
}
}
if ss.Config.GUI {
plt := ss.GUI.Plots[etime.ScopeKey(lgnm)]
plt.SetTable(dt)
plt.GoUpdatePlot()
}
}

// EpochUSBadStats aggregates stats separately for BadUSGate = 0 vs. 1
// i.e., for trials when it selects the "wrong" option (not the best) = (Bad)
// vs. when it does select the best option (Good)
func (ss *Sim) EpochUSBadStats() {
lgnm := "EpochUSBadStats"

ix := ss.Logs.IndexView(etime.Train, etime.Trial)
ix.Filter(func(et *table.Table, row int) bool {
return !math.IsNaN(et.Float("BadUSGate", row)) // && (et.StringValue("ActAction", row) == "Consume")
})
spl := split.GroupBy(ix, []string{"BadUSGate"})
for _, ts := range ix.Table.ColumnNames {
col := ix.Table.ColumnByName(ts)
if col.DataType() == reflect.String || ts == "BadUSGate" {
continue
}
split.AggColumn(spl, ts, stats.Mean)
}
dt := spl.AggsToTable(table.ColumnNameOnly)
ix.Sequential() // return to full table for subsequent use

dt.SetMetaData("Rew_R:On", "+")
dt.SetMetaData("DA_R:On", "+")
dt.SetMetaData("RewPred_R:On", "+")
dt.SetMetaData("VtaDA:On", "+")
dt.SetMetaData("PVposEst:On", "+")
dt.SetMetaData("PVnegEst:On", "+")
dt.SetMetaData("DA_R:FixMin", "+")
dt.SetMetaData("DA_R:Min", "-1")
dt.SetMetaData("DA_R:FixMax", "-")
Expand All @@ -1138,21 +1201,20 @@ func (ss *Sim) EpochBadStats() {
ss.Logs.MiscTables[lgnm] = dt

// grab selected stats at CS and US for higher level aggregation,

nrows := dt.Rows
stNm := []string{"Good", "Bad"}
for ri := 0; ri < nrows; ri++ {
wrong := dt.Float("BadCSGate", ri)
for _, st := range ss.Config.Log.AggStats {
wrong := dt.Float("BadUSGate", ri)
for _, st := range USAggStats {
ss.Stats.SetFloat(fmt.Sprintf("%s_%s", stNm[int(wrong)], st), dt.Float(st, ri))
}
}

if ss.Config.GUI {
plt := ss.GUI.Plots[etime.ScopeKey(lgnm)]
plt.SetTable(dt)
plt.GoUpdatePlot()
}

}

// Log is the main logging function, handles special things for different scopes
Expand Down Expand Up @@ -1208,7 +1270,8 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) {
}
case mode == etime.Train && time == etime.Epoch:
axon.LayerActsLogAvg(ss.Net, &ss.Logs, &ss.GUI, true) // reset recs
ss.EpochBadStats()
ss.EpochCSBadStats()
ss.EpochUSBadStats()
}

ss.Logs.LogRow(mode, time, row) // also logs to file, etc
Expand Down Expand Up @@ -1275,14 +1338,22 @@ func (ss *Sim) ConfigGUI() {

axon.LayerActsLogConfigGUI(&ss.Logs, &ss.GUI)

lgnm := "EpochBadStats"
lgnm := "EpochCSBadStats"
dt := ss.Logs.MiscTable(lgnm)
plt := plotview.NewSubPlot(ss.GUI.Tabs.NewTab(lgnm + " Plot"))
ss.GUI.Plots[etime.ScopeKey(lgnm)] = plt
plt.Params.Title = lgnm
plt.Params.XAxisColumn = "BadCSGate"
plt.SetTable(dt)

lgnm = "EpochUSBadStats"
dt = ss.Logs.MiscTable(lgnm)
plt = plotview.NewSubPlot(ss.GUI.Tabs.NewTab(lgnm + " Plot"))
ss.GUI.Plots[etime.ScopeKey(lgnm)] = plt
plt.Params.Title = lgnm
plt.Params.XAxisColumn = "BadUSGate"
plt.SetTable(dt)

ss.GUI.Body.AddAppBar(func(tb *core.Toolbar) {
ss.GUI.AddToolbarItem(tb, egui.ToolbarItem{Label: "Init", Icon: icons.Update,
Tooltip: "Initialize everything including network weights, and start over. Also applies current params.",
Expand Down
3 changes: 0 additions & 3 deletions examples/choose/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ type RunConfig struct {
// LogConfig has config parameters related to logging data
type LogConfig struct {

// stats to aggregate separately for good and bad choices
AggStats []string `default:"['Rew_R', 'DA_R', 'RewPred_R', 'VtaDA', 'PVposEst', 'PVnegEst', 'PVposVar', 'PVnegVar', 'GateVMtxGo', 'GateVMtxNo', 'GateVMtxGoNo', 'GateBLApos']"`

// if true, save final weights after each run
SaveWts bool

Expand Down
12 changes: 6 additions & 6 deletions examples/choose/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ var ParamSets = netparams.Sets{
}},
{Sel: ".BLAExtPath", Desc: "ext learns relatively fast",
Params: params.Params{
"Path.Learn.LRate.Base": "0.05",
"Path.Learn.LRate.Base": "0.05", // 0.05 > 0.02 = 0.01
}},
{Sel: ".BLAAcqToGo", Desc: "must dominate",
Params: params.Params{
Expand All @@ -191,7 +191,11 @@ var ParamSets = netparams.Sets{
}},
{Sel: ".BLAExtToAcq", Desc: "",
Params: params.Params{
"Path.PathScale.Abs": "2", // note: key param -- 0.5 > 1
"Path.PathScale.Abs": "1", // 1 == 2
}},
{Sel: ".CSToBLApos", Desc: "",
Params: params.Params{
"Path.Learn.LRate.Base": "0.01", // 0.02 > 0.01 for early gating; 0.01 more consistent
}},
{Sel: ".PFCToVSMtx", Desc: "contextual, should be weaker",
Params: params.Params{
Expand All @@ -212,10 +216,6 @@ var ParamSets = netparams.Sets{
"Path.Learn.Trace.LearnThr": "0",
"Path.Learn.LRate.Base": "0.02", // 0.02 needed in test; better overall
}},
{Sel: "#CSToBLAposAcqD1", Desc: "",
Params: params.Params{
"Path.Learn.LRate.Base": "0.02", // was 0.5 -- too fast!?
}},
{Sel: ".CSToBLANovelInhib", Desc: "",
Params: params.Params{
"Path.Learn.LRate.Base": "0.02", // 0.01 > 0.005 -- too slow is bad
Expand Down

0 comments on commit 92af672

Please sign in to comment.