Skip to content

Commit

Permalink
Merge branch 'main' into 254-default-values-for-essential_covariates-…
Browse files Browse the repository at this point in the history
…in-treatmenteffect
  • Loading branch information
MarIniOnz committed Jan 15, 2025
2 parents 9012622 + 816ee6f commit d319202
Show file tree
Hide file tree
Showing 14 changed files with 1,549 additions and 735 deletions.
1,378 changes: 952 additions & 426 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ description = "Limebit MedModels Crate"
[workspace.dependencies]
hashbrown = { version = "0.14.5", features = ["serde"] }
serde = { version = "1.0.203", features = ["derive"] }
polars = { version = "0.40.0", features = ["polars-io"] }
polars = { version = "0.45.0", features = ["polars-io", "dtype-full"] }
chrono = { version = "0.4.38", features = ["serde"] }

medmodels = { version = "0.1.2", path = "crates/medmodels" }
Expand Down
16 changes: 12 additions & 4 deletions crates/medmodels-core/src/medrecord/example_dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,61 @@ const PATIENT_PROCEDURE: &[u8] = include_bytes!("./synthetic_data/patient_proced
impl MedRecord {
pub fn from_example_dataset() -> Self {
let cursor = Cursor::new(DIAGNOSIS_DATA);
let diagnosis = CsvReadOptions::default()
let mut diagnosis = CsvReadOptions::default()
.with_has_header(true)
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
diagnosis.rechunk_mut();
let diagnosis_ids = diagnosis
.column("diagnosis_code")
.expect("Column must exist")
.as_materialized_series()
.iter()
.map(|value| MedRecordAttribute::try_from(value).expect("AnyValue can be converted"))
.collect::<Vec<_>>();

let cursor = Cursor::new(DRUG_DATA);
let drug = CsvReadOptions::default()
let mut drug = CsvReadOptions::default()
.with_has_header(true)
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
drug.rechunk_mut();
let drug_ids = drug
.column("drug_code")
.expect("Column must exist")
.as_materialized_series()
.iter()
.map(|value| MedRecordAttribute::try_from(value).expect("AnyValue can be converted"))
.collect::<Vec<_>>();

let cursor = Cursor::new(PATIENT_DATA);
let patient = CsvReadOptions::default()
let mut patient = CsvReadOptions::default()
.with_has_header(true)
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
patient.rechunk_mut();
let patient_ids = patient
.column("patient_id")
.expect("Column must exist")
.as_materialized_series()
.iter()
.map(|value| MedRecordAttribute::try_from(value).expect("AnyValue can be converted"))
.collect::<Vec<_>>();

let cursor = Cursor::new(PROCEDURE_DATA);
let procedure = CsvReadOptions::default()
let mut procedure = CsvReadOptions::default()
.with_has_header(true)
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
procedure.rechunk_mut();
let procedure_ids = procedure
.column("procedure_code")
.expect("Column must exist")
.as_materialized_series()
.iter()
.map(|value| MedRecordAttribute::try_from(value).expect("AnyValue can be converted"))
.collect::<Vec<_>>();
Expand Down
14 changes: 7 additions & 7 deletions crates/medmodels-core/src/medrecord/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,16 +807,16 @@ mod test {
}

fn create_nodes_dataframe() -> Result<DataFrame, PolarsError> {
let s0 = Series::new("index", &["0", "1"]);
let s1 = Series::new("attribute", &[1, 2]);
DataFrame::new(vec![s0, s1])
let s0 = Series::new("index".into(), &["0", "1"]);
let s1 = Series::new("attribute".into(), &[1, 2]);
DataFrame::new(vec![s0.into(), s1.into()])
}

fn create_edges_dataframe() -> Result<DataFrame, PolarsError> {
let s0 = Series::new("from", &["0", "1"]);
let s1 = Series::new("to", &["1", "0"]);
let s2 = Series::new("attribute", &[1, 2]);
DataFrame::new(vec![s0, s1, s2])
let s0 = Series::new("from".into(), &["0", "1"]);
let s1 = Series::new("to".into(), &["1", "0"]);
let s2 = Series::new("attribute".into(), &[1, 2]);
DataFrame::new(vec![s0.into(), s1.into(), s2.into()])
}

fn create_medrecord() -> MedRecord {
Expand Down
61 changes: 36 additions & 25 deletions crates/medmodels-core/src/medrecord/polars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ impl<'a> TryFrom<AnyValue<'a>> for MedRecordValue {
fn try_from(value: AnyValue<'a>) -> Result<Self, Self::Error> {
match value {
AnyValue::String(value) => Ok(MedRecordValue::String(value.into())),
AnyValue::StringOwned(value) => Ok(MedRecordValue::String(value.into())),
AnyValue::StringOwned(value) => Ok(MedRecordValue::String((*value).into())),
AnyValue::Int8(value) => Ok(MedRecordValue::Int(value.into())),
AnyValue::Int16(value) => Ok(MedRecordValue::Int(value.into())),
AnyValue::Int32(value) => Ok(MedRecordValue::Int(value.into())),
Expand Down Expand Up @@ -88,7 +88,7 @@ impl<'a> TryFrom<AnyValue<'a>> for MedRecordAttribute {
fn try_from(value: AnyValue<'a>) -> Result<Self, Self::Error> {
match value {
AnyValue::String(value) => Ok(MedRecordAttribute::String(value.into())),
AnyValue::StringOwned(value) => Ok(MedRecordAttribute::String(value.into())),
AnyValue::StringOwned(value) => Ok(MedRecordAttribute::String((*value).into())),
AnyValue::Int8(value) => Ok(MedRecordAttribute::Int(value.into())),
AnyValue::Int16(value) => Ok(MedRecordAttribute::Int(value.into())),
AnyValue::Int32(value) => Ok(MedRecordAttribute::Int(value.into())),
Expand All @@ -105,9 +105,13 @@ impl<'a> TryFrom<AnyValue<'a>> for MedRecordAttribute {
}

pub(crate) fn dataframe_to_nodes(
nodes: DataFrame,
mut nodes: DataFrame,
index_column_name: &str,
) -> Result<Vec<(NodeIndex, Attributes)>, MedRecordError> {
if nodes.max_n_chunks() > 1 {
nodes.rechunk_mut();
}

let attribute_column_names = nodes
.get_column_names()
.into_iter()
Expand All @@ -122,13 +126,14 @@ pub(crate) fn dataframe_to_nodes(
index_column_name
))
})?
.as_materialized_series()
.iter();

let mut columns = nodes
.columns(&attribute_column_names)
.expect("Attribute columns must exist")
.iter()
.map(|s| s.iter())
.map(|s| s.as_materialized_series().iter())
.zip(attribute_column_names)
.collect::<Vec<_>>();

Expand All @@ -140,7 +145,7 @@ pub(crate) fn dataframe_to_nodes(
.iter_mut()
.map(|(column, column_name)| {
Ok((
(*column_name).into(),
(***column_name).into(),
column.next().expect("msg").try_into()?,
))
})
Expand All @@ -151,10 +156,14 @@ pub(crate) fn dataframe_to_nodes(
}

pub(crate) fn dataframe_to_edges(
edges: DataFrame,
mut edges: DataFrame,
source_index_column_name: &str,
target_index_column_name: &str,
) -> Result<Vec<(NodeIndex, NodeIndex, Attributes)>, MedRecordError> {
if edges.max_n_chunks() > 1 {
edges.rechunk_mut();
}

let attribute_column_names = edges
.get_column_names()
.into_iter()
Expand All @@ -169,6 +178,7 @@ pub(crate) fn dataframe_to_edges(
source_index_column_name
))
})?
.as_materialized_series()
.iter();
let target_index = edges
.column(target_index_column_name)
Expand All @@ -178,13 +188,14 @@ pub(crate) fn dataframe_to_edges(
target_index_column_name
))
})?
.as_materialized_series()
.iter();

let mut columns = edges
.columns(&attribute_column_names)
.expect("Attribute columns must exist")
.iter()
.map(|s| s.iter())
.map(|s| s.as_materialized_series().iter())
.zip(attribute_column_names)
.collect::<Vec<_>>();

Expand All @@ -198,7 +209,7 @@ pub(crate) fn dataframe_to_edges(
.iter_mut()
.map(|(column, column_name)| {
Ok((
(*column_name).into(),
(***column_name).into(),
column
.next()
.expect("Should have as many iterations as rows")
Expand Down Expand Up @@ -293,7 +304,7 @@ mod test {

#[test]
fn test_from_anyvalue_datetime() {
let any_value = AnyValue::Datetime(0, polars::prelude::TimeUnit::Microseconds, &None);
let any_value = AnyValue::Datetime(0, polars::prelude::TimeUnit::Microseconds, None);

let value = MedRecordValue::try_from(any_value).unwrap();

Expand All @@ -304,7 +315,7 @@ mod test {
value
);

let any_value = AnyValue::Datetime(0, polars::prelude::TimeUnit::Milliseconds, &None);
let any_value = AnyValue::Datetime(0, polars::prelude::TimeUnit::Milliseconds, None);

let value = MedRecordValue::try_from(any_value).unwrap();

Expand All @@ -315,7 +326,7 @@ mod test {
value
);

let any_value = AnyValue::Datetime(0, polars::prelude::TimeUnit::Nanoseconds, &None);
let any_value = AnyValue::Datetime(0, polars::prelude::TimeUnit::Nanoseconds, None);

let value = MedRecordValue::try_from(any_value).unwrap();

Expand All @@ -338,9 +349,9 @@ mod test {

#[test]
fn test_dataframe_to_nodes() {
let s0 = Series::new("index", &["0", "1"]);
let s1 = Series::new("attribute", &[1, 2]);
let nodes_dataframe = DataFrame::new(vec![s0, s1]).unwrap();
let s0 = Series::new("index".into(), &["0", "1"]);
let s1 = Series::new("attribute".into(), &[1, 2]);
let nodes_dataframe = DataFrame::new(vec![s0.into(), s1.into()]).unwrap();

let nodes = dataframe_to_nodes(nodes_dataframe, "index").unwrap();

Expand All @@ -355,9 +366,9 @@ mod test {

#[test]
fn test_invalid_dataframe_to_nodes() {
let s0 = Series::new("index", &["0", "1"]);
let s1 = Series::new("attribute", &[1, 2]);
let nodes_dataframe = DataFrame::new(vec![s0, s1]).unwrap();
let s0 = Series::new("index".into(), &["0", "1"]);
let s1 = Series::new("attribute".into(), &[1, 2]);
let nodes_dataframe = DataFrame::new(vec![s0.into(), s1.into()]).unwrap();

// Providing the wrong index column name should fail
assert!(dataframe_to_nodes(nodes_dataframe, "wrong_column")
Expand All @@ -366,10 +377,10 @@ mod test {

#[test]
fn test_dataframe_to_edges() {
let s0 = Series::new("source", &["0", "1"]);
let s1 = Series::new("target", &["1", "0"]);
let s2 = Series::new("attribute", &[1, 2]);
let edges_dataframe = DataFrame::new(vec![s0, s1, s2]).unwrap();
let s0 = Series::new("source".into(), &["0", "1"]);
let s1 = Series::new("target".into(), &["1", "0"]);
let s2 = Series::new("attribute".into(), &[1, 2]);
let edges_dataframe = DataFrame::new(vec![s0.into(), s1.into(), s2.into()]).unwrap();

let edges = dataframe_to_edges(edges_dataframe, "source", "target").unwrap();

Expand All @@ -392,10 +403,10 @@ mod test {

#[test]
fn test_invalid_dataframe_to_edges() {
let s0 = Series::new("source", &["0", "1"]);
let s1 = Series::new("target", &["1", "0"]);
let s2 = Series::new("attribute", &[1, 2]);
let edges_dataframe = DataFrame::new(vec![s0, s1, s2]).unwrap();
let s0 = Series::new("source".into(), &["0", "1"]);
let s1 = Series::new("target".into(), &["1", "0"]);
let s2 = Series::new("attribute".into(), &[1, 2]);
let edges_dataframe = DataFrame::new(vec![s0.into(), s1.into(), s2.into()]).unwrap();

// Providing the wrong source index column name should fail
assert!(
Expand Down
6 changes: 6 additions & 0 deletions medmodels/treatment_effect/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class TreatmentEffectBuilder:
The TreatmentEffectBuilder class is used to build a TreatmentEffect object with
the desired configurations for the treatment effect estimation using a builder
pattern.
By default, it configures a static treatment effect estimation. To configure a
time-dependent treatment effect estimation, the time_attribute must be set.
"""

treatment: Group
Expand Down Expand Up @@ -97,6 +100,9 @@ def with_time_attribute(
) -> TreatmentEffectBuilder:
"""Sets the time attribute to be used in the treatment effect estimation.
It turns the treatment effect estimation from a static to a time-dependent
analysis.
Args:
attribute (MedRecordAttribute): The time attribute.
Expand Down
Loading

0 comments on commit d319202

Please sign in to comment.