From 583dbc0c2350efb1a620bf6da7ee29a90acbe607 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Tue, 7 Mar 2023 20:38:35 -0800 Subject: [PATCH] fix encoding for inf2 (#534) * fix encoding * add comments --- engines/python/setup/djl_python/transformer-neuronx.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/engines/python/setup/djl_python/transformer-neuronx.py b/engines/python/setup/djl_python/transformer-neuronx.py index e9dd170dd..69d1d0cbb 100644 --- a/engines/python/setup/djl_python/transformer-neuronx.py +++ b/engines/python/setup/djl_python/transformer-neuronx.py @@ -151,8 +151,9 @@ def infer(self, inputs): f"{self.batch_size} batch size not equal to {len(input_text)} prompt size" ) with torch.inference_mode(): - input_ids = torch.as_tensor( - [self.tokenizer.encode(text) for text in input_text]) + # inf 2 needs padding + input_ids = self.tokenizer.batch_encode_plus( + input_text, return_tensors="pt", padding=True)['input_ids'] generated_sequence = self.model.sample( input_ids, sequence_length=seq_length) result = [