diff --git a/.gitignore b/.gitignore index c1a76ff..8757073 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target */tmp scratch.rs -data_prep \ No newline at end of file +data_prep +*.csv \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index a8b5987..341af2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -288,6 +288,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "either" version = "1.9.0" @@ -661,6 +682,7 @@ dependencies = [ "clap", "crfsuite", "criterion", + "csv", "lazy_static", "unicode-normalization", "xml-rs", diff --git a/Cargo.toml b/Cargo.toml index ebb1fa8..c13a049 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ unicode-normalization = "0.1.22" clap = { version = "4.4.6", features = ["derive"] } crfsuite = "0.3.1" xml-rs = "0.8.19" +csv = "1.3.0" [dev-dependencies] criterion = "0.5.1" diff --git a/src/lib.rs b/src/lib.rs index d78472d..ed48626 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,33 +10,35 @@ pub mod train; use abbreviations::{DIRECTIONALS, STREET_NAMES}; -pub enum Tag { - AddressNumberPrefix, - AddressNumber, - AddressNumberSuffix, - StreetNamePreModifier, - StreetNamePreDirectional, - StreetNamePreType, - StreetName, - StreetNamePostType, - StreetNamePostDirectional, - SubaddressType, - SubaddressIdentifier, - BuildingName, - OccupancyType, - OccupancyIdentifier, - CornerOf, - LandmarkName, - PlaceName, - StateName, - ZipCode, - USPSBoxType, - USPSBoxID, - USPSBoxGroupType, - USPSBoxGroupID, - IntersectionSeparator, - Recipient, - NotAddress, +lazy_static! { + pub static ref TAGS: [&'static str; 26] = [ + "AddressNumberPrefix", + "AddressNumber", + "AddressNumberSuffix", + "StreetNamePreModifier", + "StreetNamePreDirectional", + "StreetNamePreType", + "StreetName", + "StreetNamePostType", + "StreetNamePostDirectional", + "SubaddressType", + "SubaddressIdentifier", + "BuildingName", + "OccupancyType", + "OccupancyIdentifier", + "CornerOf", + "LandmarkName", + "PlaceName", + "StateName", + "ZipCode", + "USPSBoxType", + "USPSBoxID", + "USPSBoxGroupType", + "USPSBoxGroupID", + "IntersectionSeparator", + "Recipient", + "NotAddress", + ]; } lazy_static! { @@ -55,6 +57,20 @@ pub fn parse(address: &str) -> Vec<(String, String)> { zip_tokens_and_tags(tokens, tags) } +pub fn parse_addresses(addresses: Vec<&str>) -> Vec> { + addresses + .iter() // .iter is 42% faster than .par_iter() + .map(|address| parse(address)) + .collect() +} + +pub fn parse_addresses_from_txt(file_path: &str) -> Vec> { + let raw_data = std::fs::read_to_string(file_path).unwrap(); + let data: Vec<&str> = raw_data.lines().collect(); + + parse_addresses(data) +} + pub fn zip_tokens_and_tags(tokens: Vec, tags: Vec) -> Vec<(String, String)> { tokens.into_iter().zip(tags.into_iter()).collect() } diff --git a/src/main.rs b/src/main.rs index a3aed34..e49f361 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use clap::Parser; -use us_addrs::parse; use us_addrs::train::train_model; +use us_addrs::{parse, parse_addresses_from_txt, TAGS}; // use std::path::PathBuf; @@ -8,6 +8,7 @@ use us_addrs::train::train_model; enum USAddrsCli { Train(TrainArgs), Parse(ParseArgs), + ParseFile(ParseFileArgs), } #[derive(Parser)] @@ -22,12 +23,43 @@ struct ParseArgs { address: String, } +#[derive(Parser)] +struct ParseFileArgs { + #[clap(short, long)] + file_path: String, + export_path: String, +} + fn main() { match USAddrsCli::parse() { USAddrsCli::Train(args) => match train_model(&args.export_path) { Ok(()) => println!("Trained model"), Err(e) => println!("Error training model: {}", e), }, + USAddrsCli::ParseFile(args) => { + let parsed_addresses = parse_addresses_from_txt(&args.file_path); + // write as CSV with Tags as columns + let mut wtr = csv::Writer::from_path(&args.export_path).unwrap(); + + wtr.write_record(TAGS.iter()).unwrap(); + + for tagged_address in parsed_addresses { + let mut record = Vec::new(); + + for tag in TAGS.iter() { + if let Some((token, _)) = tagged_address + .iter() + .find(|&(_, token_tag)| *token_tag == *tag) + { + record.push(token.to_string()); + } else { + record.push("".to_string()); + } + } + wtr.write_record(&record).unwrap(); + } + wtr.flush().unwrap(); + } USAddrsCli::Parse(args) => { let parsed = parse(&args.address); println!("{:?}", parsed);