Skip to content

Commit

Permalink
add laev2 (2x2 Hermitian eig) and lasr (apply multiple Givens)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgates3 committed Jun 19, 2024
1 parent 4abdfe6 commit 83360ed
Show file tree
Hide file tree
Showing 13 changed files with 908 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ docs/doxygen/errors.txt
docs/html/
files.txt
include/lapack/defines.h
install*
issues/
lib/*.a
lib/*.so
lib/pkgconfig/*.pc
make.inc
test/tester
tools/gen
wiki/
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ add_library(
src/lacp2.cc
src/lacpy.cc
src/laed4.cc
src/laev2.cc
src/lag2c.cc
src/lag2d.cc
src/lag2s.cc
Expand Down Expand Up @@ -334,6 +335,7 @@ add_library(
src/lascl.cc
src/laset.cc
src/lassq.cc
src/lasr.cc
src/laswp.cc
src/lauum.cc
src/opgtr.cc
Expand Down
90 changes: 90 additions & 0 deletions include/lapack/fortran.h
Original file line number Diff line number Diff line change
Expand Up @@ -20382,6 +20382,96 @@ void LAPACK_zunhr_col(
lapack_complex_double* T, lapack_int const* ldt,
lapack_complex_double* D, lapack_int* info );

//--------------------
#define LAPACK_slaev2 LAPACK_GLOBAL( slaev2, SLAEV2 )
void LAPACK_slaev2(
float const* a, float const* b, float const* c,
float* rt1, float* rt2,
float* cs1, float* sn1
);

#define LAPACK_dlaev2 LAPACK_GLOBAL( dlaev2, DLAEV2 )
void LAPACK_dlaev2(
double const* a, double const* b, double const* c,
double* rt1, double* rt2,
double* cs1, double* sn1
);

#define LAPACK_claev2 LAPACK_GLOBAL( claev2, CLAEV2 )
void LAPACK_claev2(
lapack_complex_float const* a,
lapack_complex_float const* b,
lapack_complex_float const* c,
float* rt1, float* rt2,
float* cs1, lapack_complex_float* sn1
);

#define LAPACK_zlaev2 LAPACK_GLOBAL( zlaev2, ZLAEV2 )
void LAPACK_zlaev2(
lapack_complex_double const* a,
lapack_complex_double const* b,
lapack_complex_double const* c,
double* rt1, double* rt2,
double* cs1, lapack_complex_double* sn1
);

//--------------------
#define LAPACK_slasr_base LAPACK_GLOBAL( slasr, SLASR )
void LAPACK_slasr_base(
char const* side, char const* pivot, char const* direction,
lapack_int const* m, lapack_int const* n,
float const* C, float const* S,
float* A, lapack_int const* lda
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t side_len, size_t pivot_len, size_t direction_len
#endif
);

#define LAPACK_dlasr_base LAPACK_GLOBAL( dlasr, DLASR )
void LAPACK_dlasr_base(
char const* side, char const* pivot, char const* direction,
lapack_int const* m, lapack_int const* n,
double const* C, double const* S,
double* A, lapack_int const* lda
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t side_len, size_t pivot_len, size_t direction_len
#endif
);

#define LAPACK_clasr_base LAPACK_GLOBAL( clasr, CLASR )
void LAPACK_clasr_base(
char const* side, char const* pivot, char const* direction,
lapack_int const* m, lapack_int const* n,
float const* C, float const* S,
lapack_complex_float* A, lapack_int const* lda
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t side_len, size_t pivot_len, size_t direction_len
#endif
);

#define LAPACK_zlasr_base LAPACK_GLOBAL( zlasr, ZLASR )
void LAPACK_zlasr_base(
char const* side, char const* pivot, char const* direction,
lapack_int const* m, lapack_int const* n,
double const* C, double const* S,
lapack_complex_double* A, lapack_int const* lda
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t side_len, size_t pivot_len, size_t direction_len
#endif
);
#ifdef LAPACK_FORTRAN_STRLEN_END
#define LAPACK_slasr( ... ) LAPACK_slasr_base( __VA_ARGS__, 1, 1, 1 )
#define LAPACK_dlasr( ... ) LAPACK_dlasr_base( __VA_ARGS__, 1, 1, 1 )
#define LAPACK_clasr( ... ) LAPACK_clasr_base( __VA_ARGS__, 1, 1, 1 )
#define LAPACK_zlasr( ... ) LAPACK_zlasr_base( __VA_ARGS__, 1, 1, 1 )
#else
#define LAPACK_slasr( ... ) LAPACK_slasr_base( __VA_ARGS__ )
#define LAPACK_dlasr( ... ) LAPACK_dlasr_base( __VA_ARGS__ )
#define LAPACK_clasr( ... ) LAPACK_clasr_base( __VA_ARGS__ )
#define LAPACK_zlasr( ... ) LAPACK_zlasr_base( __VA_ARGS__ )
#endif


#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
46 changes: 46 additions & 0 deletions include/lapack/util.hh
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,52 @@ inline RowCol char2rowcol( char ch )
return val;
}

// -----------------------------------------------------------------------------
// check_ortho (LAPACK testing zunt01)
enum class Pivot : char {
Variable = 'V',
Top = 'T',
Bottom = 'B',
};

extern const char* Pivot_help;

//--------------------
inline char to_char( Pivot value )
{
return char( value );
}

inline const char* to_c_string( Pivot value )
{
switch (value) {
case Pivot::Variable: return "variable";
case Pivot::Top: return "top";
case Pivot::Bottom: return "bottom";
}
return "?";
}

inline std::string to_string( Pivot value )
{
return to_c_string( value );
}

inline void from_string( std::string const& str, Pivot* val )
{
std::string str_ = str;
std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower );

if (str_ == "v" || str_ == "variable")
*val = Pivot::Variable;
else if (str_ == "t" || str_ == "top")
*val = Pivot::Top;
else if (str_ == "b" || str_ == "bottom")
*val = Pivot::Bottom;
else
throw Error( "unknown Pivot: " + str );
}

//------------------------------------------------------------------------------
// For %lld printf-style printing, cast to llong; guaranteed >= 64 bits.
using llong = long long;
Expand Down
54 changes: 54 additions & 0 deletions include/lapack/wrappers.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3654,6 +3654,35 @@ int64_t laed4(
double rho,
double* lambda );

// -----------------------------------------------------------------------------
void laev2(
float a, float b, float c,
float* rt1,
float* rt2,
float* cs1,
float* sn1 );

void laev2(
double a, double b, double c,
double* rt1,
double* rt2,
double* cs1,
double* sn1 );

void laev2(
std::complex<float> a, std::complex<float> b, std::complex<float> c,
float* rt1,
float* rt2,
float* cs1,
std::complex<float>* sn1 );

void laev2(
std::complex<double> a, std::complex<double> b, std::complex<double> c,
double* rt1,
double* rt2,
double* cs1,
std::complex<double>* sn1 );

// -----------------------------------------------------------------------------
int64_t lag2c(
int64_t m, int64_t n,
Expand Down Expand Up @@ -4398,6 +4427,31 @@ void lassq(
double* scale,
double* sumsq );

// -----------------------------------------------------------------------------
void lasr(
lapack::Side side, lapack::Pivot pivot, lapack::Direction direction,
int64_t m, int64_t n,
float const* C, float const* S,
float* A, int64_t lda );

void lasr(
lapack::Side side, lapack::Pivot pivot, lapack::Direction direction,
int64_t m, int64_t n,
double const* C, double const* S,
double* A, int64_t lda );

void lasr(
lapack::Side side, lapack::Pivot pivot, lapack::Direction direction,
int64_t m, int64_t n,
float const* C, float const* S,
std::complex<float>* A, int64_t lda );

void lasr(
lapack::Side side, lapack::Pivot pivot, lapack::Direction direction,
int64_t m, int64_t n,
double const* C, double const* S,
std::complex<double>* A, int64_t lda );

// -----------------------------------------------------------------------------
void laswp(
int64_t n,
Expand Down
125 changes: 125 additions & 0 deletions src/laev2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) 2017-2023, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "lapack.hh"
#include "lapack/fortran.h"

#include <vector>

namespace lapack {

//------------------------------------------------------------------------------
/// @ingroup heev_computational
void laev2(
float a, float b, float c,
float* rt1,
float* rt2,
float* cs1,
float* sn1 )
{
LAPACK_slaev2(
&a, &b, &c, rt1, rt2, cs1, sn1 );
}

//------------------------------------------------------------------------------
/// @ingroup heev_computational
void laev2(
double a, double b, double c,
double* rt1,
double* rt2,
double* cs1,
double* sn1 )
{
LAPACK_dlaev2(
&a, &b, &c, rt1, rt2, cs1, sn1 );
}

//------------------------------------------------------------------------------
/// @ingroup heev_computational
void laev2(
std::complex<float> a, std::complex<float> b, std::complex<float> c,
float* rt1,
float* rt2,
float* cs1,
std::complex<float>* sn1 )
{
LAPACK_claev2(
(lapack_complex_float*) &a,
(lapack_complex_float*) &b,
(lapack_complex_float*) &c,
rt1, rt2, cs1,
(lapack_complex_float*) sn1 );
}

//------------------------------------------------------------------------------
/// Computes the eigendecomposition of a 2-by-2 Hermitian matrix
///
/// [ a b ]
/// [ conj( b ) c ].
///
/// On return, rt1 is the eigenvalue of larger absolute value, rt2 is the
/// eigenvalue of smaller absolute value, and (cs1, sn1) is the unit right
/// eigenvector for rt1, giving the decomposition
///
/// [ cs1 conj( sn1 ) ] [ a b ] [ cs1 -conj( sn1 ) ] = [ rt1 0 ]
/// [ -sn1 cs1 ] [ conj(b) c ] [ sn1 cs1 ] [ 0 rt2 ].
///
/// Overloaded versions are available for
/// `float`, `double`, `std::complex<float>`, and `std::complex<double>`.
///
/// @param[in] a
/// The (1, 1) element of the 2-by-2 matrix.
///
/// @param[in] b
/// The (1, 2) element and the conjugate of the (2, 1) element of
/// the 2-by-2 matrix.
///
/// @param[in] c
/// The (2, 2) element of the 2-by-2 matrix.
///
/// @param[out] rt1
/// The eigenvalue of larger absolute value.
///
/// @param[out] rt2
/// The eigenvalue of smaller absolute value.
///
/// @param[out] cs1
///
/// @param[out] sn1
/// The vector (cs1, sn1) is a unit right eigenvector for rt1.
///
//------------------------------------------------------------------------------
/// @par Further Details
///
/// rt1 is accurate to a few ulps barring over/underflow.
///
/// rt2 may be inaccurate if there is massive cancellation in the
/// determinant a*c-b*b; higher precision or correctly rounded or
/// correctly truncated arithmetic would be needed to compute rt2
/// accurately in all cases.
///
/// cs1 and sn1 are accurate to a few ulps barring over/underflow.
///
/// Overflow is possible only if rt1 is within a factor of 5 of overflow.
/// Underflow is harmless if the input data is 0 or exceeds
/// underflow_threshold / macheps.
///
/// @ingroup heev_computational
void laev2(
std::complex<double> a, std::complex<double> b, std::complex<double> c,
double* rt1,
double* rt2,
double* cs1,
std::complex<double>* sn1 )
{
LAPACK_zlaev2(
(lapack_complex_double*) &a,
(lapack_complex_double*) &b,
(lapack_complex_double*) &c,
rt1, rt2, cs1,
(lapack_complex_double*) sn1 );
}

} // namespace lapack
Loading

0 comments on commit 83360ed

Please sign in to comment.