Skip to content

Commit

Permalink
fully updated to new linear regression based synaptic ca; validated r…
Browse files Browse the repository at this point in the history
…egression model with lasso & ridge regression on larger data set; smaller calcium levels work with existing lrates; even better learning performance it seems. regression weights are sensible in terms of CaP, CaD dynamics.
  • Loading branch information
rcoreilly committed Jun 13, 2024
1 parent 413b3fe commit e39a5a5
Show file tree
Hide file tree
Showing 48 changed files with 242 additions and 741 deletions.
1 change: 0 additions & 1 deletion axon/act.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,6 @@ func (ac *ActParams) DecayLearnCa(ctx *Context, ni, di uint32, decay float32) {

AddNrnV(ctx, ni, di, CaLrn, -decay*NrnV(ctx, ni, di, CaLrn))

AddNrnV(ctx, ni, di, CaSyn, -decay*NrnV(ctx, ni, di, CaSyn))
AddNrnV(ctx, ni, di, CaSpkM, -decay*NrnV(ctx, ni, di, CaSpkM))
AddNrnV(ctx, ni, di, CaSpkP, -decay*NrnV(ctx, ni, di, CaSpkP))
AddNrnV(ctx, ni, di, CaSpkD, -decay*NrnV(ctx, ni, di, CaSpkD))
Expand Down
63 changes: 10 additions & 53 deletions axon/enumgen.go

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions axon/gpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ func (gp *GPU) Config(ctx *Context, net *Network) {
gp.Sys.NewComputePipelineEmbed("Cycle", content, "shaders/gpu_cycle.spv")
gp.Sys.NewComputePipelineEmbed("CycleInc", content, "shaders/gpu_cycleinc.spv")
gp.Sys.NewComputePipelineEmbed("SendSpike", content, "shaders/gpu_sendspike.spv")
gp.Sys.NewComputePipelineEmbed("SynCa", content, "shaders/gpu_synca.spv")
gp.Sys.NewComputePipelineEmbed("CyclePost", content, "shaders/gpu_cyclepost.spv")

gp.Sys.NewComputePipelineEmbed("NewStatePool", content, "shaders/gpu_newstate_pool.spv")
Expand Down Expand Up @@ -1217,7 +1216,6 @@ func (gp *GPU) RunCycleOneCmd() vk.CommandBuffer {
gp.RunPipelineMemWait(cmd, "CyclePost", maxData)
} else {
gp.RunPipelineMemWait(cmd, "CyclePost", maxData)
gp.RunPipelineMemWait(cmd, "SynCa", neurDataN)
}

gp.Sys.ComputeCopyFromGPU(cmd, cxr, glr, lvr, plr, nrr, nrar)
Expand Down Expand Up @@ -1282,7 +1280,6 @@ func (gp *GPU) RunCyclesCmd() vk.CommandBuffer {
gp.RunPipelineMemWait(cmd, "CyclePost", maxData)
} else {
gp.RunPipelineMemWait(cmd, "CyclePost", maxData)
gp.RunPipelineMemWait(cmd, "SynCa", neurDataN)
}
if ci < CyclesN-1 {
gp.RunPipelineMemWait(cmd, "CycleInc", 1) // we do
Expand Down Expand Up @@ -1315,7 +1312,6 @@ func (gp *GPU) RunCycleSeparateFuns() {
gp.RunPipelineWait("CyclePost", maxData)
if !gp.Ctx.Testing.IsTrue() {
gp.RunPipelineWait("CyclePost", maxData)
gp.RunPipelineWait("SynCa", neurDataN)
}
gp.SyncLayerStateFromGPU()
}
Expand Down Expand Up @@ -1649,7 +1645,7 @@ func (gp *GPU) TestSynCa() bool {
limit := 2
failed := false

for vr := CaM; vr < SynapseCaVarsN; vr++ {
for vr := Tr; vr < SynapseCaVarsN; vr++ {
nfail := 0
for syni := uint32(0); syni < uint32(4); syni++ {
for di := uint32(0); di < gp.Net.MaxData; di++ {
Expand Down
130 changes: 0 additions & 130 deletions axon/gpu_hlsl/gpu_synca.hlsl

This file was deleted.

8 changes: 0 additions & 8 deletions axon/gpu_hlsl/gpu_test_synca.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,10 @@ void WriteSynCa(in Context ctx, uint syni, uint di) {
// uint pi = SynI(ctx, syni, SynPathIndex);
// uint si = SynI(ctx, syni, SynSendIndex);
// uint ri = SynI(ctx, syni, SynRecvIndex);
// SetSynCaV(ctx, syni, di, CaM, asfloat(pi));
// SetSynCaV(ctx, syni, di, CaP, asfloat(si));
// SetSynCaV(ctx, syni, di, CaD, asfloat(ri));
// SetSynCaV(ctx, syni, di, CaUpT, asfloat(syni));
// SetSynCaV(ctx, syni, di, Tr, asfloat(di));
// SetSynCaV(ctx, syni, di, DTr, asfloat(bank));
// SetSynCaV(ctx, syni, di, DiDWt, asfloat(res));

SetSynCaV(ctx, syni, di, CaM, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, CaM) % 0xFFFFFFFF)));
SetSynCaV(ctx, syni, di, CaP, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, CaP) % 0xFFFFFFFF)));
SetSynCaV(ctx, syni, di, CaD, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, CaD) % 0xFFFFFFFF)));
SetSynCaV(ctx, syni, di, CaUpT, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, CaUpT) % 0xFFFFFFFF)));
SetSynCaV(ctx, syni, di, Tr, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, Tr) % 0xFFFFFFFF)));
SetSynCaV(ctx, syni, di, DTr, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, DTr) % 0xFFFFFFFF)));
SetSynCaV(ctx, syni, di, DiDWt, asfloat(uint(ctx.SynapseCaVars.Index(syni, di, DiDWt) % 0xFFFFFFFF)));
Expand Down
27 changes: 0 additions & 27 deletions axon/layer_compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,33 +212,6 @@ func (ly *Layer) SendSpike(ctx *Context, ni uint32) {
}
}

// SynCa updates synaptic calcium based on spiking, for SynSpkTheta mode.
// Optimized version only updates at point of spiking, threaded over neurons.
// Called directly by Network, iterates over data.
func (ly *Layer) SynCa(ctx *Context, ni uint32) {
for di := uint32(0); di < ctx.NetIndexes.NData; di++ {
if NrnV(ctx, ni, di, Spike) == 0 { // di has to be outer loop b/c of this test
continue
}
updtThr := ly.Params.Learn.CaLearn.UpdateThr
if NrnV(ctx, ni, di, CaSpkP) < updtThr && NrnV(ctx, ni, di, CaSpkD) < updtThr {
continue
}
for _, sp := range ly.SndPaths {
if sp.IsOff() {
continue
}
sp.SynCaSend(ctx, ni, di, updtThr)
}
for _, rp := range ly.RcvPaths {
if rp.IsOff() {
continue
}
rp.SynCaRecv(ctx, ni, di, updtThr)
}
}
}

// LDTSrcLayAct returns the overall activity level for given source layer
// for purposes of computing ACh salience value.
// Typically the input is a superior colliculus (SC) layer that rapidly
Expand Down
48 changes: 5 additions & 43 deletions axon/learn.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ func (ln *LearnNeurParams) InitNeurCa(ctx *Context, ni, di uint32) {

SetNrnV(ctx, ni, di, CaLrn, 0)

SetNrnV(ctx, ni, di, CaSyn, 0)
SetNrnV(ctx, ni, di, CaSpkM, 0)
SetNrnV(ctx, ni, di, CaSpkP, 0)
SetNrnV(ctx, ni, di, CaSpkD, 0)
Expand Down Expand Up @@ -344,13 +343,11 @@ func (ln *LearnNeurParams) LrnNMDAFromRaw(ctx *Context, ni, di uint32, geTot flo
// CaFromSpike updates all spike-driven calcium variables, including CaLrn and CaSpk.
// Computed after new activation for current cycle is updated.
func (ln *LearnNeurParams) CaFromSpike(ctx *Context, ni, di uint32) {

caSyn := NrnV(ctx, ni, di, CaSyn)
var caSyn float32
caSpkM := NrnV(ctx, ni, di, CaSpkM)
caSpkP := NrnV(ctx, ni, di, CaSpkP)
caSpkD := NrnV(ctx, ni, di, CaSpkD)
ln.CaSpk.CaFromSpike(NrnV(ctx, ni, di, Spike), &caSyn, &caSpkM, &caSpkP, &caSpkD)
SetNrnV(ctx, ni, di, CaSyn, caSyn)
SetNrnV(ctx, ni, di, CaSpkM, caSpkM)
SetNrnV(ctx, ni, di, CaSpkP, caSpkP)
SetNrnV(ctx, ni, di, CaSpkD, caSpkD)
Expand Down Expand Up @@ -608,22 +605,6 @@ func (sp *SWtParams) WtFromDWt(wt, lwt *float32, dwt, swt float32) {
*wt = sp.WtValue(swt, *lwt)
}

// InitSynCa initializes synaptic calcium state, including CaUpT
func InitSynCa(ctx *Context, syni, di uint32) {
SetSynCaV(ctx, syni, di, CaUpT, 0)
SetSynCaV(ctx, syni, di, CaM, 0)
SetSynCaV(ctx, syni, di, CaP, 0)
SetSynCaV(ctx, syni, di, CaD, 0)
}

// DecaySynCa decays synaptic calcium by given factor (between trials)
// Not used by default.
func DecaySynCa(ctx *Context, syni, di uint32, decay float32) {
AddSynCaV(ctx, syni, di, CaM, -decay*SynCaV(ctx, syni, di, CaM))
AddSynCaV(ctx, syni, di, CaP, -decay*SynCaV(ctx, syni, di, CaP))
AddSynCaV(ctx, syni, di, CaD, -decay*SynCaV(ctx, syni, di, CaD))
}

//gosl:end learn

// InitWtsSyn initializes weight values based on WtInit randomness parameters
Expand Down Expand Up @@ -682,27 +663,11 @@ func (ls *LRateParams) Init() {
ls.UpdateEff()
}

// SynCaFuns are different ways of computing synaptic calcium (experimental)
type SynCaFuns int32 //enums:enum

const (
// StdSynCa uses standard synaptic calcium integration method
StdSynCa SynCaFuns = iota

// LinearSynCa uses linear regression generated calcium integration (much faster)
LinearSynCa

// NeurSynCa uses simple product of separately-integrated neuron values (much faster)
NeurSynCa
)

// TraceParams manages parameters associated with temporal trace learning
type TraceParams struct {

// how to compute the synaptic calcium (experimental)
SynCa SynCaFuns

// time constant for integrating trace over theta cycle timescales -- governs the decay rate of syanptic trace
// time constant for integrating trace over theta cycle timescales.
// governs the decay rate of syanptic trace
Tau float32 `default:"1,2,4"`

// amount of the mean dWt to subtract, producing a zero-sum effect -- 1.0 = full zero-sum dWt -- only on non-zero DWts. typically set to 0 for standard trace learning pathways, although some require it for stability over the long haul. can use SetSubMean to set to 1 after significant early learning has occurred with 0. Some special path types (e.g., Hebb) benefit from SubMean = 1 always
Expand All @@ -713,12 +678,9 @@ type TraceParams struct {

// rate = 1 / tau
Dt float32 `view:"-" json:"-" xml:"-" edit:"-"`

pad, pad1, pad2 float32
}

func (tp *TraceParams) Defaults() {
tp.SynCa = LinearSynCa
tp.Tau = 1
tp.SubMean = 0
tp.LearnThr = 0
Expand Down Expand Up @@ -860,8 +822,8 @@ type LearnSynParams struct {
// trace-based learning parameters
Trace TraceParams `view:"inline"`

// kinase calcium Ca integration parameters
KinaseCa kinase.SynCaParams `view:"inline"`
// kinase calcium Ca integration parameters: using linear regression parameters
KinaseCa kinase.SynCaLinear `view:"inline"`

// hebbian learning option, which overrides the default learning rules
Hebb HebbParams `view:"inline"`
Expand Down
3 changes: 0 additions & 3 deletions axon/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ func (nt *Network) Cycle(ctx *Context) {
nt.NeuronMapPar(ctx, func(ly *Layer, ni uint32) { ly.CycleNeuron(ctx, ni) }, "CycleNeuron")
nt.NeuronMapPar(ctx, func(ly *Layer, ni uint32) { ly.PostSpike(ctx, ni) }, "PostSpike")
nt.NeuronMapPar(ctx, func(ly *Layer, ni uint32) { ly.SendSpike(ctx, ni) }, "SendSpike")
if ctx.Testing.IsFalse() {
nt.NeuronMapPar(ctx, func(ly *Layer, ni uint32) { ly.SynCa(ctx, ni) }, "SynCa")
}
var ldt, vta *Layer
for _, ly := range nt.Layers {
if ly.LayerType() == VTALayer {
Expand Down
2 changes: 1 addition & 1 deletion axon/networkbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ func (nt *NetworkBase) DiffFrom(ctx *Context, on *NetworkBase, maxDiff int) stri
}
for di := uint32(0); di < ctx.NetIndexes.NData; di++ {
for si := uint32(0); si < nt.NSyns; si++ {
for svar := CaM; svar < SynapseCaVarsN; svar++ {
for svar := Tr; svar < SynapseCaVarsN; svar++ {
sv := nt.Synapses[ctx.SynapseCaVars.Index(di, si, svar)]
ov := on.Synapses[ctx.SynapseCaVars.Index(di, si, svar)]
if sv != ov {
Expand Down
Loading

0 comments on commit e39a5a5

Please sign in to comment.