Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement to_writer method and clean up RustCycle csv serde #95

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions rust/fastsim-core/fastsim-proc-macros/src/add_pyo3_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_resource")]
pub fn from_resource_py(filepath: &PyAny) -> anyhow::Result<Self> {
Self::from_resource(PathBuf::extract(filepath)?)
pub fn from_resource_py(filepath: &PyAny) -> PyResult<Self> {
Self::from_resource(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to a file.
Expand All @@ -228,8 +228,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
/// * `filepath`: `str | pathlib.Path` - The filepath at which to write the object
///
#[pyo3(name = "to_file")]
pub fn to_file_py(&self, filepath: &PyAny) -> anyhow::Result<()> {
self.to_file(PathBuf::extract(filepath)?)
pub fn to_file_py(&self, filepath: &PyAny) -> PyResult<()> {
self.to_file(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from a file.
Expand All @@ -241,8 +241,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_file")]
pub fn from_file_py(filepath: &PyAny) -> anyhow::Result<Self> {
Self::from_file(PathBuf::extract(filepath)?)
pub fn from_file_py(filepath: &PyAny) -> PyResult<Self> {
Self::from_file(PathBuf::extract(filepath)?).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object into a string
Expand All @@ -252,8 +252,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
/// * `format`: `str` - The target format, any of those listed in [`ACCEPTED_STR_FORMATS`](`SerdeAPI::ACCEPTED_STR_FORMATS`)
///
#[pyo3(name = "to_str")]
pub fn to_str_py(&self, format: &str) -> anyhow::Result<String> {
self.to_str(format)
pub fn to_str_py(&self, format: &str) -> PyResult<String> {
self.to_str(format).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from a string
Expand All @@ -265,14 +265,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_str")]
pub fn from_str_py(contents: &str, format: &str) -> anyhow::Result<Self> {
Self::from_str(contents, format)
pub fn from_str_py(contents: &str, format: &str) -> PyResult<Self> {
Self::from_str(contents, format).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to a JSON string
#[pyo3(name = "to_json")]
pub fn to_json_py(&self) -> anyhow::Result<String> {
self.to_json()
pub fn to_json_py(&self) -> PyResult<String> {
self.to_json().map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object to a JSON string
Expand All @@ -283,14 +283,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_json")]
pub fn from_json_py(json_str: &str) -> anyhow::Result<Self> {
Self::from_json(json_str)
pub fn from_json_py(json_str: &str) -> PyResult<Self> {
Self::from_json(json_str).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to a YAML string
#[pyo3(name = "to_yaml")]
pub fn to_yaml_py(&self) -> anyhow::Result<String> {
self.to_yaml()
pub fn to_yaml_py(&self) -> PyResult<String> {
self.to_yaml().map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from a YAML string
Expand All @@ -301,14 +301,14 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_yaml")]
pub fn from_yaml_py(yaml_str: &str) -> anyhow::Result<Self> {
Self::from_yaml(yaml_str)
pub fn from_yaml_py(yaml_str: &str) -> PyResult<Self> {
Self::from_yaml(yaml_str).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Write (serialize) an object to bincode-encoded `bytes`
#[pyo3(name = "to_bincode")]
pub fn to_bincode_py<'py>(&self, py: Python<'py>) -> anyhow::Result<&'py PyBytes> {
Ok(PyBytes::new(py, &self.to_bincode()?))
pub fn to_bincode_py<'py>(&self, py: Python<'py>) -> PyResult<&'py PyBytes> {
PyResult::Ok(PyBytes::new(py, &self.to_bincode()?)).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

/// Read (deserialize) an object from bincode-encoded `bytes`
Expand All @@ -319,8 +319,8 @@ pub fn add_pyo3_api(attr: TokenStream, item: TokenStream) -> TokenStream {
///
#[staticmethod]
#[pyo3(name = "from_bincode")]
pub fn from_bincode_py(encoded: &PyBytes) -> anyhow::Result<Self> {
Self::from_bincode(encoded.as_bytes())
pub fn from_bincode_py(encoded: &PyBytes) -> PyResult<Self> {
Self::from_bincode(encoded.as_bytes()).map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}
});

Expand Down
71 changes: 32 additions & 39 deletions rust/fastsim-core/src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ impl RustCycleCache {
Ok(dict)
}

#[pyo3(name = "to_csv")]
pub fn to_csv_py(&self) -> PyResult<String> {
self.to_csv().map_err(|e| PyIOError::new_err(format!("{:?}", e)))
}

#[pyo3(name = "modify_by_const_jerk_trajectory")]
pub fn modify_by_const_jerk_trajectory_py(
&mut self,
Expand Down Expand Up @@ -650,19 +655,25 @@ impl SerdeAPI for RustCycle {
Ok(())
}

fn to_file<P: AsRef<Path>>(&self, filepath: P) -> anyhow::Result<()> {
let filepath = filepath.as_ref();
let extension = filepath
.extension()
.and_then(OsStr::to_str)
.with_context(|| format!("File extension could not be parsed: {filepath:?}"))?;
match extension.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(&File::create(filepath)?, self)?,
"json" => serde_json::to_writer(&File::create(filepath)?, self)?,
"bin" => bincode::serialize_into(&File::create(filepath)?, self)?,
"csv" => self.write_csv(&mut csv::Writer::from_path(filepath)?)?,
fn to_writer<W: std::io::Write>(&self, wtr: W, format: &str) -> anyhow::Result<()> {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(wtr, self)?,
"json" => serde_json::to_writer(wtr, self)?,
"bin" => bincode::serialize_into(wtr, self)?,
"csv" => {
let mut wtr = csv::Writer::from_writer(wtr);
for i in 0..self.len() {
wtr.serialize(RustCycleElement {
time_s: self.time_s[i],
mps: self.mps[i],
grade: Some(self.grade[i]),
road_type: Some(self.road_type[i]),
})?;
}
wtr.flush()?
}
_ => bail!(
"Unsupported format {extension:?}, must be one of {:?}",
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
}
Expand All @@ -674,11 +685,7 @@ impl SerdeAPI for RustCycle {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => self.to_yaml()?,
"json" => self.to_json()?,
"csv" => {
let mut wtr = csv::Writer::from_writer(Vec::with_capacity(self.len()));
self.write_csv(&mut wtr)?;
String::from_utf8(wtr.into_inner()?)?
}
"csv" => self.to_csv()?,
_ => {
bail!(
"Unsupported format {format:?}, must be one of {:?}",
Expand All @@ -696,7 +703,7 @@ impl SerdeAPI for RustCycle {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => Self::from_yaml(contents)?,
"json" => Self::from_json(contents)?,
"csv" => Self::from_csv_str(contents, "".to_string())?,
"csv" => Self::from_reader(contents.as_ref().as_bytes(), "csv")?,
_ => bail!(
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_STR_FORMATS
Expand Down Expand Up @@ -791,37 +798,23 @@ impl RustCycle {
.and_then(OsStr::to_str)
.with_context(|| format!("Could not parse cycle name from filepath: {filepath:?}"))?
.to_string();
let file = File::open(filepath).with_context(|| {
if !filepath.exists() {
format!("File not found: {filepath:?}")
} else {
format!("Could not open file: {filepath:?}")
}
})?;
let mut cyc = Self::from_reader(file, "csv")?;
let mut cyc = Self::from_file(filepath)?;
cyc.name = name;
Ok(cyc)
}

/// Load cycle from CSV string
pub fn from_csv_str<S: AsRef<str>>(csv_str: S, name: String) -> anyhow::Result<Self> {
let mut cyc = Self::from_reader(csv_str.as_ref().as_bytes(), "csv")?;
let mut cyc = Self::from_str(csv_str, "csv")?;
cyc.name = name;
Ok(cyc)
}

/// Write cycle data to a CSV writer
fn write_csv<W: std::io::Write>(&self, wtr: &mut csv::Writer<W>) -> anyhow::Result<()> {
for i in 0..self.len() {
wtr.serialize(RustCycleElement {
time_s: self.time_s[i],
mps: self.mps[i],
grade: Some(self.grade[i]),
road_type: Some(self.road_type[i]),
})?;
}
wtr.flush()?;
Ok(())
/// Write (serialize) cycle to a CSV string
pub fn to_csv(&self) -> anyhow::Result<String> {
let mut buf = Vec::with_capacity(self.len());
self.to_writer(&mut buf, "csv")?;
Ok(String::from_utf8(buf)?)
}

pub fn build_cache(&self) -> RustCycleCache {
Expand Down
14 changes: 9 additions & 5 deletions rust/fastsim-core/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
.extension()
.and_then(OsStr::to_str)
.with_context(|| format!("File extension could not be parsed: {filepath:?}"))?;
match extension.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(&File::create(filepath)?, self)?,
"json" => serde_json::to_writer(&File::create(filepath)?, self)?,
"bin" => bincode::serialize_into(&File::create(filepath)?, self)?,
self.to_writer(File::create(filepath)?, extension)
}

fn to_writer<W: std::io::Write>(&self, wtr: W, format: &str) -> anyhow::Result<()> {
match format.trim_start_matches('.').to_lowercase().as_str() {
"yaml" | "yml" => serde_yaml::to_writer(wtr, self)?,
"json" => serde_json::to_writer(wtr, self)?,
"bin" => bincode::serialize_into(wtr, self)?,
_ => bail!(
"Unsupported format {extension:?}, must be one of {:?}",
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
}
Expand Down
Loading