Skip to content

Commit

Permalink
Add TensorFlow implementations of ARADv1 and ARADv2.
Browse files Browse the repository at this point in the history
Changes also include:

- Refactor of `TFModelBase` into `PyRIIDModel`.
- Saving all models as either HDF or ONNX.

Co-authored-by: Tyler Morrow <[email protected]>
  • Loading branch information
alanjvano and tymorrow committed Dec 15, 2023
1 parent 307eef8 commit 49560f4
Show file tree
Hide file tree
Showing 9 changed files with 1,509 additions and 982 deletions.
83 changes: 83 additions & 0 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This example demonstrates how to use the PyRIID implementations of ARAD.
"""
import numpy as np
import pandas as pd

from riid.data.synthetic import get_dummy_seeds
from riid.data.synthetic.seed import SeedMixer
from riid.data.synthetic.static import StaticSynthesizer
from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF

# Config
rng = np.random.default_rng(42)
OOD_QUANTILE = 0.99
VERBOSE = True
# Some of the following parameters are set low because this example runs on GitHub Actions and
# we don't want it taking a bunch of time.
# When running this locally, change the values per their corresponding comment, otherwise
# the results likely will not be meaningful.
EPOCHS = 5 # Change this to 20+
N_MIXTURES = 50 # Changes this to 1000+
TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+
TEST_SAMPLES_PER_SEED = 5

# Generate training data
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg()
mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES)
static_synth = StaticSynthesizer(
samples_per_seed=TRAIN_SAMPLES_PER_SEED,
snr_function_args=(0, 0),
return_fg=False,
return_gross=True,
rng=rng,
)
_, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
gross_train_ss.normalize()

# Train the models
print("Training ARADv1...")
arad_v1 = ARAD(model=ARADv1TF())
arad_v1.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v1.predict(gross_train_ss)
v1_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

print("Training ARADv2...")
arad_v2 = ARAD(model=ARADv2TF())
arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
arad_v2.predict(gross_train_ss)
v2_ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)

# Generate test data
static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED
_, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
test_ss.normalize()

# Predict

arad_v1_reconstructions = arad_v1.predict(test_ss, verbose=True)
arad_v1_ood = test_ss.info.recon_error.values > v1_ood_threshold
arad_v1_false_positive_rate = arad_v1_ood.mean()
arad_v1_mean_recon_error = test_ss.info.recon_error.values.mean()

arad_v2_reconstructions = arad_v2.predict(test_ss, verbose=True)
arad_v2_ood = test_ss.info.recon_error.values > v2_ood_threshold
arad_v2_false_positive_rate = arad_v2_ood.mean()
arad_v2_mean_recon_error = test_ss.info.recon_error.values.mean()

results = {
"ARADv1": {
"ood_threshold": f"KLD={v1_ood_threshold:.4f}",
"mean_recon_error": arad_v1_mean_recon_error,
"false_positive_rate": arad_v1_false_positive_rate,
},
"ARADv2": {
"ood_threshold": f"JSD={v2_ood_threshold:.4f}",
"mean_recon_error": arad_v2_mean_recon_error,
"false_positive_rate": arad_v2_false_positive_rate,
}
}
print(f"Target False Positive Rate: {1-OOD_QUANTILE:.4f}")
print(pd.DataFrame.from_dict(results))
22 changes: 12 additions & 10 deletions riid/data/synthetic/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,15 @@ def generate(self, config: Union[str, dict],
class SeedMixer():
"""Randomly mixes seeds in a `SampleSet` """
def __init__(self, seeds_ss: SampleSet, mixture_size: int = 2, dirichlet_alpha: float = 2.0,
restricted_isotope_pairs: List[Tuple[str, str]] = [], random_state: int = None):
restricted_isotope_pairs: List[Tuple[str, str]] = [], rng: Generator = None):
"""
Args:
seeds_ss: `SampleSet` of `n` seed spectra where `n` >= `mixture_size`.
seeds_ss: `SampleSet` of `n` seed spectra where `n` >= `mixture_size`
mixture_size: number of templates to mix
dirichlet_alpha: Dirichlet parameter controlling the nature of proportions
restricted_isotope_pairs:
random_state:
restricted_pairs: list of 2-tuples containing pairs of isotope strings that
are not to be mixed together
rng: NumPy random number generator, useful for experiment repeatability
Raises:
AssertionError when `mixture_size` is less than 2
Expand All @@ -175,7 +176,10 @@ def __init__(self, seeds_ss: SampleSet, mixture_size: int = 2, dirichlet_alpha:
self.mixture_size = mixture_size
self.dirichlet_alpha = dirichlet_alpha
self.restricted_isotope_pairs = restricted_isotope_pairs
self.random_state = random_state
if rng is None:
self.rng = np.random.default_rng()
else:
self.rng = rng

self._check_seeds()

Expand Down Expand Up @@ -247,8 +251,6 @@ def __call__(self, n_samples: int, max_batch_size: int = 100) -> Iterator[Sample
raise ValueError("Number of Dirichlet alphas does not equal the number of seeds.")
seed_to_alpha = {s: a for s, a in zip(seeds, self.dirichlet_alpha)}

rng = np.random.default_rng(self.random_state)

n_samples_produced = 0
while n_samples_produced < n_samples:
batch_size = n_samples - n_samples_produced
Expand All @@ -262,20 +264,20 @@ def __call__(self, n_samples: int, max_batch_size: int = 100) -> Iterator[Sample
np.array(isotope_probas.copy()),
restricted_isotope_bidict,
self.mixture_size,
rng
self.rng,
)
for _ in range(batch_size)
]
seed_choices = [
[isotope_to_seeds[i][rng.choice(len(isotope_to_seeds[i]))] for i in c]
[isotope_to_seeds[i][self.rng.choice(len(isotope_to_seeds[i]))] for i in c]
for c in isotope_choices
]
batch_dirichlet_alphas = np.array([
[seed_to_alpha[i] for i in s]
for s in seed_choices
])
seed_ratios = [
rng.dirichlet(
self.rng.dirichlet(
alpha=alpha
) for alpha in batch_dirichlet_alphas
]
Expand Down
5 changes: 5 additions & 0 deletions riid/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def jensen_shannon_divergence(p, q):
return jsd


def jensen_shannon_distance(p, q):
divergence = jensen_shannon_divergence(p, q)
return tf.math.sqrt(divergence)


def chi_squared_diff(spectra, reconstructed_spectra):
"""Compute the Chi-Squared test.
Expand Down
75 changes: 65 additions & 10 deletions riid/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This module contains the base TFModel class."""
import json
import os
import uuid
import warnings
from enum import Enum

import numpy as np
import onnxruntime
import pandas as pd
import tensorflow as tf
import tf2onnx

import riid
from riid.data.labeling import label_to_index_element
Expand All @@ -23,10 +27,11 @@ class ModelInput(Enum):
ForegroundSpectrum = 2


class TFModelBase:
class PyRIIDModel:
"""Base class for TensorFlow models."""

CUSTOM_OBJECTS = {"multi_f1": multi_f1, "single_f1": single_f1}
SUPPORTED_SAVE_EXTS = {"H5": ".h5", "ONNX": ".onnx"}

def __init__(self, *args, **kwargs):
self._info = {}
Expand Down Expand Up @@ -107,37 +112,87 @@ def save(self, file_path: str):
"""Save the model to a file.
Args:
file_path: file path at which to save the model
file_path: file path at which to save the model, can be either .h5 or
.onnx format
Raises:
`ValueError` when the given file path already exists
"""
if os.path.exists(file_path):
raise ValueError("Path already exists.")

root, ext = os.path.splitext(file_path)
if ext.lower() not in self.SUPPORTED_SAVE_EXTS.values():
raise NameError("Model must be an .onnx or .h5 file.")

warnings.filterwarnings("ignore")

self.model.save(file_path, save_format="h5")
pd.DataFrame([[v] for v in self.info.values()], self.info.keys()).to_hdf(file_path, "_info")
if ext.lower() == self.SUPPORTED_SAVE_EXTS["H5"]:
self.model.save(file_path, save_format="h5")
pd.DataFrame(
[[v] for v in self.info.values()],
self.info.keys()
).to_hdf(file_path, "_info")
else:
model_path = root + self.SUPPORTED_SAVE_EXTS["ONNX"]
model_info_path = root + "_info.json"

model_info_df = pd.DataFrame(
[[v] for v in self.info.values()],
self.info.keys()
)
model_info_df[0].to_json(model_info_path, indent=4)

tf2onnx.convert.from_keras(
self.model,
input_signature=None,
output_path=model_path
)

warnings.resetwarnings()

def load(self, file_path: str):
"""Load the model from a file.
Args:
file_path: file path from which to load the model
file_path: file path from which to load the model, must be either an
.h5 or .onnx file
"""

root, ext = os.path.splitext(file_path)
if ext.lower() not in self.SUPPORTED_SAVE_EXTS.values():
raise NameError("Model must be an .onnx or .h5 file.")

warnings.filterwarnings("ignore", category=DeprecationWarning)

self.model = tf.keras.models.load_model(
file_path,
custom_objects=self.CUSTOM_OBJECTS
)
self._info = pd.read_hdf(file_path, "_info")[0].to_dict()
if ext.lower() == self.SUPPORTED_SAVE_EXTS["H5"]:
self.model = tf.keras.models.load_model(
file_path,
custom_objects=self.CUSTOM_OBJECTS
)
self._info = pd.read_hdf(file_path, "_info")[0].to_dict()
else:
model_path = root + self.SUPPORTED_SAVE_EXTS["ONNX"]
model_info_path = root + "_info.json"
with open(model_info_path) as fin:
model_info = json.load(fin)
self._info = model_info
self.onnx_session = onnxruntime.InferenceSession(model_path)

warnings.resetwarnings()

def get_predictions(self, x, **kwargs):
if self.model is not None:
outputs = self.model.predict(x, **kwargs)
elif self.onnx_session is not None:
outputs = self.onnx_session.run(
[self.onnx_session.get_outputs()[0].name],
{self.onnx_session.get_inputs()[0].name: x.astype(np.float32)}
)[0]
else:
raise ValueError("No model found with which to obtain predictions.")
return outputs

def serialize(self) -> bytes:
"""Convert model to a bytes object.
Expand Down
6 changes: 3 additions & 3 deletions riid/models/bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import tensorflow_probability as tfp

from riid.data.sampleset import SampleSet
from riid.models import TFModelBase
from riid.models import PyRIIDModel


class PoissonBayesClassifier(TFModelBase):
class PoissonBayesClassifier(PyRIIDModel):
"""This Poisson-Bayes classifier calculates the conditional Poisson log probability of each
seed spectrum given the measurement.
Expand Down Expand Up @@ -139,7 +139,7 @@ def predict(self, gross_ss: SampleSet, bg_ss: SampleSet,
bg_spectra = tf.convert_to_tensor(bg_ss.spectra.values, dtype=tf.float32)
bg_lts = tf.convert_to_tensor(bg_ss.info.live_time.values, dtype=tf.float32)

prediction_probas = self.model.predict((
prediction_probas = self.get_predictions((
gross_spectra, gross_lts, bg_spectra, bg_lts
), batch_size=512, verbose=verbose)

Expand Down
Loading

0 comments on commit 49560f4

Please sign in to comment.