Skip to content

Commit

Permalink
In tutorials/quantize_vit, extract common methods to util.py (#238)
Browse files Browse the repository at this point in the history
* Extract common methods to util.py

* Update tutorials/quantize_vit/run_vit_b.py

Co-authored-by: Mark Saroufim <[email protected]>

* Update tutorials/quantize_vit/run_vit_b_quant.py

Co-authored-by: Mark Saroufim <[email protected]>

* amend

* amend

* Include the torchao utils

---------

Co-authored-by: Mark Saroufim <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
3 people authored May 20, 2024
1 parent adfe570 commit f0f00ce
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 48 deletions.
26 changes: 26 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch


def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result
26 changes: 2 additions & 24 deletions tutorials/quantize_vit/run_vit_b.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torchvision.models.vision_transformer as models

from torchao.utils import benchmark_model, profiler_runner
torch.set_float32_matmul_precision("high")
# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)

Expand All @@ -12,30 +14,6 @@

model = torch.compile(model, mode='max-autotune')

def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
Expand Down
26 changes: 2 additions & 24 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torchao
import torchvision.models.vision_transformer as models

from torchao.utils import benchmark_model, profiler_runner
torch.set_float32_matmul_precision("high")
# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)

Expand All @@ -19,30 +21,6 @@

model = torch.compile(model, mode='max-autotune')

def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
Expand Down

0 comments on commit f0f00ce

Please sign in to comment.