Skip to content

Commit

Permalink
Merge pull request #1147 from AMS-Hippo/master
Browse files Browse the repository at this point in the history
Update parametric_umap.py
  • Loading branch information
lmcinnes authored Aug 18, 2024
2 parents 106dd9a + 825deac commit c72ac2f
Showing 1 changed file with 105 additions and 0 deletions.
105 changes: 105 additions & 0 deletions umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@
)
raise ImportError("umap.parametric_umap requires Keras") from None

torch_imported = True
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.onnx
import torchvision
except ImportError:
warn(
"""Torch and ONNX required for exporting to those formats."""
)
torch_imported = False


class ParametricUMAP(UMAP):
def __init__(
Expand Down Expand Up @@ -340,6 +354,19 @@ def save(self, save_location, verbose=True):
if verbose:
print("Pickle of ParametricUMAP model saved to {}".format(model_output))

def to_ONNX(self, save_location):
""" Exports trained parametric UMAP as ONNX.
"""
# Extract encoder
km = self.encoder
# Extract weights
pm = PumapNet(self.dims[0], self.n_components)
pm = weight_copier(km, pm)

# Put in ONNX
dummy_input = torch.randn(1, self.dims[0])
# Invoke export
return torch.onnx.export(pm, dummy_input, save_location)

def get_graph_elements(graph_, n_epochs):
"""
Expand Down Expand Up @@ -1019,3 +1046,81 @@ def _parametric_reconstruction_loss(self, y, y_pred):
loss = self.parametric_reconstruction_loss_fn(
y["reconstruction"], y_pred["reconstruction"])
return loss * self.parametric_reconstruction_loss_weight

##################################################
# 1. Pytorch version of parametric UMAP network. #
##################################################

if torch_imported:

class PumapNet(nn.Module):


def __init__(self, indim, outdim):

super(PumapNet, self).__init__()
self.dense1 = nn.Linear(indim, 100)
self.dense2 = nn.Linear(100, 100)
self.dense3 = nn.Linear(100, 100)
self.dense4 = nn.Linear(100, outdim)

"""
Creates the same network as the one used by parametric UMAP.
Note: shape of network is fixed.
Parameters
----------
indim : int
dimension of input to network.
outdim : int
dimension of output of network.
"""

def forward(self, x):
x = self.dense1(x)
x = F.relu(x)
x = self.dense2(x)
x = F.relu(x)
x = self.dense3(x)
x = F.relu(x)
x = self.dense4(x)
x = F.relu(x)
return x

######################
# 2. Copying weights #
######################

def weight_copier(km, pm):
""" Copies weights from a parametric UMAP encoder to pytorch.
Parameters
----------
km : encoder extracted from parametric UMAP.
pm: a PumapNet object. Will be overwritten.
Returns
-------
pm : PumapNet Object.
Net with copied weights.
"""
kweights = km.get_weights()
n_layers = int(len(kweights)/2) # The actual number of layers

# Get the names of the pytorch layers
all_keys = [x for x in pm.state_dict().keys()]
pm_names = [all_keys[2*i].split(".")[0] for i in range(4)]

# Set a variable for the state dict
pyt_state_dict = pm.state_dict()

for i in range(n_layers):
pyt_state_dict[pm_names[i] + ".bias"] = kweights[2*i + 1]
pyt_state_dict[pm_names[i] + ".weight"] = np.transpose(kweights[2*i])

for key in pyt_state_dict.keys():
pyt_state_dict[key] = torch.from_numpy(pyt_state_dict[key])

# Update
pm.load_state_dict(pyt_state_dict)
return pm
else:
pass

0 comments on commit c72ac2f

Please sign in to comment.