Skip to content

Commit

Permalink
Fix torch.autoscale
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Oct 1, 2024
1 parent 0870d30 commit d0611e5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def predict(self, inputs: ArgsKwargs) -> Any:
# for mixed_precision. That doesn't seem to match the docs, and now
# it raises an error when moving from the deprecated function. So
# I've removed the argument but I'm not certain it's correct.
with torch.autocast(device_type="cuda"):
with torch.autocast(device_type="cuda", enabled=self._mixed_precision):
outputs = self._model(*inputs.args, **inputs.kwargs)
self._model.train()
return outputs
Expand All @@ -133,7 +133,7 @@ def begin_update(self, inputs: ArgsKwargs):
# for mixed_precision. That doesn't seem to match the docs, and now
# it raises an error when moving from the deprecated function. So
# I've removed the argument but I'm not certain it's correct.
with torch.autocast("cuda"):
with torch.autocast("cuda", enabled=self._mixed_precision):
output = self._model(*inputs.args, **inputs.kwargs)

def backprop(grads):
Expand Down

0 comments on commit d0611e5

Please sign in to comment.