Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added proof of concept for torch.fx to MDF #492

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion examples/PyTorch/simple_pytorch_to_mdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchviz import make_dot
import netron
from modeci_mdf.interfaces.pytorch import pytorch_to_mdf
from modeci_mdf.interfaces.pytorch import pytorch_fx_to_mdf
import os


Expand Down Expand Up @@ -74,6 +75,33 @@ def main():
)
print("Passed all comparison tests!")

print("Comparing FX to mdf translation")
# Pytorch FX to MDF
fx_mdf_model, fx_params_dict = pytorch_fx_to_mdf(
model=model,
args=(input_images),
)

# Get the graph
fx_mdf_graph = fx_mdf_model.graphs[0]

# Add inputs to the parameters dict so we can feed this to the EvaluableGraph for initialization of graph input.
fx_params_dict["input1"] = input_images.numpy()

# Evaluate the model via the MDF scheduler
eg = EvaluableGraph(graph=fx_mdf_graph, verbose=False)
eg.evaluate(initializer=fx_params_dict)
fx_output_mdf = eg.output_enodes[0].get_output()

print("Evaluated the graph in PyTorch FX, output: %s" % (_val_info(fx_output_mdf)))

# Make sure the results are the same between PyTorch and MDF
assert np.allclose(
output.detach().numpy(),
fx_output_mdf,
)
print("Passed all comparison tests!")

# Output the model to JSON
mdf_model.to_json_file("simple_pytorch_to_mdf.json")

Expand All @@ -90,7 +118,10 @@ def main():
)
onnx_model = onnx.load("simple_pytorch_to_mdf.onnx")
onnx.checker.check_model(onnx_model)
sess = rt.InferenceSession("simple_pytorch_to_mdf.onnx")
sess = rt.InferenceSession(
"simple_pytorch_to_mdf.onnx",
providers=["AzureExecutionProvider", "CPUExecutionProvider"],
)
res = sess.run(None, {sess.get_inputs()[0].name: input_images.numpy()})
print("Exported to MDF and ONNX")

Expand Down
5 changes: 4 additions & 1 deletion src/modeci_mdf/functions/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def predict_with_onnxruntime(model_def, *inputs) -> Dict[str, np.array]:
A dict of output values, keys are output names for the model. Values are
the output values of the model.
"""
sess = ort.InferenceSession(model_def.SerializeToString())
sess = ort.InferenceSession(
model_def.SerializeToString(),
providers=["AzureExecutionProvider", "CPUExecutionProvider"],
)
names = [i.name for i in sess.get_inputs()]
dinputs = {name: input for name, input in zip(names, inputs)}
res = sess.run(None, dinputs)
Expand Down
1 change: 1 addition & 0 deletions src/modeci_mdf/interfaces/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Import and export code for `PyTorch <https://pytorch.org>`_ models"""

from .importer import pytorch_to_mdf
from .importer import pytorch_fx_to_mdf

from .exporter import mdf_to_pytorch
from . import mod_torch_builtins
288 changes: 288 additions & 0 deletions src/modeci_mdf/interfaces/pytorch/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import onnx.defs

import torch
import torch.fx
from torch.fx.node import Node as fx_Node
import numpy as np


# We need to monkey patch the torch._C.Node class to add a __getitem__ method
# This is for torch 2.0
Expand Down Expand Up @@ -576,6 +580,290 @@ def pytorch_to_mdf(
return mdf_model, params_dict


############
# TORCH.FX TO MDF
############
class FXPortMapper:
r"""
A simple class that handles mapping Torch fx input\ouput ids to MDF InputPort\OutputPort ids. It keeps track of
annoying details like graph level inputs and stuff.
"""

def __init__(self, model):
self.trace = torch.fx.symbolic_trace(model)
self.fx_dict = dict(self.trace.named_parameters())
# Keep generate special names for all the graph inputs and parameters
self.graph_inputs = self._get_graph_inputs_dict()

def id_to_port(self, id: str):
"""Turn unique TorchScript output and input value names into valid MDF input and outport names"""

new_name = str(id).replace(".", "_")

# Remove :: from ids, these cause issues with parsing in the execution engine
new_name = new_name.replace("::", "_")

# Renive aby "-" from names, these cause issues with parsing in the execution engine
new_name = new_name.replace("-", "_")

# If the first character is a digit, precede with an underscore so this can never be interpreted
# as number down the line.
if new_name[0].isdigit():
new_name = "_" + new_name

return new_name

def port_to_id(self, name: str):
"""Transform a port name back to is TorchScript ID"""

# If first character is underscore, remove it
id = name
if name[0] == "_":
id = name[1:]

# Replace any remaining underscores with '.'
id = id.replace("_", ".")

# If this is a numeric id, make it an int again
if id[0].isdigit():
id = int(id)

# If this id is actually a debugName from a graph input, use that
for input_id, debug_name in self.graph_inputs.items():
if debug_name == id:
return input_id

return id

def _get_graph_inputs_dict(self) -> Dict[str, str]:
"""
Create a dict mapping graph input torch.Node ids to default names. The default names are just:
- input1
- input2
- etc.

Any parameters for the model will also be graph inputs but their node.debugName() will be used
instead.
"""
graph_inputs = {i + 1: value for i, value in enumerate(self.fx_dict.keys())}
return graph_inputs

def get_params_dict(self):
fx_params_dict = {self.id_to_port(k): v.data for k, v in self.fx_dict.items()}
return fx_params_dict


def parse_shape_type(prev_in_shape, prev_in_type):
try:
inp_dtype = str(prev_in_type).replace("torch.", "")
except RuntimeError:
inp_dtype = str(prev_in_type.getElementType())

try:
shape = tuple(prev_in_shape) if prev_in_shape else None
except RuntimeError:
shape = None

return shape, inp_dtype


def parameter_id_generator():
arguments = ["alpha", "beta", "transB", "extra"]
index = 0
while True:
yield arguments[index]
index = index + 1


# Since currently using "onnx::" functions currently argument names need to be specific to the respective onnx functions
def get_argument_id(function_name):
func_vs_arglist = {
"onnx::Gemm": ["A", "B", "C"],
"onnx::Relu": ["X"],
"onnx::Reshape": ["data", "shape"],
}
if function_name in func_vs_arglist:
if not hasattr(get_argument_id, "index"):
get_argument_id.index = {}
if function_name not in get_argument_id.index:
get_argument_id.index[function_name] = 0
else:
get_argument_id.index[function_name] = (
get_argument_id.index[function_name] + 1
) % len(func_vs_arglist[function_name])
return func_vs_arglist[function_name][get_argument_id.index[function_name]]
else:
return None


# https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments.
"""

def __init__(self, mod, portmapper_obj):
self.mod = mod
self.graph = torch.fx.symbolic_trace(mod).graph
self.modules = dict(self.mod.named_modules())
self.pm = portmapper_obj
# Due to some reason info about weigths and biases of each node is not present in any of its attributes
# So, this is a hack to get the assosciated weights and biases to a particular layer
# eg if node name is fc1 and fc1 is present as a substring in the fx_dict keys then we simply map
# the node name with the keys matched eg - fc1:[fc1.weight,fc1.bias], view[] ,.....
self.node_wbs = {
node.name: [key for key in self.pm.fx_dict.keys() if node.name in key]
for node in self.graph.nodes
}

def propagate(self, *args):
args_iter = iter(args)
env: Dict[str, fx_Node] = {}
prev_in_shape = args[0].shape
prev_in_type = args[0].dtype
mdf_model = Model(id="sample_model")
mdf_graph = Graph(id="sample_graph")
mdf_model.graphs.append(mdf_graph)

def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])

for node in self.graph.nodes:
para_gen = parameter_id_generator() # initiate generator for parameter id

mdf_node = Node(id=node.name)
shape, type = parse_shape_type(prev_in_shape, prev_in_type)

argument = {}
for ip in node.all_input_nodes:
mdf_node.input_ports.append(
InputPort(id=ip.name, shape=shape, type=type)
)

func = None
arg_values = list(node.args) + self.node_wbs[node.name]

if node.op == "placeholder":
result = next(args_iter)

elif node.op == "call_function":
func = "onnx::Relu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a more general way to identify the appropriate ONNX function for the torch op. I think this might be something that the torch and ONNX folks are working on as well, see here. It looks like they are developing a converter between torch fx IR to onnxscript. I didn't know that ONNX now has a thing called ONNX script. Allowing for composing ONNX ops in Python. Might be useful for re-writing the execution engine in fact.

Anyway, I think this issue needs to be fixed before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did go through their implementation and they seem to get the appropriate function from here. But cannot figure out how to use a key in _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE dictionary to get the value.

argument[get_argument_id(func)] = str(arg_values[0])
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))

elif node.op == "call_method":
func = "onnx::Reshape"
param_id = next(para_gen)
argument[get_argument_id(func)] = str(arg_values[0])
argument[get_argument_id(func)] = param_id
mdf_node.parameters.append(
Parameter(
id=param_id, value=np.array([arg_values[1], arg_values[2]])
)
)
mdf_node.parameters.append(Parameter(id="allowzero", value=0))
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)

elif node.op == "call_module":
func = "onnx::Gemm"
value = 1.0
for val in arg_values:
if val in self.node_wbs[node.name]:
shape, type = parse_shape_type(
self.pm.fx_dict[val].shape, self.pm.fx_dict[val].dtype
)
mdf_node.input_ports.append(
InputPort(
id=self.pm.id_to_port(val), shape=shape, type=type
)
)
val = self.pm.id_to_port(val)

argument[get_argument_id(func)] = str(val)
temp_id = next(para_gen)
if (
temp_id == "transB"
): # value 1.0 gives an error, something related to type mismatch for transB
value = 1
mdf_node.parameters.append(Parameter(id=temp_id, value=value))
result = self.modules[node.target](
*load_arg(node.args), **load_arg(node.kwargs)
)

mdf_node.parameters.append(
Parameter(id=next(para_gen), value=None, args=argument, function=func)
)

shape, type = parse_shape_type(result.shape, result.dtype)
mdf_node.output_ports.append(
OutputPort(
id=node.name,
value=mdf_node.parameters[
-1
].id, # last parameter id is the value of the output node
shape=shape,
type=type,
)
)

# construct edge
from_id = node.name
from_port = mdf_node.output_ports[0].id
to_id = node.next.name
to_port = mdf_node.output_ports[0].id

mdf_edge = Edge(
id=f"{from_id}_{to_id}",
sender=from_id,
sender_port=f"{from_port}",
receiver=to_id,
receiver_port=f"{to_port}",
)

# not to make an node for constants/inputs/output nodes
if node.op not in ["placeholder", "output"]:
mdf_graph.nodes.append(mdf_node)
# since the last node is the output we wont make an edge for the last second node
if node.next.op != "output":
mdf_graph.edges.append(mdf_edge)

env[node.name] = result

prev_in_shape = result.shape
prev_in_type = result.dtype

return mdf_model


def pytorch_fx_to_mdf(
model: Union[Callable, torch.nn.Module],
args: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
) -> Union[Model, Graph]:
r"""
Convert a PyTorch model to an MDF model. This function will invoke `torch.fx` on the
model to convert to MDF counterpart in node to node manner

Args:
model: The pytorch model to translate into MDF.
args: The input arguments for this model. A nn.Module is passed then the model will be traced with these
inputs. If a ScriptModule is passed, they are still needed to determine input shapes.

Returns:
The translated MDF model
"""
pm = FXPortMapper(model)
fx_params_dict = pm.get_params_dict()

exec = ShapeProp(model, pm)
mdf_model = exec.propagate(args)

return mdf_model, fx_params_dict


if __name__ == "__main__":

def simple(x, y):
Expand Down