Skip to content

Commit

Permalink
add lae2 (2x2 Hermitian eigvals)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgates3 committed Jun 19, 2024
1 parent 83360ed commit 655c75d
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 3 deletions.
13 changes: 13 additions & 0 deletions include/lapack/fortran.h
Original file line number Diff line number Diff line change
Expand Up @@ -20382,6 +20382,19 @@ void LAPACK_zunhr_col(
lapack_complex_double* T, lapack_int const* ldt,
lapack_complex_double* D, lapack_int* info );

//--------------------
#define LAPACK_slae2 LAPACK_GLOBAL( slae2, SLAE2 )
void LAPACK_slae2(
float const* a, float const* b, float const* c,
float* rt1, float* rt2
);

#define LAPACK_dlae2 LAPACK_GLOBAL( dlae2, DLAE2 )
void LAPACK_dlae2(
double const* a, double const* b, double const* c,
double* rt1, double* rt2
);

//--------------------
#define LAPACK_slaev2 LAPACK_GLOBAL( slaev2, SLAEV2 )
void LAPACK_slaev2(
Expand Down
11 changes: 11 additions & 0 deletions include/lapack/wrappers.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3637,6 +3637,17 @@ void lacpy(
std::complex<double> const* A, int64_t lda,
std::complex<double>* B, int64_t ldb );

// -----------------------------------------------------------------------------
void lae2(
float a, float b, float c,
float* rt1,
float* rt2 );

void lae2(
double a, double b, double c,
double* rt1,
double* rt2 );

// -----------------------------------------------------------------------------
int64_t laed4(
int64_t n, int64_t i,
Expand Down
73 changes: 73 additions & 0 deletions src/lae2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "lapack.hh"
#include "lapack/fortran.h"

#include <vector>

namespace lapack {

//------------------------------------------------------------------------------
/// @ingroup heev_computational
void lae2(
float a, float b, float c,
float* rt1,
float* rt2 )
{
LAPACK_slae2(
&a, &b, &c, rt1, rt2 );
}

//------------------------------------------------------------------------------
/// Computes the eigenvalues of a 2-by-2 symmetric matrix
/// [ a b ]
/// [ b c ].
/// On return, rt1 is the eigenvalue of larger absolute value, and rt2
/// is the eigenvalue of smaller absolute value.
///
/// Overloaded versions are available for
/// `float`, `double`, `std::complex<float>`, and `std::complex<double>`.
///
/// @param[in] a
/// The (1, 1) element of the 2-by-2 matrix.
///
/// @param[in] b
/// The (1, 2) and (2, 1) elements of the 2-by-2 matrix.
///
/// @param[in] c
/// The (2, 2) element of the 2-by-2 matrix.
///
/// @param[out] rt1
/// The eigenvalue of larger absolute value.
///
/// @param[out] rt2
/// The eigenvalue of smaller absolute value.
///
//------------------------------------------------------------------------------
/// @par Further Details
///
/// rt1 is accurate to a few ulps barring over/underflow.
///
/// rt2 may be inaccurate if there is massive cancellation in the
/// determinant a*c-b*b; higher precision or correctly rounded or
/// correctly truncated arithmetic would be needed to compute rt2
/// accurately in all cases.
///
/// Overflow is possible only if rt1 is within a factor of 5 of overflow.
/// Underflow is harmless if the input data is 0 or exceeds
/// underflow_threshold / macheps.
///
/// @ingroup heev_computational
void lae2(
double a, double b, double c,
double* rt1,
double* rt2 )
{
LAPACK_dlae2(
&a, &b, &c, rt1, rt2 );
}

} // namespace lapack
1 change: 1 addition & 0 deletions test/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def filter_csv( values, csv ):
[ 'heevr', gen + dtype + align + n + jobz + uplo + vl + vu ],
[ 'heevr', gen + dtype + align + n + jobz + uplo + il + iu ],
[ 'hetrd', gen + dtype + align + n + uplo ],
[ 'lae2', gen + dtype_real ], # 2x2, eigvals only
[ 'laev2', gen + dtype ], # 2x2
[ 'ungtr', gen + dtype + align + n + uplo ],
[ 'unmtr', gen + dtype_real + align + mn + uplo + side + trans ], # real does trans = N, T, C
Expand Down
1 change: 1 addition & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ std::vector< testsweeper::routines_t > routines = {
{ "", nullptr, Section::newline },

{ "heevr", test_heevr, Section::heev }, // backwards error check
{ "lae2", test_lae2, Section::heev }, // backwards error check
{ "laev2", test_laev2, Section::heev }, // backwards error check
{ "", nullptr, Section::newline },

Expand Down
1 change: 1 addition & 0 deletions test/test.hh
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ void test_heevx ( Params& params, bool run );
void test_heevd ( Params& params, bool run );
void test_heevr ( Params& params, bool run );
void test_hetrd ( Params& params, bool run );
void test_lae2 ( Params& params, bool run );
void test_laev2 ( Params& params, bool run );
void test_sturm ( Params& params, bool run );
void test_ungtr ( Params& params, bool run );
Expand Down
108 changes: 108 additions & 0 deletions test/test_lae2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "test.hh"
#include "lapack.hh"
#include "lapack/flops.hh"
#include "print_matrix.hh"
#include "error.hh"
#include "lapacke_wrappers.hh"

#include <vector>

//------------------------------------------------------------------------------
template< typename scalar_t >
void test_lae2_work( Params& params, bool run )
{
using real_t = blas::real_type< scalar_t >;
using blas::conj;
using lapack::Job, lapack::Uplo;

// Constants
const real_t eps = std::numeric_limits< real_t >::epsilon();

// get & mark input values
params.dim.m() = 2;
params.dim.n() = 2;
real_t tol = params.tol() * eps;
int verbose = params.verbose();
params.matrix.mark();

// mark non-standard output values
params.error.name( "Lambda" );

if (! run)
return;

//---------- setup
int64_t n = 2;
int64_t lda = 2;
std::vector< scalar_t > A( lda*n );

lapack::generate_matrix( params.matrix, n, n, &A[0], lda );

// A = [ a b ], stored column-wise.
// [ conj(b) c ]
scalar_t a = A[ 0 ];
scalar_t b = A[ 2 ];
scalar_t c = A[ 3 ];
A[ 1 ] = conj( b );

real_t rt1, rt2, rt1_ref, rt2_ref, cs1;
scalar_t sn1;

if (verbose >= 2) {
printf( "A = " ); print_matrix( n, n, &A[0], lda );
}

//---------- run test
testsweeper::flush_cache( params.cache() );
double time = testsweeper::get_wtime();
// no info returned
lapack::lae2( a, b, c, &rt1, &rt2 );
time = testsweeper::get_wtime() - time;

params.time() = time;

std::vector< real_t > Lambda{ rt1, rt2 };

if (verbose >= 2) {
printf( "Lambda = " ); print_vector( n, &Lambda[0], 1 );
}

if (params.check() == 'y') {
//---------- run reference, using laev2
testsweeper::flush_cache( params.cache() );
time = testsweeper::get_wtime();
lapack::laev2( a, b, c, &rt1_ref, &rt2_ref, &cs1, &sn1 );
time = testsweeper::get_wtime() - time;

params.ref_time() = time;

//---------- check error compared to reference
std::vector< real_t > Lambda_ref{ rt1_ref, rt2_ref };
real_t error = rel_error( Lambda, Lambda_ref );
params.error() = error;
params.okay() = (error < tol);
}
}

//------------------------------------------------------------------------------
void test_lae2( Params& params, bool run )
{
switch (params.datatype()) {
case testsweeper::DataType::Single:
test_lae2_work< float >( params, run );
break;

case testsweeper::DataType::Double:
test_lae2_work< double >( params, run );
break;

default:
throw std::runtime_error( "unknown datatype" );
break;
}
}
9 changes: 6 additions & 3 deletions tools/header_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,15 @@
# end
# end
# end
print( '#define LAPACK_' + func + ' LAPACK_GLOBAL(' + func + ',' + func.upper() + ')', file=output )
print( retval + ' LAPACK_' + func + '(\n ' + ', '.join( args ) + '\n);\n', file=output )
print( '#define LAPACK_' + func + ' LAPACK_GLOBAL( ' + func + ', '
+ func.upper() + ' )', file=output )
print( retval + ' LAPACK_' + func + '(\n ' + ', '.join( args )
+ '\n);\n', file=output )
else:
print( '// skipping, file not found:', filename )
print( '// skipping, file not found:', filename, file=output )
print( '// #define LAPACK_' + func + ' LAPACK_GLOBAL(' + func + ',' + func.upper() + ')', file=output )
print( '// #define LAPACK_' + func + ' LAPACK_GLOBAL(' + func + ','
+ func.upper() + ')', file=output )
print( '// void LAPACK_' + func + '( ... );\n', file=output )
# end
# end
Expand Down

0 comments on commit 655c75d

Please sign in to comment.