Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Staging Branch for NeurIPS MCEIF Paper - WIP #572

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
de36144
Create __init__.py (#502)
SamWitty Jan 16, 2024
aef8c9b
Setup experiment configurations (#504)
SamWitty Jan 17, 2024
ee60ad4
More experiment configs (#505)
SamWitty Jan 18, 2024
47e0e86
ignore experiments and datasets folder
agrawalraj Jan 18, 2024
3c3562e
ignore experiments and datasets folder (#507)
agrawalraj Jan 18, 2024
3727fce
Merge branch 'staging-robust-icml' of github.com:BasisResearch/causal…
agrawalraj Jan 18, 2024
325d4b0
Refactor of `influence_approx.py` (#509)
agrawalraj Jan 19, 2024
72452d1
Finite Difference Baseline (#508)
azane Jan 19, 2024
3ecead9
Merge branch 'staging-robust-icml' of github.com:BasisResearch/causal…
agrawalraj Jan 22, 2024
bd63526
Integrates FD for Squared Density Into Experiment (#510)
azane Jan 24, 2024
b40b5cb
Merge branch 'staging-robust-icml' of github.com:BasisResearch/causal…
agrawalraj Jan 24, 2024
8535309
initial experiments
agrawalraj Jan 25, 2024
4025e8f
uncomitted changes
agrawalraj Jan 25, 2024
da3d5d5
Implement TMLE estimator using Influence Functions (#484)
SamWitty Jan 25, 2024
8a3f390
Merge branch 'staging-robust-icml' of github.com:BasisResearch/causal…
agrawalraj Jan 25, 2024
a331035
quality experiment
agrawalraj Jan 26, 2024
eac68e6
still memory leak
agrawalraj Jan 26, 2024
2a93228
hacky workaround for memory issues
SamWitty Jan 28, 2024
111db71
ran experiment and results
SamWitty Jan 28, 2024
eb10112
ran experiment more
SamWitty Jan 29, 2024
d5e833c
ran for 100
SamWitty Jan 29, 2024
0eecd6b
progress on opt functional
SamWitty Jan 29, 2024
079d4bd
progress on alternative portfolio allocation model/functional
SamWitty Jan 30, 2024
73398d1
first pass at markowitz experiment
SamWitty Jan 30, 2024
4b01d34
nits
SamWitty Jan 30, 2024
2b1f3f4
got markowitx working
SamWitty Jan 30, 2024
a06df47
update results
SamWitty Feb 2, 2024
a46cc00
updated figs
SamWitty Feb 2, 2024
9763659
reran experiments with IHDP
SamWitty Oct 22, 2024
ff63c18
rename
SamWitty Oct 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
# C extensions
*.so

# Data and experiment folders
docs/examples/robust_paper/datasets/
docs/examples/robust_paper/experiments/

# Packages
*.egg
*.egg-info
Expand Down
218 changes: 217 additions & 1 deletion chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,234 @@
import copy
import warnings
from typing import Any, Callable, TypeVar

import torch
import torchopt
from typing_extensions import ParamSpec

from chirho.robust.handlers.predictive import PredictiveFunctional
from chirho.robust.internals.utils import make_functional_call
from chirho.robust.ops import Functional, Point, influence_fn

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


def tmle_scipy_optimize_wrapper(
packed_influence, log_jitter: float = 1e-6
) -> torch.Tensor:
import numpy as np
import scipy
from scipy.optimize import LinearConstraint

# Turn things into numpy. This makes us sad... :(
D = packed_influence.detach().numpy()

N, L = D.shape[0], D.shape[1]

def loss(epsilon):
correction = 1 + D.dot(epsilon)

return -np.sum(np.log(np.maximum(correction, log_jitter)))

positive_density_constraint = LinearConstraint(
D, -1 * np.ones(N), np.inf * np.ones(N)
)

epsilon_solve = scipy.optimize.minimize(
loss, np.zeros(L, dtype=np.float64), constraints=positive_density_constraint
)

if not epsilon_solve.success:
warnings.warn("TMLE optimization did not converge.", RuntimeWarning)

# Convert epsilon back to torch. This makes us happy... :)
packed_epsilon = torch.tensor(epsilon_solve.x, dtype=packed_influence.dtype)

return packed_epsilon


# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn
def tmle(
functional: Functional[P, S],
test_point: Point,
learning_rate: float = 1e-5,
n_grad_steps: int = 100,
n_tmle_steps: int = 1,
num_nmc_samples: int = 1000,
num_grad_samples: int = 1000,
log_jitter: float = 1e-6,
verbose: bool = False,
influence_estimator: Callable[
[Functional[P, S], Point[T]], Functional[P, S]
] = influence_fn,
**influence_kwargs,
) -> Functional[P, S]:
from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood

def _solve_epsilon(prev_model: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
# find epsilon that minimizes the corrected density on test data

influence_at_test = influence_estimator(
functional, test_point, **influence_kwargs
)(prev_model)(*args, **kwargs)

flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(influence_at_test)

N = flat_influence_at_test[0].shape[0]

packed_influence_at_test = torch.concatenate(
[i.reshape(N, -1) for i in flat_influence_at_test]
)

packed_epsilon = tmle_scipy_optimize_wrapper(packed_influence_at_test)

return packed_epsilon

def _solve_model_projection(
packed_epsilon: torch.Tensor,
prev_model: torch.nn.Module,
*args,
**kwargs,
) -> torch.nn.Module:
prev_params, functional_model = make_functional_call(
PredictiveFunctional(prev_model, num_samples=num_grad_samples)
)
prev_params = {k: v.detach() for k, v in prev_params.items()}

# Sample data from the model. Note that we only sample once during projection.
with torch.no_grad():
data: Point[T] = functional_model(prev_params, *args, **kwargs)
data = {k: v.detach() for k, v in data.items() if k in test_point}

data = {
k: v
for k, v in functional_model(prev_params, *args, **kwargs).items()
if k in test_point
}

batched_log_prob: torch.nn.Module = BatchedNMCLogMarginalLikelihood(
prev_model, num_samples=num_nmc_samples
)

_, log_p_phi = make_functional_call(batched_log_prob)

influence_at_data = influence_estimator(functional, data, **influence_kwargs)(
prev_model
)(*args, **kwargs)
flat_influence_at_data, _ = torch.utils._pytree.tree_flatten(influence_at_data)
N_x = flat_influence_at_data[0].shape[0]

packed_influence_at_data = torch.concatenate(
[i.reshape(N_x, -1) for i in flat_influence_at_data]
).detach()

log_likelihood_correction = torch.log(
torch.maximum(
1 + packed_influence_at_data.mv(packed_epsilon),
torch.tensor(log_jitter),
)
).detach()
if verbose:
influence_at_test = influence_estimator(
functional, test_point, **influence_kwargs
)(prev_model)(*args, **kwargs)
flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(
influence_at_test
)
N = flat_influence_at_test[0].shape[0]

packed_influence_at_test = torch.concatenate(
[i.reshape(N, -1) for i in flat_influence_at_test]
).detach()

log_likelihood_correction_at_test = torch.log(
torch.maximum(
1 + packed_influence_at_test.mv(packed_epsilon),
torch.tensor(log_jitter),
)
)

print("previous log prob at test", log_p_phi(prev_params, test_point).sum())
print(
"new log prob at test",
(
log_p_phi(prev_params, test_point)
+ log_likelihood_correction_at_test
).sum(),
)

log_p_epsilon_at_data = (
log_likelihood_correction + log_p_phi(prev_params, data)
).detach()

def loss(new_params):
log_p_phi_at_data = log_p_phi(new_params, data)
return torch.sum((log_p_phi_at_data - log_p_epsilon_at_data) ** 2)

grad_fn = torch.func.grad(loss)

new_params = {
k: v.clone().detach().requires_grad_(True) for k, v in prev_params.items()
}

optimizer = torchopt.adam(lr=learning_rate)

optimizer_state = optimizer.init(new_params)

for i in range(n_grad_steps):
grad = grad_fn(new_params)
if verbose and i % 100 == 0:
print(f"inner_iteration_{i}_loss", loss(new_params))
for parameter_name, parameter in prev_model.named_parameters():
parameter.data = new_params[f"model.{parameter_name}"]

estimate = functional(prev_model)(*args, **kwargs)
assert isinstance(estimate, torch.Tensor)
print(
f"inner_iteration_{i}_estimate",
estimate.detach().item(),
)
updates, optimizer_state = optimizer.update(
grad, optimizer_state, inplace=False
)
new_params = torchopt.apply_updates(new_params, updates)

for parameter_name, parameter in prev_model.named_parameters():
parameter.data = new_params[f"model.{parameter_name}"]

return prev_model

def _corrected_functional(*models: Callable[P, Any]) -> Callable[P, S]:
assert len(models) == 1
model = models[0]

assert isinstance(model, torch.nn.Module)

def _estimator(*args, **kwargs) -> S:
tmle_model = copy.deepcopy(model)

for _ in range(n_tmle_steps):
packed_epsilon = _solve_epsilon(tmle_model, *args, **kwargs)

tmle_model = _solve_model_projection(
packed_epsilon, tmle_model, *args, **kwargs
)
return functional(tmle_model)(*args, **kwargs)

return _estimator

return _corrected_functional


# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn
def one_step_corrected_estimator(
functional: Functional[P, S],
*test_points: Point[T],
influence_estimator: Callable[
[Functional[P, S], Point[T]], Functional[P, S]
] = influence_fn,
**influence_kwargs,
) -> Functional[P, S]:
"""
Expand All @@ -30,7 +246,7 @@ def one_step_corrected_estimator(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step)
eif_fn = influence_estimator(functional, *test_points, **influence_kwargs_one_step)

def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]:
plug_in_estimator = functional(*model)
Expand Down
46 changes: 46 additions & 0 deletions docs/examples/robust_paper/analytic_eif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from typing import Tuple
from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood


def analytic_eif_expected_density(test_data, plug_in, model, *args, **kwargs):
log_marginal_prob_at_points = BatchedNMCLogMarginalLikelihood(model, num_samples=1)(
test_data, *args, **kwargs
)
analytic_eif_at_test_pts = 2 * (torch.exp(log_marginal_prob_at_points) - plug_in)
analytic_correction = analytic_eif_at_test_pts.mean()
return analytic_correction, analytic_eif_at_test_pts


def analytic_eif_ate_causal_glm(
test_data, point_estimates
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the analytic EIF for the ATE for a ``CausalGLM`` model.

:param test_data: Dictionary containing test data with keys "X", "A", and "Y"
:param point_estimates: Estimated parameters of the model with keys "propensity_weights",
"outcome_weights", "treatment_weight", and "intercept"
:type point_estimates: _type_
:return: Tuple of the analytic EIF averaged over test,
and the analytic EIF evaluated pointwise at each test point
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
assert "propensity_weights" in point_estimates, "propensity_weights not found"
assert "outcome_weights" in point_estimates, "outcome_weights not found"
assert "treatment_weight" in point_estimates, "treatment_weight not found"
assert "intercept" in point_estimates, "treatment_weight not found"
assert test_data.keys() == {"X", "A", "Y"}, "test_data has unexpected keys"

X = test_data["X"]
A = test_data["A"]
Y = test_data["Y"]
pi_X = torch.sigmoid(X.mv(point_estimates["propensity_weights"]))
mu_X = (
X.mv(point_estimates["outcome_weights"])
+ A * point_estimates["treatment_weight"]
+ point_estimates["intercept"]
)
analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)
analytic_correction = analytic_eif_at_test_pts.mean()
return analytic_correction, analytic_eif_at_test_pts
Empty file.
Loading