-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jaimerzp
committed
Jun 6, 2024
1 parent
03a3c3b
commit 0aa9b67
Showing
11 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ scikit-learn | |
setuptools_scm | ||
tables-io[full] | ||
deprecated | ||
multipledispatch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,5 @@ | |
from . import packing_utils | ||
|
||
from . import test_funcs | ||
|
||
from . import projectors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .projector_base import ProjectorBase | ||
from .projector_shifts import ProjectorShifts | ||
from .projector_moments import ProjectorMoments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from ..ensemble import Ensemble | ||
import numpy as np | ||
from multipledispatch import dispatch | ||
|
||
|
||
class ProjectorBase(object): | ||
@dispatch() | ||
def __init__(self): | ||
self._project_base() | ||
self._project() | ||
|
||
@dispatch(np.ndarray, np.ndarray) | ||
def __init__(self, zs, pzs): | ||
self._project_base(zs, pzs) | ||
|
||
@dispatch(Ensemble) | ||
def __init__(self, ens): | ||
self._project_base(ens) | ||
|
||
@dispatch() | ||
def _project_base(self): | ||
raise NotImplementedError | ||
|
||
@dispatch(np.ndarray, np.ndarray) | ||
def _project_base(self, zs, pzs): | ||
self.pzs = self._normalize(pzs) | ||
self.z = zs | ||
self.pz_mean = np.mean(self.pzs, axis=0) | ||
self.prior = None | ||
|
||
@dispatch(qp.ensemble.Ensemble) | ||
def _project_base(self, ens, z=None): | ||
if z is None: | ||
z = np.linspace(0, 1.5, 45) | ||
self.z = z | ||
pzs = ens.pdf(z) | ||
pzs = ens.objdata()['pdfs'] | ||
self.pzs = self._normalize(pzs) | ||
self.pz_mean = np.mean(self.pzs, axis=0) | ||
self.prior = None | ||
|
||
def _normalize(self, pzs): | ||
norms = np.sum(pzs, axis=1) | ||
pzs /= norms[:, None] | ||
return pzs | ||
|
||
def evaluate_model(self, *args): | ||
raise NotImplementedError | ||
|
||
def get_prior(self): | ||
if self.prior is None: | ||
self.prior = self._get_prior() | ||
return self.prior | ||
|
||
def sample_prior(self): | ||
prior = self.get_prior() | ||
return prior.rvs() | ||
|
||
def save_prior(self): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
from ..ensemble import Ensemble | ||
from multipledispatch import dispatch | ||
from .projector_base import ProjectorBase | ||
from numpy.linalg import eig, cholesky | ||
from scipy.stats import multivariate_normal as mvn | ||
|
||
|
||
class ProjectorMoments(ProjectorBase): | ||
@dispatch() | ||
def __init__(self): | ||
self._project_base() | ||
self._project() | ||
|
||
@dispatch(np.ndarray, np.ndarray) | ||
def __init__(self, zs, pzs): | ||
self._project_base(zs, pzs) | ||
self._project() | ||
|
||
@dispatch(Ensemble) | ||
def __init__(self, ens): | ||
self._project_base(ens) | ||
self._project() | ||
|
||
def _project(self): | ||
self.pz_cov = self._get_cov() | ||
self.pz_chol = cholesky(self.pz_cov) | ||
|
||
def _get_cov(self): | ||
cov = np.cov(self.pzs, rowvar=False) | ||
if not self._is_pos_def(cov): | ||
print('Warning: Covariance matrix is not positive definite') | ||
print('The covariance matrix will be regularized') | ||
jitter = 1e-15 * np.eye(cov.shape[0]) | ||
w, v = eig(cov+jitter) | ||
w = np.real(np.abs(w)) | ||
v = np.real(v) | ||
cov = v @ np.diag(np.abs(w)) @ v.T | ||
cov = np.tril(cov) + np.triu(cov.T, 1) | ||
if not self._is_pos_def(cov): | ||
print('Warning: regularization failed') | ||
print('The covariance matrix will be diagonalized') | ||
cov = np.diag(np.diag(cov)) | ||
return cov | ||
|
||
def _is_pos_def(self, A): | ||
return np.all(np.linalg.eigvals(A) > 0) | ||
|
||
def evaluate_model(self, pz, alpha): | ||
z = pz[0] | ||
pz = pz[1] | ||
return [z, pz + self.pz_chol @ alpha] | ||
|
||
def _get_prior(self): | ||
return mvn(np.zeros_like(self.pz_mean), | ||
np.ones_like(self.pz_mean)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
from ..ensemble import Ensemble | ||
from multipledispatch import dispatch | ||
from .projector_base import ProjectorBase | ||
from scipy.interpolate import interp1d | ||
from scipy.stats import multivariate_normal | ||
|
||
|
||
class ProjectorShifts(ProjectorBase): | ||
@dispatch() | ||
def __init__(self): | ||
self._project_base() | ||
self._project() | ||
|
||
@dispatch(np.ndarray, np.ndarray) | ||
def __init__(self, zs, pzs): | ||
self._project_base(zs, pzs) | ||
self._project() | ||
|
||
@dispatch(Ensemble) | ||
def __init__(self, ens): | ||
self._project_base(ens) | ||
self._project() | ||
|
||
def _project(self): | ||
self.shift = self._find_shift() | ||
|
||
def evaluate_model(self, pz, shift): | ||
z = pz[0] | ||
pz = pz[1] | ||
z_shift = z + shift | ||
pz_shift = interp1d(z_shift, pz, | ||
kind='linear', | ||
fill_value='extrapolate')(z) | ||
return [z, pz_shift] | ||
|
||
def _find_shift(self): | ||
stds = np.std(self.pzs, axis=1) # std of each pz | ||
s_stds = np.std(stds) # std of the z-std | ||
m_stds = np.mean(stds) # mean of the z-std | ||
return s_stds / m_stds | ||
|
||
def _get_prior(self): | ||
return multivariate_normal([0], [self.shift**2]) |
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import qp | ||
import numpy as np | ||
import rail_projector.projectors as rp | ||
|
||
|
||
def make_qp_ens(file): | ||
zs = file['zs'] | ||
pzs = file['pzs'] | ||
dz = np.mean(np.diff(zs)) | ||
zs_edges = np.append(zs - dz/2, zs[-1] + dz/2) | ||
q = qp.Ensemble(qp.hist, data={"bins":zs_edges, "pdfs":pzs}) | ||
return q | ||
|
||
def test_base_from_qp(): | ||
file = np.load('rail_projector/tests/dummy.npz') | ||
ens = make_qp_ens(file) | ||
projector = rp.ProjectorBase(ens) | ||
m, n = projector.pzs.shape | ||
k, = projector.z.shape | ||
pzs = file['pzs'] | ||
pzs /= np.sum(pzs, axis=1)[:, None] | ||
assert n == k | ||
assert np.allclose(projector.pz_mean, np.mean(pzs, axis=0)) | ||
|
||
def test_base_from_arrs(): | ||
file = np.load('rail_projector/tests/dummy.npz') | ||
zs = file['zs'] | ||
pzs = file['pzs'] | ||
projector = rp.ProjectorBase(zs, pzs) | ||
m, n = projector.pzs.shape | ||
k, = projector.z.shape | ||
pzs = file['pzs'] | ||
pzs /= np.sum(pzs, axis=1)[:, None] | ||
assert n == k | ||
assert np.allclose(projector.pz_mean, np.mean(pzs, axis=0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import qp | ||
import numpy as np | ||
import rail_projector.projectors as rp | ||
|
||
|
||
def make_qp_ens(file): | ||
zs = file['zs'] | ||
pzs = file['pzs'] | ||
dz = np.mean(np.diff(zs)) | ||
zs_edges = np.append(zs - dz/2, zs[-1] + dz/2) | ||
q = qp.Ensemble(qp.hist, data={"bins":zs_edges, "pdfs":pzs}) | ||
return q | ||
|
||
|
||
def make_projector(): | ||
file = np.load('rail_projector/tests/dummy.npz') | ||
ens = make_qp_ens(file) | ||
return rp.ProjectorMoments(ens) | ||
|
||
|
||
def test_prior(): | ||
projector = make_projector() | ||
prior = projector.get_prior() | ||
assert prior is not None | ||
|
||
|
||
def test_sample_prior(): | ||
projector = make_projector() | ||
pz = projector.sample_prior() | ||
assert len(pz) == len(projector.pz_mean) | ||
|
||
|
||
def test_model(): | ||
projector = make_projector() | ||
shift = projector.sample_prior() | ||
input = np.array([projector.z, projector.pz_mean]) | ||
output = projector.evaluate_model(input, shift) | ||
assert (projector.z == output[0]).all() | ||
assert len(output[1]) == len(projector.pz_mean) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import qp | ||
import numpy as np | ||
import rail_projector.projectors as rp | ||
|
||
|
||
def make_qp_ens(file): | ||
zs = file['zs'] | ||
pzs = file['pzs'] | ||
dz = np.mean(np.diff(zs)) | ||
zs_edges = np.append(zs - dz/2, zs[-1] + dz/2) | ||
q = qp.Ensemble(qp.hist, data={"bins":zs_edges, "pdfs":pzs}) | ||
return q | ||
|
||
|
||
def make_projector(): | ||
file = np.load('rail_projector/tests/dummy.npz') | ||
ens = make_qp_ens(file) | ||
return rp.ProjectorShifts(ens) | ||
|
||
|
||
def test_prior(): | ||
projector = make_projector() | ||
prior = projector.get_prior() | ||
assert prior is not None | ||
|
||
|
||
def test_sample_prior(): | ||
projector = make_projector() | ||
shift = projector.sample_prior() | ||
assert len([shift]) == len([projector.shift]) | ||
|
||
|
||
def test_model(): | ||
projector = make_projector() | ||
shift = projector.sample_prior() | ||
input = np.array([projector.z, projector.pz_mean]) | ||
output = projector.evaluate_model(input, shift) | ||
assert (projector.z == output[0]).all() | ||
assert len(output[1]) == len(projector.pz_mean) |