From 8c937e36a11c019a336b095f8183fcbb80ddb21a Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 9 Jan 2025 22:18:49 +0200 Subject: [PATCH] Fix cache retrieval (#718) --- Project.toml | 2 +- src/array.jl | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 927b0208..d58da3ad 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AMDGPU" uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e" authors = ["Julian P Samaroo ", "Valentin Churavy ", "Anton Smirnov "] -version = "1.2.0" +version = "1.2.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/array.jl b/src/array.jl index 6780f0b8..6cb4855a 100644 --- a/src/array.jl +++ b/src/array.jl @@ -6,10 +6,15 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N} function ROCArray{T, N, B}(::UndefInitializer, dims::Dims{N}) where {T, N, B <: Mem.AbstractAMDBuffer} @assert isbitstype(T) "ROCArray only supports bits types" sz::Int64 = prod(dims) * sizeof(T) - return GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), T, B, sz)) do + x = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), T, B, sz)) do @debug "Allocate `T=$T`, `dims=$dims`: $(Base.format_bytes(sz))" data = DataRef(pool_free, pool_alloc(B, sz)) return finalizer(unsafe_free!, new{T, N, B}(data, dims, 0)) + end + return if size(x) != dims + reshape(x, dims) + else + x end::ROCArray{T, N, B} end