-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SD-119] Implement layer execution latency measurements for Pytorch #48
Changes from 2 commits
197d863
cf9637d
3f9925d
b1f8a08
71a9eb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,6 @@ requires-python = ">=3.11" | |
dependencies = [ | ||
"dvc-s3>=3.2.0", | ||
"pandas>=2.2.3", | ||
"pillow>=11.1.0", | ||
"torch==2.2.2", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""Script.""" | ||
|
||
import torch | ||
from typing import Any | ||
import time | ||
from functools import partial | ||
import datetime | ||
import json | ||
|
||
|
||
def load_model(model_name: str) -> Any: | ||
"""Load model from Pytorch Hub. | ||
|
||
Args: | ||
model_name: Name of model. | ||
It should be same as that in Pytorch Hub. | ||
|
||
Raises: | ||
ValueError: If loading model fails from PyTorch Hub | ||
|
||
Returns: | ||
PyTorch model | ||
""" | ||
try: | ||
return torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=True) | ||
except: | ||
raise ValueError( | ||
f"Model name: {model_name} is most likely incorrect. " | ||
"Please refer https://pytorch.org/hub/ to get model name." | ||
) | ||
|
||
def get_layers(model: torch.nn.Module, name_prefix: str="") -> list[tuple[str, torch.nn.Module]]: | ||
""" | ||
Recursively get all layers in a pytorch model. | ||
|
||
Args: | ||
model: the pytorch model to look for layers. | ||
name_prefix: Use to identify the parents layer. Defaults to "". | ||
|
||
Returns: | ||
a list of tuple containing the layer name and the layer. | ||
""" | ||
children = list(model.named_children()) | ||
|
||
if len(children) == 0: # No child | ||
result = [(name_prefix, model)] | ||
else: | ||
# If have children, iterate over each child. | ||
result = [] | ||
for child_name, child in children: | ||
# Recursively call get_layers on the child, appending the current | ||
# child's name to the name_prefix. | ||
layers = get_layers(child, name_prefix + "_" + child_name) | ||
result.extend(layers) | ||
|
||
return result | ||
|
||
|
||
def define_and_register_hooks(model) -> None: | ||
""" | ||
Define and register layer hooks. | ||
""" | ||
layer_time_dict = {} | ||
|
||
for layer_name, layer in get_layers(model): | ||
layer.register_forward_pre_hook(partial(layer_time_pre_hook, layer_time_dict, layer_name)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are hooks being used? Does profiler API not work? I think it will provide much better results on CPU and GPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I understand autograd profiler gives us the wrong resolution (it gives latencies by operation type, not by layer). But @osw282 can give more context I guess There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe adding some error handling for layers that dont support hooks? I don't think Sequential supports them for one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The From what I’ve read and tried out, the profiler is mainly meant to inspect the cost of individual operators in a model. Each layer can involve multiple operations, it doesn’t directly correspond to layer-level timings we are after. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don’t think Sequential is a layer; it’s more like a module in PyTorch that can consist of multiple layers, like convolution, maxpool, etc. What we want is to measure the time of these individual layers rather than the Sequential module as a whole. I tested the code on resnet18, and it seems to register hooks on all the layers we’re interested in. @OCarrollM Does that answer your question? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be interesting to explore something like this : https://github.com/awwong1/torchprof for the autograd profile. It provides results for layer-wise latency and has a mode to show latencies for low level operations as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't get the torchprof example code running. Not really sure how to interpret the results. They used AlexNet as example, I assume this is the layer-level trace. Not sure what they meant by feature 0-12. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mean to use this package itself as it is outdated and does not work with PyTorch > 1.9. Something to explore down to road to make a similar package extending the current autograd profiler to provide similar such output.
|
||
layer.register_forward_hook(partial(layer_time_hook, layer_time_dict, layer_name)) | ||
|
||
return layer_time_dict | ||
|
||
def layer_time_pre_hook(layer_time_dict, layer_name, module, input) -> None: | ||
""" | ||
Forward pass pre-hook. | ||
|
||
Args: | ||
layer_time_dict: dictionary to save hook function output. | ||
layer_name: the layer to register hook. | ||
module: the module to register hook. | ||
input: tuple containing the input arguments to module's forward method. | ||
""" | ||
layer_time_dict[layer_name] = (time.time(), datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if I think Reference: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need to worry about CPU implementation of timing too much, as the hardware we're using has cuda, and CUDA events will be used in the next ticket There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That was my thinking, this ticket is to find a way to measure model layers execution time. SD-125 will use cuda event. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We also want There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. N.B. these functions will need use CudaEvents from SD-118 in the actual benchmark script
osw282 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def layer_time_hook(layer_time_dict, layer_name, module, input, output) -> None: | ||
""" | ||
Forward pass hook. | ||
|
||
Args: | ||
layer_time_dict: dictionary to save hook function output. | ||
layer_name: the layer to register hook. | ||
module: the module to register hook. | ||
input: tuple containing the input arguments to module's forward method. | ||
output: the output tensor from the forward method. | ||
""" | ||
layer_time_dict[layer_name] = (time.time() - layer_time_dict[layer_name][0], layer_time_dict[layer_name][1]) | ||
|
||
|
||
def get_layer_execution_time(model_name, input_shape, num_inference_cycles) -> None: | ||
""" | ||
Benchmark function. | ||
|
||
This function will output a json file containing the recorded execution time for all layer for all inference cycles ran. | ||
|
||
Args: | ||
model_name: the name of the model to run benchmark on. | ||
input_shape: shape of the input tensor for inference. | ||
num_inference_cycles: number of cycles to run inference. | ||
""" | ||
all_cycle_measurements = {} | ||
model = load_model(model_name) | ||
x = torch.randn(input_shape) | ||
|
||
layer_time_dict = define_and_register_hooks(model) | ||
|
||
for i in range(num_inference_cycles): | ||
model(x) | ||
all_cycle_measurements[f"cycle_{i}"] = layer_time_dict | ||
|
||
with open(f"{datetime.datetime.now().strftime('%Y_%m_%d_%H:%M:%S')}_{model_name}_inference_trace_{num_inference_cycles}_cycles.json", "w") as f: | ||
json.dump(all_cycle_measurements, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this packaged being used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It an implicit requirement by pytorch to run resnet18