Skip to content

Commit

Permalink
add parallel for loop with generic reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
matekelemen committed Mar 16, 2024
1 parent 358979a commit 682369c
Showing 1 changed file with 60 additions and 6 deletions.
66 changes: 60 additions & 6 deletions kratos/utilities/parallel_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// _|\_\_| \__,_|\__|\___/ ____/
// Multi-Physics
//
// License: BSD License
// Kratos default license: kratos/license.txt
// License: BSD License
// Kratos default license: kratos/license.txt
//
// Main authors: Riccardo Rossi
// Denis Demidov
Expand Down Expand Up @@ -67,10 +67,6 @@ namespace Kratos
class KRATOS_API(KRATOS_CORE) ParallelUtilities
{
public:
///@name Life Cycle
///@{

///@}
///@name Operations
///@{

Expand Down Expand Up @@ -281,6 +277,41 @@ class BlockPartition
return global_reducer.GetValue();
}

template <class TThreadLocalStorage,
class TFunction,
class TThreadLocalReduction,
std::enable_if_t<std::is_same_v<std::invoke_result_t<TThreadLocalReduction,TThreadLocalStorage&>,void>,bool> = true>
void for_each(const TThreadLocalStorage& rTls,
TFunction&& rFunction,
TThreadLocalReduction&& rTLSReducer)
{
// Check type requirements
static_assert(std::is_copy_constructible<TThreadLocalStorage>::value, "TThreadLocalStorage must be copy constructible!");

KRATOS_PREPARE_CATCH_THREAD_EXCEPTION

#pragma omp parallel
{
TThreadLocalStorage tls(rTls);

#pragma omp for
for (int i=0; i<mNchunks; ++i) {
KRATOS_TRY
for (auto it = mBlockPartition[i]; it != mBlockPartition[i+1]; ++it){
rFunction(*it, tls); // note that we pass the value to the function, not the iterator
} // for it in mBlockPartition[i]
KRATOS_CATCH_THREAD_EXCEPTION
} // for i in range(mNchunks)

#pragma omp critical
{
rTLSReducer(tls);
} // pragma omp critical
} // pragma omp parallel

KRATOS_CHECK_AND_THROW_THREAD_EXCEPTION
}

private:
int mNchunks;
std::array<TIterator, MaxThreads> mBlockPartition;
Expand Down Expand Up @@ -438,6 +469,29 @@ template <class TReducer,
return block_for_each<TReducer>(v.begin(), v.end(), tls, std::forward<TFunctionType>(func));
}

template <class TContainer,
class TThreadLocalStorage,
class TFunction,
class TThreadLocalReduction,
std::enable_if_t<std::is_same_v<std::invoke_result_t<TThreadLocalReduction,TThreadLocalStorage&>,void>,bool> = true>
void block_for_each(TContainer&& rContainer,
const TThreadLocalStorage& rTls,
TFunction&& rFunction,
TThreadLocalReduction&& rReduction)
{
using ContainerType = std::remove_reference_t<TContainer>;
using iterator_type = std::conditional_t<
std::is_const_v<ContainerType>,
typename ContainerType::const_iterator,
typename ContainerType::iterator
>;
return BlockPartition<iterator_type>(rContainer.begin(), rContainer.end()).for_each(
rTls,
std::forward<TFunction>(rFunction),
std::forward<TThreadLocalReduction>(rReduction)
);
}

//***********************************************************************************
//***********************************************************************************
//***********************************************************************************
Expand Down

0 comments on commit 682369c

Please sign in to comment.