You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As PyTorch fails to offer an efficient function that performs copy according to a list of uncontiguous indices. I want to implement such a function by myself. Specifically, I wrote the following function, which can perform correctly:
import torch
import triton
import triton.language as tl
import time
@triton.jit
def copy_selected_indices(
input_tensor_ptr, # Pointer to the input tensor
output_tensor_ptr, # Pointer to the output tensor
input_indices_ptr, # Pointer to the indices tensor
output_indices_ptr,
n_elements, # Number of elements to copy
BLOCK_SIZE: tl.constexpr, # Block size for parallelism
):
# Get the program ID
pid = tl.program_id(axis=0)
# Calculate the offset in the indices array
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Load the indices
mask = offset < n_elements
input_indices = tl.load(input_indices_ptr + offset, mask=mask)
output_indices = tl.load(output_indices_ptr + offset, mask=mask)
# Load the data from the input tensor
input_data = tl.load(input_tensor_ptr + input_indices[:, None], mask=mask[:, None])
# Store the data in the output tensor
tl.store(output_tensor_ptr + output_indices[:, None], input_data, mask=mask[:, None])
if __name__ == '__main__':
# Example usage
hidden_dim = 5120
input_tensor = torch.ones((100, hidden_dim), device='cuda:2')
output_tensor = torch.zeros((100, hidden_dim), device='cuda:2')
input_indices = []
input_rows = [1, 3, 5, 7, 9, 12, 18, 22, 96, 98, 99]
for i in input_rows:
input_indices.extend(list(range(i * hidden_dim, (i + 1) * hidden_dim)))
input_indices = torch.as_tensor(input_indices, device='cuda:2', dtype=torch.int)
output_indices = []
output_rows = [1, 3, 5, 7, 9, 12, 18, 22, 96, 98, 99]
for i in input_rows:
output_indices.extend(list(range(i * hidden_dim, (i + 1) * hidden_dim)))
output_indices = torch.as_tensor(output_indices, device='cuda:2', dtype=torch.int)
input_tensor = input_tensor.contiguous()
output_tensor = output_tensor.contiguous()
n_elements = input_indices.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
block_size = 512
# warm up
copy_selected_indices[grid](input_tensor, torch.empty_like(input_tensor), input_indices, output_indices, n_elements, BLOCK_SIZE=block_size)
#torch.cuda.synchronize('cuda:2')
start_time = time.monotonic()
copy_selected_indices[grid](input_tensor, output_tensor, input_indices, output_indices, n_elements, BLOCK_SIZE=block_size)
#torch.cuda.synchronize('cuda:2')
print(time.monotonic() - start_time)
print(output_tensor)
In the above code, I copy a list of rows from input_tensor to output_tensor. I would like to know if there is a more efficient method to implement the same function?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all,
As PyTorch fails to offer an efficient function that performs copy according to a list of uncontiguous indices. I want to implement such a function by myself. Specifically, I wrote the following function, which can perform correctly:
In the above code, I copy a list of rows from input_tensor to output_tensor. I would like to know if there is a more efficient method to implement the same function?
BR
Beta Was this translation helpful? Give feedback.
All reactions