diff --git a/ducho/multimodal/visual/VisualFeatureExtractor.py b/ducho/multimodal/visual/VisualFeatureExtractor.py index 3a9aaa4..b04ce51 100644 --- a/ducho/multimodal/visual/VisualFeatureExtractor.py +++ b/ducho/multimodal/visual/VisualFeatureExtractor.py @@ -104,7 +104,7 @@ def extract_feature(self, image): elif 'transformers' in self._backend_libraries_list: # converting the input image tensor - outcome of the pre-processor - in a set. - model_input = {'pixel_values': image} + model_input = {'pixel_values': image[0]} model_input = {k: torch.tensor(v).to(self._device) for k, v in model_input.items()} model_output = getattr(self._model(**model_input), self._output_layer.lower()) return model_output.detach().cpu().numpy()