diff --git a/examples/PyTorch/simple_pytorch_to_mdf.py b/examples/PyTorch/simple_pytorch_to_mdf.py index 7c59dabf..9ea841c6 100644 --- a/examples/PyTorch/simple_pytorch_to_mdf.py +++ b/examples/PyTorch/simple_pytorch_to_mdf.py @@ -7,6 +7,7 @@ from torchviz import make_dot import netron from modeci_mdf.interfaces.pytorch import pytorch_to_mdf +from modeci_mdf.interfaces.pytorch import pytorch_fx_to_mdf import os @@ -74,6 +75,33 @@ def main(): ) print("Passed all comparison tests!") + print("Comparing FX to mdf translation") + # Pytorch FX to MDF + fx_mdf_model, fx_params_dict = pytorch_fx_to_mdf( + model=model, + args=(input_images), + ) + + # Get the graph + fx_mdf_graph = fx_mdf_model.graphs[0] + + # Add inputs to the parameters dict so we can feed this to the EvaluableGraph for initialization of graph input. + fx_params_dict["input1"] = input_images.numpy() + + # Evaluate the model via the MDF scheduler + eg = EvaluableGraph(graph=fx_mdf_graph, verbose=False) + eg.evaluate(initializer=fx_params_dict) + fx_output_mdf = eg.output_enodes[0].get_output() + + print("Evaluated the graph in PyTorch FX, output: %s" % (_val_info(fx_output_mdf))) + + # Make sure the results are the same between PyTorch and MDF + assert np.allclose( + output.detach().numpy(), + fx_output_mdf, + ) + print("Passed all comparison tests!") + # Output the model to JSON mdf_model.to_json_file("simple_pytorch_to_mdf.json") @@ -90,7 +118,10 @@ def main(): ) onnx_model = onnx.load("simple_pytorch_to_mdf.onnx") onnx.checker.check_model(onnx_model) - sess = rt.InferenceSession("simple_pytorch_to_mdf.onnx") + sess = rt.InferenceSession( + "simple_pytorch_to_mdf.onnx", + providers=["AzureExecutionProvider", "CPUExecutionProvider"], + ) res = sess.run(None, {sess.get_inputs()[0].name: input_images.numpy()}) print("Exported to MDF and ONNX") diff --git a/src/modeci_mdf/functions/onnx.py b/src/modeci_mdf/functions/onnx.py index 72a833a2..73775fa6 100644 --- a/src/modeci_mdf/functions/onnx.py +++ b/src/modeci_mdf/functions/onnx.py @@ -59,7 +59,10 @@ def predict_with_onnxruntime(model_def, *inputs) -> Dict[str, np.array]: A dict of output values, keys are output names for the model. Values are the output values of the model. """ - sess = ort.InferenceSession(model_def.SerializeToString()) + sess = ort.InferenceSession( + model_def.SerializeToString(), + providers=["AzureExecutionProvider", "CPUExecutionProvider"], + ) names = [i.name for i in sess.get_inputs()] dinputs = {name: input for name, input in zip(names, inputs)} res = sess.run(None, dinputs) diff --git a/src/modeci_mdf/interfaces/pytorch/__init__.py b/src/modeci_mdf/interfaces/pytorch/__init__.py index 3b5aa522..fea15938 100644 --- a/src/modeci_mdf/interfaces/pytorch/__init__.py +++ b/src/modeci_mdf/interfaces/pytorch/__init__.py @@ -1,6 +1,7 @@ """Import and export code for `PyTorch `_ models""" from .importer import pytorch_to_mdf +from .importer import pytorch_fx_to_mdf from .exporter import mdf_to_pytorch from . import mod_torch_builtins diff --git a/src/modeci_mdf/interfaces/pytorch/importer.py b/src/modeci_mdf/interfaces/pytorch/importer.py index 31e4f777..f9b34f0f 100644 --- a/src/modeci_mdf/interfaces/pytorch/importer.py +++ b/src/modeci_mdf/interfaces/pytorch/importer.py @@ -12,6 +12,10 @@ import onnx.defs import torch +import torch.fx +from torch.fx.node import Node as fx_Node +import numpy as np + # We need to monkey patch the torch._C.Node class to add a __getitem__ method # This is for torch 2.0 @@ -576,6 +580,290 @@ def pytorch_to_mdf( return mdf_model, params_dict +############ +# TORCH.FX TO MDF +############ +class FXPortMapper: + r""" + A simple class that handles mapping Torch fx input\ouput ids to MDF InputPort\OutputPort ids. It keeps track of + annoying details like graph level inputs and stuff. + """ + + def __init__(self, model): + self.trace = torch.fx.symbolic_trace(model) + self.fx_dict = dict(self.trace.named_parameters()) + # Keep generate special names for all the graph inputs and parameters + self.graph_inputs = self._get_graph_inputs_dict() + + def id_to_port(self, id: str): + """Turn unique TorchScript output and input value names into valid MDF input and outport names""" + + new_name = str(id).replace(".", "_") + + # Remove :: from ids, these cause issues with parsing in the execution engine + new_name = new_name.replace("::", "_") + + # Renive aby "-" from names, these cause issues with parsing in the execution engine + new_name = new_name.replace("-", "_") + + # If the first character is a digit, precede with an underscore so this can never be interpreted + # as number down the line. + if new_name[0].isdigit(): + new_name = "_" + new_name + + return new_name + + def port_to_id(self, name: str): + """Transform a port name back to is TorchScript ID""" + + # If first character is underscore, remove it + id = name + if name[0] == "_": + id = name[1:] + + # Replace any remaining underscores with '.' + id = id.replace("_", ".") + + # If this is a numeric id, make it an int again + if id[0].isdigit(): + id = int(id) + + # If this id is actually a debugName from a graph input, use that + for input_id, debug_name in self.graph_inputs.items(): + if debug_name == id: + return input_id + + return id + + def _get_graph_inputs_dict(self) -> Dict[str, str]: + """ + Create a dict mapping graph input torch.Node ids to default names. The default names are just: + - input1 + - input2 + - etc. + + Any parameters for the model will also be graph inputs but their node.debugName() will be used + instead. + """ + graph_inputs = {i + 1: value for i, value in enumerate(self.fx_dict.keys())} + return graph_inputs + + def get_params_dict(self): + fx_params_dict = {self.id_to_port(k): v.data for k, v in self.fx_dict.items()} + return fx_params_dict + + +def parse_shape_type(prev_in_shape, prev_in_type): + try: + inp_dtype = str(prev_in_type).replace("torch.", "") + except RuntimeError: + inp_dtype = str(prev_in_type.getElementType()) + + try: + shape = tuple(prev_in_shape) if prev_in_shape else None + except RuntimeError: + shape = None + + return shape, inp_dtype + + +def parameter_id_generator(): + arguments = ["alpha", "beta", "transB", "extra"] + index = 0 + while True: + yield arguments[index] + index = index + 1 + + +# Since currently using "onnx::" functions currently argument names need to be specific to the respective onnx functions +def get_argument_id(function_name): + func_vs_arglist = { + "onnx::Gemm": ["A", "B", "C"], + "onnx::Relu": ["X"], + "onnx::Reshape": ["data", "shape"], + } + if function_name in func_vs_arglist: + if not hasattr(get_argument_id, "index"): + get_argument_id.index = {} + if function_name not in get_argument_id.index: + get_argument_id.index[function_name] = 0 + else: + get_argument_id.index[function_name] = ( + get_argument_id.index[function_name] + 1 + ) % len(func_vs_arglist[function_name]) + return func_vs_arglist[function_name][get_argument_id.index[function_name]] + else: + return None + + +# https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern +class ShapeProp: + """ + Shape propagation. This class takes a `GraphModule`. + Then, its `propagate` method executes the `GraphModule` + node-by-node with the given arguments. + """ + + def __init__(self, mod, portmapper_obj): + self.mod = mod + self.graph = torch.fx.symbolic_trace(mod).graph + self.modules = dict(self.mod.named_modules()) + self.pm = portmapper_obj + # Due to some reason info about weigths and biases of each node is not present in any of its attributes + # So, this is a hack to get the assosciated weights and biases to a particular layer + # eg if node name is fc1 and fc1 is present as a substring in the fx_dict keys then we simply map + # the node name with the keys matched eg - fc1:[fc1.weight,fc1.bias], view[] ,..... + self.node_wbs = { + node.name: [key for key in self.pm.fx_dict.keys() if node.name in key] + for node in self.graph.nodes + } + + def propagate(self, *args): + args_iter = iter(args) + env: Dict[str, fx_Node] = {} + prev_in_shape = args[0].shape + prev_in_type = args[0].dtype + mdf_model = Model(id="sample_model") + mdf_graph = Graph(id="sample_graph") + mdf_model.graphs.append(mdf_graph) + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + for node in self.graph.nodes: + para_gen = parameter_id_generator() # initiate generator for parameter id + + mdf_node = Node(id=node.name) + shape, type = parse_shape_type(prev_in_shape, prev_in_type) + + argument = {} + for ip in node.all_input_nodes: + mdf_node.input_ports.append( + InputPort(id=ip.name, shape=shape, type=type) + ) + + func = None + arg_values = list(node.args) + self.node_wbs[node.name] + + if node.op == "placeholder": + result = next(args_iter) + + elif node.op == "call_function": + func = "onnx::Relu" + argument[get_argument_id(func)] = str(arg_values[0]) + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + + elif node.op == "call_method": + func = "onnx::Reshape" + param_id = next(para_gen) + argument[get_argument_id(func)] = str(arg_values[0]) + argument[get_argument_id(func)] = param_id + mdf_node.parameters.append( + Parameter( + id=param_id, value=np.array([arg_values[1], arg_values[2]]) + ) + ) + mdf_node.parameters.append(Parameter(id="allowzero", value=0)) + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + + elif node.op == "call_module": + func = "onnx::Gemm" + value = 1.0 + for val in arg_values: + if val in self.node_wbs[node.name]: + shape, type = parse_shape_type( + self.pm.fx_dict[val].shape, self.pm.fx_dict[val].dtype + ) + mdf_node.input_ports.append( + InputPort( + id=self.pm.id_to_port(val), shape=shape, type=type + ) + ) + val = self.pm.id_to_port(val) + + argument[get_argument_id(func)] = str(val) + temp_id = next(para_gen) + if ( + temp_id == "transB" + ): # value 1.0 gives an error, something related to type mismatch for transB + value = 1 + mdf_node.parameters.append(Parameter(id=temp_id, value=value)) + result = self.modules[node.target]( + *load_arg(node.args), **load_arg(node.kwargs) + ) + + mdf_node.parameters.append( + Parameter(id=next(para_gen), value=None, args=argument, function=func) + ) + + shape, type = parse_shape_type(result.shape, result.dtype) + mdf_node.output_ports.append( + OutputPort( + id=node.name, + value=mdf_node.parameters[ + -1 + ].id, # last parameter id is the value of the output node + shape=shape, + type=type, + ) + ) + + # construct edge + from_id = node.name + from_port = mdf_node.output_ports[0].id + to_id = node.next.name + to_port = mdf_node.output_ports[0].id + + mdf_edge = Edge( + id=f"{from_id}_{to_id}", + sender=from_id, + sender_port=f"{from_port}", + receiver=to_id, + receiver_port=f"{to_port}", + ) + + # not to make an node for constants/inputs/output nodes + if node.op not in ["placeholder", "output"]: + mdf_graph.nodes.append(mdf_node) + # since the last node is the output we wont make an edge for the last second node + if node.next.op != "output": + mdf_graph.edges.append(mdf_edge) + + env[node.name] = result + + prev_in_shape = result.shape + prev_in_type = result.dtype + + return mdf_model + + +def pytorch_fx_to_mdf( + model: Union[Callable, torch.nn.Module], + args: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None, +) -> Union[Model, Graph]: + r""" + Convert a PyTorch model to an MDF model. This function will invoke `torch.fx` on the + model to convert to MDF counterpart in node to node manner + + Args: + model: The pytorch model to translate into MDF. + args: The input arguments for this model. A nn.Module is passed then the model will be traced with these + inputs. If a ScriptModule is passed, they are still needed to determine input shapes. + + Returns: + The translated MDF model + """ + pm = FXPortMapper(model) + fx_params_dict = pm.get_params_dict() + + exec = ShapeProp(model, pm) + mdf_model = exec.propagate(args) + + return mdf_model, fx_params_dict + + if __name__ == "__main__": def simple(x, y):