← 返回首页
/*******************************************************
* Copyright (c) 2025, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include
#include
#include
#include
#include
#include
#include
#include
// This makes the macros cleaner
using af::array;
using af::dim4;
using af::dtype_traits;
using af::exception;
using af::randu;
using half_float::half;
using std::abs;
using std::endl;
using std::vector;
const int num = 10000;
const float hlf_err = 1e-2;
const float flt_err = 1e-3;
const double dbl_err = 1e-6;
typedef std::complex complex_float;
typedef std::complex complex_double;
template
T sigmoid(T in) {
return T(1.0 / (1.0 + std::exp(-in)));
}
template
T rsqrt(T in) {
return T(1.0 / sqrt(in));
}
#define MATH_TEST(T, func, err, lo, hi) \
TEST(Math, func##_##T) { \
try { \
SUPPORTED_TYPE_CHECK(T); \
af_dtype ty = (af_dtype)dtype_traits::af_type; \
array a = (hi - lo) * randu(num, ty) + lo + err; \
a = a.as(ty); \
eval(a); \
array b = func(a); \
vector h_a(a.elements()); \
a.host(&h_a[0]); \
for (size_t i = 0; i < h_a.size(); i++) { h_a[i] = func(h_a[i]); } \
\
ASSERT_VEC_ARRAY_NEAR(h_a, dim4(h_a.size()), b, err); \
} catch (exception & ex) { FAIL() 199711L || _MSC_VER >= 1800
MATH_TESTS_CPLX(asin)
MATH_TESTS_CPLX(acos)
MATH_TESTS_CPLX(atan)
MATH_TESTS_ALL(asinh)
MATH_TESTS_ALL(atanh)
MATH_TESTS_LIMITS_REAL(acosh, 1, 5)
MATH_TESTS_LIMITS_CPLX(acosh, 1, 5)
MATH_TESTS_LIMITS_REAL(round, -10, 10)
MATH_TESTS_REAL(cbrt)
MATH_TESTS_REAL(expm1)
MATH_TESTS_REAL(log1p)
MATH_TESTS_REAL(erf)
MATH_TESTS_REAL(erfc)
#endif
TEST(Math, Not) {
array a = randu(5, 5, b8);
array b = !a;
char *ha = a.host();
char *hb = b.host();
for (int i = 0; i < a.elements(); i++) { ASSERT_EQ(ha[i] ^ hb[i], true); }
af_free_host(ha);
af_free_host(hb);
}
TEST(Math, Modulus) {
af::dim4 shape(2, 2);
std::vector aData{3, 3, 3, 3};
std::vector bData{2, 2, 2, 2};
auto a = af::array(shape, aData.data(), afHost);
auto b = af::array(shape, bData.data(), afHost);
auto rem = a % b;
auto neg_rem = -a % b;
ASSERT_ARRAYS_EQ(af::constant(1, shape, s64), rem);
ASSERT_ARRAYS_EQ(af::constant(-1, shape, s64), neg_rem);
}
TEST(Math, ModulusFloat) {
SUPPORTED_TYPE_CHECK(half_float::half);
af::dim4 shape(2, 2);
auto a = af::constant(3, shape, af::dtype::f16);
auto b = af::constant(2, shape, af::dtype::f16);
auto a32 = af::constant(3, shape, af::dtype::f32);
auto b32 = af::constant(2, shape, af::dtype::f32);
auto a64 = af::constant(3, shape, af::dtype::f64);
auto b64 = af::constant(2, shape, af::dtype::f64);
auto rem = a % b;
auto rem32 = a32 % b32;
auto rem64 = a64 % b64;
auto neg_rem = -a % b;
auto neg_rem32 = -a32 % b32;
auto neg_rem64 = -a64 % b64;
ASSERT_ARRAYS_EQ(af::constant(1, shape, af::dtype::f16), rem);
ASSERT_ARRAYS_EQ(af::constant(1, shape, af::dtype::f32), rem32);
ASSERT_ARRAYS_EQ(af::constant(1, shape, af::dtype::f64), rem64);
ASSERT_ARRAYS_EQ(af::constant(-1, shape, af::dtype::f16), neg_rem);
ASSERT_ARRAYS_EQ(af::constant(-1, shape, af::dtype::f32), neg_rem32);
ASSERT_ARRAYS_EQ(af::constant(-1, shape, af::dtype::f64), neg_rem64);
ASSERT_ARRAYS_EQ(rem32.as(f16), rem);
}