Skip to content

Commit

Permalink
test accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaellaude committed Nov 2, 2023
1 parent 0124347 commit 75bb181
Show file tree
Hide file tree
Showing 14 changed files with 28,691 additions and 830 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ cargo run -- parse --address '33 Nassau Avenue, Brooklyn, NY'
or by importing the crate and using the `parse` function

```rust
let addresses = read_to_string("tests/test_addrs.txt").expect("Could not read file");
let addresses = read_to_string("tests/test_data/test_addrs.txt").expect("Could not read file");
let addresses: Vec<&str> = addresses.lines().collect();
for (i, addr) in addresses.iter().enumerate() {
let parsed = parse(addr);
Expand Down
2 changes: 1 addition & 1 deletion benches/clean_test_addresses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn bench(c: &mut Criterion) {
});

c.bench_function("clean_address_batch", |b| {
let data = read_to_string("tests/test_addrs.txt").expect("Could not read file");
let data = read_to_string("tests/test_data/test_addrs.txt").expect("Could not read file");
let data: Vec<&str> = data.lines().collect();
b.iter_batched(
|| data.clone(),
Expand Down
Binary file modified model/test_usaddr.crfsuite
Binary file not shown.
85 changes: 72 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crfsuite::Attribute;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use unicode_normalization::UnicodeNormalization;
use xml::reader::{EventReader, XmlEvent};
mod abbreviations;
pub mod train;

Expand Down Expand Up @@ -56,7 +59,7 @@ pub fn zip_tokens_and_tags(tokens: Vec<String>, tags: Vec<String>) -> Vec<(Strin
tokens.into_iter().zip(tags.into_iter()).collect()
}

pub fn get_address_features(tokens: &Vec<String>) -> Vec<Vec<Attribute>> {
pub fn get_address_features(tokens: &[String]) -> Vec<Vec<Attribute>> {
let xseq = tokens
.iter()
.map(|token| get_token_features(token))
Expand All @@ -78,16 +81,20 @@ pub fn add_feature_context(features: Vec<Vec<Attribute>>) -> Vec<Vec<Attribute>>

let n_features = features.len();

if n_features == 1 {
return features;
}

// 1. Collect new attributes
let mut new_attributes = Vec::new();
for idx in 0..n_features {
let mut current_attrs = Vec::new();
if idx == 0 {
current_attrs.extend(get_new_attributes(&features[idx + 1], "next"));
} else if idx == 1 {
current_attrs.push(Attribute::new("previous.address.start", 1f64));
current_attrs.push(Attribute::new("previous_address.start", 1f64));
} else if idx == n_features - 2 {
current_attrs.push(Attribute::new("next.address.end", 1f64));
current_attrs.push(Attribute::new("next_address.end", 1f64));
} else if idx == n_features - 1 {
current_attrs.extend(get_new_attributes(&features[idx - 1], "previous"));
} else {
Expand Down Expand Up @@ -124,7 +131,10 @@ pub fn get_token_features(token: &str) -> Vec<Attribute> {
let mut n_chars = 0;
let mut numeric_digits = 0;
let mut has_vowels = false;
let mut last_char = None;
let mut endsinpunc = false;
let mut ends_in_period = false;
let mut trailing_zeros = false;
let mut token_clean = String::with_capacity(token.len());

for c in token.chars() {
n_chars += 1;
Expand All @@ -134,17 +144,17 @@ pub fn get_token_features(token: &str) -> Vec<Attribute> {
if "aeiou".contains(c) {
has_vowels = true;
}
last_char = Some(c);
if c.is_alphanumeric() {
token_clean.push(c);
}
}

if let Some(last_char) = token.chars().last() {
endsinpunc = last_char.is_ascii_punctuation();
ends_in_period = last_char == '.';
trailing_zeros = last_char == '0';
}

let token_clean = token
.to_lowercase()
.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>();
let endsinpunc = last_char.map_or(false, |c| c.is_ascii_punctuation());
let ends_in_period = last_char.map_or(false, |c| c == '.');
let trailing_zeros = last_char.map_or(false, |c| c == '0');
let digits = match numeric_digits {
d if d == n_chars => "all_digits",
d if d > 0 => "some_digits",
Expand Down Expand Up @@ -249,3 +259,52 @@ pub fn remove_insignificant_punctuation(address: &str) -> String {
}
output
}

pub fn read_xml_tagged_addresses(file_path: &str) -> (Vec<String>, Vec<Vec<String>>) {
let file = File::open(file_path);
let file = match file {
Ok(file) => file,
Err(e) => {
eprintln!("Error opening file: {}", e);
std::process::exit(1);
}
};
let file = BufReader::new(file);

let parser = EventReader::new(file);

let mut addresses: Vec<String> = Vec::new();
let mut tags: Vec<Vec<String>> = Vec::new();

let mut address: Vec<String> = Vec::new();
let mut yseq: Vec<String> = Vec::new();

for e in parser {
match e {
Ok(XmlEvent::StartElement { name, .. }) => {
if name.local_name == "AddressString" {
address.clear();
yseq.clear();
} else {
yseq.push(name.local_name.to_string());
}
}
Ok(XmlEvent::Characters(s)) => {
address.push(s);
}
Ok(XmlEvent::EndElement { name }) => {
if name.local_name == "AddressString" {
addresses.push(address.join(" "));
tags.push(yseq.clone());
}
}
Err(e) => {
eprintln!("Error: {e}");
break;
}
_ => {}
}
}

(addresses, tags)
}
4 changes: 2 additions & 2 deletions src/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use xml::reader::{EventReader, XmlEvent};

use crate::{get_address_features, tokenize};

pub fn train_model(file_path: &str) -> std::io::Result<()> {
pub fn train_model(export_path: &str) -> std::io::Result<()> {
let file = File::open("training/labeled.xml")?;
let file = BufReader::new(file);

Expand Down Expand Up @@ -58,7 +58,7 @@ pub fn train_model(file_path: &str) -> std::io::Result<()> {
}
}

match trainer.train(file_path, -1) {
match trainer.train(export_path, -1) {
Ok(()) => (),
Err(e) => println!("Error training model: {}", e),
}
Expand Down
70 changes: 0 additions & 70 deletions tests/integration_tests.rs

This file was deleted.

Loading

0 comments on commit 75bb181

Please sign in to comment.