-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBanditTaskParam.py
159 lines (137 loc) · 8.73 KB
/
BanditTaskParam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:12:28 2020
@author: kblackw1
"""
############ reward ################
rwd={'error':-5,'reward':10,'base':-1,'none':-1,'partial':0}
#rwd={'error':-1,'reward':8,'base':0,'none':0,'partial':0} #use these for Opal and Bogacz to show that 3 step tasks do not work - same mean rwd per trial
#rwd={'error':-1,'reward':6,'base':0,'none':0,'partial':1} #use these for Opal and Bogacz to improve performance on 3 step tasks - same mean rwd per trial
#rwd={'error':-2,'reward':5,'base':-0.5,'none':-0.5,'partial':0} #show more lose-shift and sampling behavior? No
######### Parameters for the agent ##################################
params={}
params['wt_learning']=False
params['wt_noise']=False #whether to multiply noise by learning_rate - not helpful
params['numQ']=2
params['alpha']=[0.3,0.06] # learning rate 0.3 and 0.06 produce learning in 400 trials,#slower for Q2 - D2 neurons
params['beta']=0.9 # inverse temperature, controls exploration
params['beta_min']=0.1
params['gamma']=0.9 #discount factor#
params['hist_len']=40
params['state_thresh']=[0.12,0.2] #similarity of noisy state to ideal state
#if lower state_creation_threshold for Q[0] compared to Q[1], then Q[0] will have fewer states
#possibly multiply state_thresh by learn_rate? to change dynamically?
params['sigma']=0.25 #similarity of noisey state to ideal state,std used in Mahalanobis distance.
params['time_inc']=0.1 #increment time since reward by this much in no reward
params['moving_avg_window']=3 #This in units of trials, the actual window is this times the number of events per trial
params['decision_rule']=None #'combo', 'delta', 'sumQ2', None ## None means use direct negative of D1 rule
params['Q2other']=0.1
params['forgetting']=0
params['reward_cues']=None ##options: 'TSR', 'RewHx3', 'reward', None
params['distance']='Euclidean'
params['initQ']=-1 #split states, initialize Q values to best matching
params['events_per_trial']=3
############### Make sure you have all the state transitions needed ##########
def validate_T(T,msg=''):
print(msg)
for st in T.keys():
for newst_list in T[st].values():
for newst in newst_list:
if newst[0] not in T.keys():
print('new state', newst[0],'not in Tacq')
def validate_R(T,R,msg=''):
print(msg)
for st in T.keys():
if st not in R.keys():
print('state', st,'in T,but not in R')
else:
for a in T[st].keys():
if a not in R[st].keys():
print('state/action:', st,a,'in T, but not in R')
######### Parameters for the environment ##################################
include_wander=True
act={'center':0,'left':1,'return':2,'right':3,'hold':4} # , 'wander':5
if include_wander:
act['wander']=5
states={'loc':{'start':0,'Pport':2,'Lport':1,'other':-1,'Rport':3},
'tone':{'blip':0,'6kHz':6,'success':2,'error':-2,'10kHz':10}}
params['state_units']={'loc':False,'tone':True} #Try false/true
#some convenient variables
start=(states['loc']['start'],states['tone']['blip']) #used many times
env_params={'start':start}
loc=states['loc'] #used only to define R and T
tone=states['tone'] #used only to define R and T
move=['left','right','center'] #,'wander'
stay=['hold']
if include_wander:
move=move+['wander']
################# Two arm bandit task of Josh Berke (Hamid et al. Nat Neuro V19)
Rbandit={};Tbandit={} #dictionaries to improve readability/prevent mistakes
prwdR=0.8; prwdL=0.5 #initial values. These change with each block of trials
####value of T dict is the new state
#from start - best response is poke
Tbandit[start]={a:[(start,1)] for a in act.values()} # stay at start, unless move
for a in ['left','right']:
Tbandit[start][act[a]]=[((loc['other'],tone['error']),1)] #error if go to left or right port from start box
Tbandit[start][act['center']]=[((loc['Pport'],tone['6kHz']),1)] #poke at start tone, go to poke port
if include_wander:
Tbandit[start][act['wander']]=[((loc['other'],tone['blip']),1)] #meandering, not yet at poke port **
#What happens if agent doesn't go to poke port immediately **
Tbandit[(loc['other'],tone['blip'])]={a:[((loc['other'],tone['blip']),1)] for a in act.values()} #default: remain in other, unless
Tbandit[(loc['other'],tone['blip'])][act['center']]=[((loc['Pport'],tone['6kHz']),1)] #go to center port, after wandering
Tbandit[(loc['other'],tone['blip'])][act['return']]=[(start,1)] #return to start
#from poke port - correct response is 'left'
Tbandit[(loc['Pport'],tone['6kHz'])]={a:[((loc['Pport'],tone['6kHz']),1)] for a in act.values()} #default - stay in poke port
if include_wander:
Tbandit[(loc['Pport'],tone['6kHz'])][act['wander']]=[((loc['other'],tone['error']),1)] #incorrect movements
Tbandit[(loc['Pport'],tone['6kHz'])][act['return']]=[((loc['start'],tone['error']),1)] #incorrect movements
Tbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[((loc['Rport'],tone['success']),prwdR),((loc['Rport'],tone['error']),1-prwdR)]
Tbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[((loc['Lport'],tone['success']),prwdL),((loc['Lport'],tone['error']),1-prwdL)] #hear tone in poke port, go left, in left port/success
Tbandit[(loc['other'],tone['error'])]={a:[((loc['other'],tone['error']),1)] for a in act.values()}#remain in other unless
Tbandit[(loc['start'],tone['error'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move}
Tbandit[(loc['start'],tone['error'])][act['hold']]=[(start,1)]
#from left port or right port - best response is return
Tbandit[(loc['Lport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} #default, wandering around
Tbandit[(loc['Lport'],tone['success'])][act['hold']]=[((loc['Lport'],tone['error']),1)] #staying at Lport, but not continued success (reward)
Tbandit[(loc['Lport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again
Tbandit[(loc['Rport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} #default, wandering around
Tbandit[(loc['Rport'],tone['success'])][act['hold']]=[((loc['Rport'],tone['error']),1)] #staying at Lport, but not continued success (reward)
Tbandit[(loc['Rport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again
#from right or left port or any with error tone - best response is return
Tbandit[(loc['Rport'],tone['error'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move}
Tbandit[(loc['Rport'],tone['error'])][act['hold']]=[((loc['Rport'],tone['error']),1)] #remain in Rport if no movement,
Tbandit[(loc['Lport'],tone['error'])] = {act[a]:[((loc['other'],tone['error']),1)] for a in move}
Tbandit[(loc['Lport'],tone['error'])][act['hold']]=[((loc['Lport'],tone['error']),1)] #remain in Lport if no movement,
for location in ['Rport','Lport','start','other']:
Tbandit[(loc[location],tone['error'])][act['return']]=[(start,1)] #return to start
#error tone is associated with penalty when agent makes incorrect movement
for k in Tbandit.keys(): #Tbandit determines what states pairs need reward values
Rbandit[k]={a:[(rwd['base'],1)] for a in act.values()} #default: cost of basic action
Rbandit[k][act['hold']]=[(rwd['none'],1)] #not moving - no cost
if rwd['partial']>0:
Rbandit[start][act['center']]=[(rwd['partial'],1)]
Rbandit[(loc['Rport'],tone['success'])][act['return']]=[(rwd['partial'],1)]
Rbandit[(loc['Lport'],tone['success'])][act['return']]=[(rwd['partial'],1)]
#reward for correct response
Rbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward'],prwdR),(rwd['base'],1-prwdR)]
Rbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],prwdL),(rwd['base'],1-prwdL)]
#Error if go anywhere but Left or Right port after tone (but hold is not an error)
Rbandit[(loc['Pport'],tone['6kHz'])][act['return']]=[(rwd['error'],1)]
if include_wander:
Rbandit[(loc['Pport'],tone['6kHz'])][act['wander']]=[(rwd['error'],1)]
#Error if go straight to left or right port from start box - same error for discrimination
for a in['right','left']:
Rbandit[(loc['start'],tone['blip'])][act[a]]=[(rwd['error'],1)]
###### Error if agent wanders or stays at Rport or Lport
# These are needed to keep agent from not returning for new trial after success
for a in ['right','left','center']:
Rbandit[(loc['Rport'],tone['success'])][act[a]]=[(rwd['error'],1)]
Rbandit[(loc['Lport'],tone['success'])][act[a]]=[(rwd['error'],1)]
if include_wander:
Rbandit[(loc['Rport'],tone['success'])][act['wander']]=[(rwd['error'],1)]
Rbandit[(loc['Lport'],tone['success'])][act['wander']]=[(rwd['error'],1)]
############## End bandit parameters ###################
if __name__== '__main__':
######## Make sure all needed transitions have been created
validate_T(Tbandit,msg='validate bandit T')
validate_R(Tbandit,Rbandit,msg='validate bandit R')