-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSequenceTaskParam.py
152 lines (136 loc) · 5.93 KB
/
SequenceTaskParam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:12:28 2020
@author: kblackw1
"""
############ reward ################
rwd={'error':-1,'reward':15,'base':-1,'none':0}
######### Parameters for the agent ##################################
params={}
params['beta_GPi']=10
params['wt_learning']=False
params['wt_noise']=False #whether to multiply noise by learning_rate - not helpful
params['numQ']=1
params['alpha']=[0.3,0.06] # learning rate 0.3 and 0.06 produce learning in 400 trials,slower for Q2 - D2 neurons
params['beta']=0.9 # inverse temperature, controls exploration
params['beta_min']=0.5
params['gamma']=0.9 #discount factor
params['hist_len']=40
params['state_thresh']=[0.12,0.2] #similarity of noisy state to ideal state
#if lower state_creation_threshold for Q[0] compared to Q[1], then Q[0] will have fewer states
#possibly multiply state_thresh by learn_rate? to change dynamically?
params['sigma']=0.25 #similarity of noisey state to ideal state,std used in Mahalanobis distance.
params['time_inc']=0.1 #increment time since reward by this much in no reward
params['moving_avg_window']=3 ##This in units of trials, the actual window is this times the number of events per trial
params['decision_rule']=None #'combo', 'delta', 'sumQ2'
params['Q2other']=0.1
params['forgetting']=0
params['reward_cues']=None #options: 'TSR', 'RewHx3', 'reward'
params['distance']='Euclidean'
params['split']=True
############### Make sure you have all the state transitions needed ##########
def validate_T(T,msg=''):
print(msg)
for st in T.keys():
for newst_list in T[st].values():
for newst in newst_list:
if newst[0] not in T.keys():
print('new state', newst[0],'not in Tacq')
def validate_R(T,R,msg=''):
print(msg)
for st in T.keys():
if st not in R.keys():
print('state', st,'in T,but not in R')
else:
for a in T[st].keys():
if a not in R[st].keys():
print('state/action:', st,a,'in T, but not in R')
######### Parameters for the environment ##################################
act={'goMag':0,'goL':1,'goR':2,'press':3,'other':4} #other includes grooming, not moving
Hx_len=4 # this specifies the number of for loops #Hx_len=3 works better than 4
hx_values=['L','R'] #possible characters in the history
start_hx='-'*Hx_len #starting press history is 'empty'
'''
WOrse performance if no '-'
hx_values=['L','R'] #possible characters in the history
import numpy
#starting press history is random
start_presses=numpy.random.randint(len(hx_values),size=Hx_len)
start_hx=''.join([hx_values[s] for s in start_presses])
'''
### Enumerate all possible 3-way combinations of press history
sequences={}
value=0
if Hx_len==3:
for c1 in hx_values:
for c2 in hx_values:
for c3 in hx_values:
sequences[c1+c2+c3]=value
value+=1 ############ Why are states being numbered? - only for the states dictionary.
for c3 in hx_values:
sequences['--'+c3]=value
value+=1
for c2 in hx_values:
sequences['-'+c2+c3]=value
value+=1
sequences['---']=value
params['events_per_trial']=6
elif Hx_len==4:
for c1 in hx_values:
for c2 in hx_values:
for c3 in hx_values:
for c4 in hx_values:
sequences[c1+c2+c3+c4]=value
value+=1 ############ Why are states being numbered? - only for the states dictionary.
for c4 in hx_values:
sequences['---'+c4]=value
value+=1
for c3 in hx_values:
sequences['--'+c3+c4]=value
value+=1
for c2 in hx_values:
sequences['-'+c2+c3+c4]=value
value+=1
sequences['----']=value
params['events_per_trial']=7
else:
print('unanticipated Hx_len in press history')
#create state dictionary
states={'loc':{'mag':0,'Llever':1,'Rlever':2,'other':3},
'hx': sequences}
params['state_units']={'loc':False,'hx':False}
#some convenient variables
loc=states['loc'] #used only to define R and T
hx=states['hx'] #used only to define R and T
Tloc={loc[location]:{} for location in loc} #dictionaries to improve readability/prevent mistakes
#two Transition matrices - NOTE, this is the transition for locations
#The transition for lever presses is a function specified in the environment
for location in ['Llever','Rlever','other','mag']:
Tloc[loc[location]][act['goL']]=[(loc['Llever'],1)]
Tloc[loc[location]][act['goR']]=[(loc['Rlever'],1)]
Tloc[loc[location]][act['goMag']]=[(loc['mag'],1)]
for location in ['Llever','Rlever','other','mag']:
for action in ['press','other']:
Tloc[loc[location]][act[action]]=[(loc[location],1)]
#where to start episodes, and also re-start trial after reward
start=(states['loc']['mag'],states['hx'][start_hx])
#put some environment values into dictionary for ease of param passing
env_params={'start':start,'hx_len':Hx_len,'hx_act':'press'}
#Reward matrix: enumerates all states.
#Would be nice to avoid such enumeration and create function similar to T
R={}
for k in Tloc.keys(): #T determines what states pairs need reward values
for st in states['hx'].values():
R[(k,st)]={a:[(rwd['base'],1)] for a in act.values()} #default: cost of basic action
if Hx_len==3:
R[(loc['Rlever'],states['hx']['LLR'])][act['press']]=[(rwd['reward'],0.95),(rwd['base'],0.05)] #95% reward for correct press sequence
elif Hx_len==4:
for location in loc:
R[(loc[location],states['hx']['LLRR'])][act['goMag']]=[(rwd['reward'],0.95),(rwd['base'],0.05)] #95% reward for correct press sequence
else:
print('unanticipated Hx_len in reward assignment')
if __name__== '__main__':
######## Make sure all needed transitions have been created
validate_T(Tloc,msg='validate Tloc')
validate_R(Tloc,R,msg='validate R')
print('press history length=',Hx_len,', start press',start_hx,', hx values',hx_values)