Skip to content

Commit

Permalink
Update stage 5 passed column
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Nov 22, 2024
1 parent ad76c8c commit 83b7de8
Showing 1 changed file with 4 additions and 67 deletions.
71 changes: 4 additions & 67 deletions src/npc_sessions/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import pynwb
import upath
import zarr
from DynamicRoutingTask.Analysis import DynamicRoutingAnalysisUtils
from DynamicRoutingTask.Analysis.DynamicRoutingAnalysisUtils import DynRoutData

import npc_sessions.trials as TaskControl
Expand Down Expand Up @@ -1981,6 +1982,7 @@ def is_naive(self) -> bool:

@property
def is_stage_5_passed(self) -> bool:
"""Before this session, had the subject passed stage 5 (two sessions meeting crossmodal dprime threshold)?"""
if self.is_templeton:
logger.warning(
f"{self.id} is a Templeton session: returning is_stage_5_passed = False, but we don't currently track this"
Expand All @@ -2002,79 +2004,14 @@ def is_stage_5_passed(self) -> bool:
df = df[df["task version"].str.startswith("stage 5")]
if df.empty:
return False

# from https://github.com/samgale/DynamicRoutingTask/blob/aa781757b3e14895b50851d71583257e8fc39fb9/Analysis/DynamicRoutingAnalysisUtils.py#L282
def getPerformanceStats(df, sessions):
hits = []
dprimeSame = []
dprimeOther = []
for i in sessions:
if isinstance(df.loc[i, "hits"], str):
hits.append(
[int(s) for s in re.findall("[0-9]+", df.loc[i, "hits"])]
)
dprimeSame.append(
[
float(s)
for s in re.findall(
"-*[0-9].[0-9]*|nan", df.loc[i, "d' same modality"]
)
]
)
dprimeOther.append(
[
float(s)
for s in re.findall(
"-*[0-9].[0-9]*|nan",
df.loc[i, "d' other modality go stim"],
)
]
)
else:
hits.append(df.loc[i, "hits"])
dprimeSame.append(df.loc[i, "d' same modality"])
dprimeOther.append(df.loc[i, "d' other modality go stim"])
return hits, dprimeSame, dprimeOther

def getSessionsToPass(
mouseId, df, sessions, stage, hitThresh=100, dprimeThresh=1.5
):
sessionsToPass = np.nan
for sessionInd in sessions:
if sessionInd > sessions[0]:
hits, dprimeSame, dprimeOther = getPerformanceStats(
df, (sessionInd - 1, sessionInd)
)
if (
stage in (1, 2)
and all(h[0] >= hitThresh for h in hits)
and all(d[0] >= dprimeThresh for d in dprimeSame)
) or (
stage == 5
and np.all(
np.sum(
(np.array(dprimeSame) >= dprimeThresh)
& (np.array(dprimeOther) >= dprimeThresh),
axis=1,
)
> 3
)
):
sessionsToPass = np.where(sessions == sessionInd)[0][0] + 1
break
if np.isnan(sessionsToPass):
if stage in (1, 2) and mouseId in (614910, 684071, 682893):
sessionsToPass = len(sessions)
return sessionsToPass

return np.isnan(
getSessionsToPass(
DynamicRoutingAnalysisUtils.getSessionsToPass(
mouseId=int(self.subject.id),
df=df,
sessions=np.where(
[
str(d).split(" ")[0]
<= self.session_start_time.strftime("%Y-%m-%d")
< self.session_start_time.strftime("%Y-%m-%d")
for d in df["start time"].values
]
)[0],
Expand Down

0 comments on commit 83b7de8

Please sign in to comment.