Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: making static analysis as the default option for the TEE #274

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions medmodels/treatment_effect/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,13 @@ def with_patients_group(self, group: Group) -> TreatmentEffectBuilder:
return self

def with_time_attribute(
self, attribute: MedRecordAttribute
self, attribute: Optional[MedRecordAttribute]
MarIniOnz marked this conversation as resolved.
Show resolved Hide resolved
) -> TreatmentEffectBuilder:
"""Sets the time attribute to be used in the treatment effect estimation.

Args:
attribute (MedRecordAttribute): The time attribute.
attribute (Optional[MedRecordAttribute]): The time attribute. If None,
there is no temporal analysis, but only static one.

Returns:
TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder
Expand Down
207 changes: 144 additions & 63 deletions medmodels/treatment_effect/continuous_estimators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from math import sqrt
from typing import Literal, Set
from typing import Literal, Optional, Set

import numpy as np

from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.querying import EdgeOperand
from medmodels.medrecord.types import Group, MedRecordAttribute, NodeIndex
from medmodels.treatment_effect.temporal_analysis import find_reference_edge

Expand All @@ -15,7 +16,7 @@ def average_treatment_effect(
outcome_group: Group,
outcome_variable: MedRecordAttribute,
reference: Literal["first", "last"] = "last",
time_attribute: MedRecordAttribute = "time",
time_attribute: Optional[MedRecordAttribute] = None,
) -> float:
r"""Calculates the Average Treatment Effect (ATE) as the difference between the outcome means of the treated and control sets.

Expand Down Expand Up @@ -50,46 +51,76 @@ def average_treatment_effect(
exposure time. Options include "first" and "last". If "first", the function
returns the earliest exposure edge. If "last", the function returns the
latest exposure edge. Defaults to "last".
time_attribute (MedRecordAttribute, optional): The attribute in the edge that
contains the time information. Defaults to "time".
time_attribute (Optional[MedRecordAttribute], optional): The attribute in the
edge that contains the time information. If it is equal to None, there is
no time component in the data and all edges between the sets and the
outcomes are considered for the average treatment effect. Defaults to
None.

Returns:
float: The average treatment effect.

Raises:
ValueError: If the outcome variable is not numeric.
"""
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
if time_attribute is not None:
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)

else:
edges_treated_outcomes = medrecord.select_edges(
lambda edge: query_edges_between_set_outcome(
edge, treatment_outcome_true_set, outcome_group
)
)
treated_outcomes = np.array(
[
medrecord.edge[edge_id][outcome_variable]
for edge_id in edges_treated_outcomes
]
)

edges_control_outcomes = medrecord.select_edges(
lambda edge: query_edges_between_set_outcome(
edge, control_outcome_true_set, outcome_group
)
)
control_outcomes = np.array(
[
medrecord.edge[edge_id][outcome_variable]
for edge_id in edges_control_outcomes
]
)

if not all(isinstance(i, (int, float)) for i in treated_outcomes):
raise ValueError("Outcome variable must be numeric")

control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
raise ValueError("Outcome variable must be numeric")

Expand All @@ -103,7 +134,7 @@ def cohens_d(
outcome_group: Group,
outcome_variable: MedRecordAttribute,
reference: Literal["first", "last"] = "last",
time_attribute: MedRecordAttribute = "time",
time_attribute: Optional[MedRecordAttribute] = None,
add_correction: bool = False,
) -> float:
"""Calculates Cohen's D, the standardized mean difference between two sets, measuring the effect size of the difference between two outcome means.
Expand Down Expand Up @@ -141,8 +172,10 @@ def cohens_d(
exposure time. Options include "first" and "last". If "first", the function
returns the earliest exposure edge. If "last", the function returns the
latest exposure edge. Defaults to "last".
time_attribute (MedRecordAttribute, optional): The attribute in the edge that
contains the time information. Defaults to "time".
time_attribute (Optional[MedRecordAttribute], optional): The attribute in the
edge that contains the time information. If it is equal to None, there is
no time component in the data and all edges between the sets and the
outcomes are considered for the Cohen's D calculation. Defaults to None.
add_correction (bool, optional): Whether to apply a correction factor for small
sample sizes. When True, using Hedges' g formula instead of Cohens' D.
Defaults to False.
Expand All @@ -153,37 +186,62 @@ def cohens_d(
Raises:
ValueError: If the outcome variable is not numeric.
"""
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
if time_attribute is not None:
treated_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute=time_attribute,
reference=reference,
)
][outcome_variable]
for node_index in treatment_outcome_true_set
]
)
control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
else:
edges_treated_outcomes = medrecord.select_edges(
lambda edge: query_edges_between_set_outcome(
edge, treatment_outcome_true_set, outcome_group
)
)
treated_outcomes = np.array(
[
medrecord.edge[edge_id][outcome_variable]
for edge_id in edges_treated_outcomes
]
)

edges_control_outcomes = medrecord.select_edges(
lambda edge: query_edges_between_set_outcome(
edge, control_outcome_true_set, outcome_group
MarIniOnz marked this conversation as resolved.
Show resolved Hide resolved
)
)
control_outcomes = np.array(
[
medrecord.edge[edge_id][outcome_variable]
for edge_id in edges_control_outcomes
]
)
if not all(isinstance(i, (int, float)) for i in treated_outcomes):
raise ValueError("Outcome variable must be numeric")

control_outcomes = np.array(
[
medrecord.edge[
find_reference_edge(
medrecord,
node_index,
outcome_group,
time_attribute="time",
reference=reference,
)
][outcome_variable]
for node_index in control_outcome_true_set
]
)
if not all(isinstance(i, (int, float)) for i in control_outcomes):
raise ValueError("Outcome variable must be numeric")

Expand All @@ -207,3 +265,26 @@ def cohens_d(
"average_te": average_treatment_effect,
"cohens_d": cohens_d,
}


def query_edges_between_set_outcome(
edge: EdgeOperand, set: Set[NodeIndex], outcomes_group: Group
):
"""Query edges that connect a set of nodes to the outcomes group.

Args:
edge (EdgeOperand): The edge operand to query.
set (Set[NodeIndex]): A set of node indices representing the treated group that
also have the outcome.
outcomes_group (Group): The group of nodes that contain the outcome variable.
"""
list_nodes = list(set)
edge.either_or(
lambda edge: edge.source_node().index().is_in(list_nodes),
lambda edge: edge.target_node().index().is_in(list_nodes),
)

edge.either_or(
lambda edge: edge.source_node().in_group(outcomes_group),
lambda edge: edge.target_node().in_group(outcomes_group),
)
Loading
Loading