Skip to content

Commit

Permalink
Multi-function nodes with dictionary format
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickstock committed Apr 15, 2021
1 parent db95db9 commit 6372ce4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 33 deletions.
43 changes: 14 additions & 29 deletions mdf_to_pytorch/example_mdfs/mlp_classifier/mlp_classifier.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,23 @@
}
}
},
"mlp_hidden_layer": {
"mlp_hidden_layer_with_relu": {
"parameters": {
"weight": "weights.mlp_classifier.graphs.mlp_classifier.nodes.mlp_hidden_layer_with_relu.parameters.weight",
"bias": "weights.mlp_classifier.graphs.mlp_classifier.nodes.mlp_hidden_layer_with_relu.parameters.bias"
},
"functions":{
"Linear_2": {
"function": "linear"
}
},
"input_ports": {
"in_1": {
"shape": 128
}
},
"output_ports": {
"out_1": {
"shape": 128,
"value":"Relu_2"
}
}
},
"mlp_relu_2": {
"functions": {
"function": "linear",
"args":{
"variable0":"in_1"
}
},
"Relu_2": {
"function": "relu"
"function": "relu",
"args":{
"variable0":"Linear_2"
}
}
},
"input_ports": {
Expand All @@ -79,7 +70,7 @@
"output_ports": {
"out_1": {
"shape": 128,
"value":"Relu_1"
"value":"Relu_2"
}
}
},
Expand Down Expand Up @@ -133,23 +124,17 @@
},
"edge2": {
"sender": "mlp_relu_1",
"receiver": "mlp_hidden_layer",
"receiver": "mlp_hidden_layer_with_relu",
"sender_port": "out_1",
"receiver_port": "in_1"
},
"edge3": {
"sender": "mlp_hidden_layer",
"receiver": "mlp_relu_2",
"sender_port": "out_1",
"receiver_port": "in_1"
},
"edge4": {
"sender": "mlp_relu_2",
"sender": "mlp_hidden_layer_with_relu",
"receiver": "mlp_output_layer",
"sender_port": "out_1",
"receiver_port": "in_1"
},
"edge5": {
"edge4": {
"sender": "mlp_output_layer",
"receiver": "argmax_1",
"sender_port": "out_1",
Expand Down
32 changes: 28 additions & 4 deletions mdf_to_pytorch/mdf2torch/text_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,33 @@ def get_module_declaration_text(name, node_dict, dependency_graph, declared_modu

# Multi function node
else:

from toposort import toposort

# Need to put function calls in proper order
function_keys = set([key for key in functions.keys()])

function_graph = {}


for func_name, func_dict in functions.items():
if "args" in func_dict:
depends_on = func_dict["args"]["variable0"]

if depends_on in function_keys:

function_graph[func_name] = {depends_on}


function_graph = toposort(function_graph)
function_names = [list(e)[0] for e in list(function_graph)]

declaration_text += "\n\t\tself.function_list = []"
for function in functions:
function_name = next(iter(function.keys()))
function_type = function[function_name]["function"]


for function_name in function_names:
function = functions[function_name]
function_type = functions[function_name]["function"]

# Function is predefined
if (function_type in udf.__all__ or function_type in torch_builtins.__all__):
Expand All @@ -262,12 +285,13 @@ def get_module_declaration_text(name, node_dict, dependency_graph, declared_modu
declaration_text += "\n\t\tself.function_list.append({}())".format(function_type)

else:
constructor_call, func_class = generate_constructor_call(function, constructor_parameters)
constructor_call, func_class = generate_constructor_call((function_name, function), constructor_parameters)
declaration_text += "\n\t\tself.function_list.append({})".format(constructor_call)

initializer_call = generate_initializer_call(func_class, parameters, idx=True)
declaration_text += "\n{}".format(initializer_call)


declaration_text += "\n\t\tself.function = nn.Sequential(*self.function_list)"
forward_call, forward_signature = generate_module_forward_call(name, dependency_graph)
declaration_text+="\n{}".format(forward_call)
Expand Down

0 comments on commit 6372ce4

Please sign in to comment.