← 返回首页
/******************************************************* * Copyright (c) 2025, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include // This makes the macros cleaner using af::array; using af::dim4; using af::dtype_traits; using af::exception; using af::randu; using half_float::half; using std::abs; using std::endl; using std::vector; const int num = 10000; const float hlf_err = 1e-2; const float flt_err = 1e-3; const double dbl_err = 1e-6; typedef std::complex complex_float; typedef std::complex complex_double; template T sigmoid(T in) { return T(1.0 / (1.0 + std::exp(-in))); } template T rsqrt(T in) { return T(1.0 / sqrt(in)); } #define MATH_TEST(T, func, err, lo, hi) \ TEST(Math, func##_##T) { \ try { \ SUPPORTED_TYPE_CHECK(T); \ af_dtype ty = (af_dtype)dtype_traits::af_type; \ array a = (hi - lo) * randu(num, ty) + lo + err; \ a = a.as(ty); \ eval(a); \ array b = func(a); \ vector h_a(a.elements()); \ a.host(&h_a[0]); \ for (size_t i = 0; i < h_a.size(); i++) { h_a[i] = func(h_a[i]); } \ \ ASSERT_VEC_ARRAY_NEAR(h_a, dim4(h_a.size()), b, err); \ } catch (exception & ex) { FAIL() 199711L || _MSC_VER >= 1800 MATH_TESTS_CPLX(asin) MATH_TESTS_CPLX(acos) MATH_TESTS_CPLX(atan) MATH_TESTS_ALL(asinh) MATH_TESTS_ALL(atanh) MATH_TESTS_LIMITS_REAL(acosh, 1, 5) MATH_TESTS_LIMITS_CPLX(acosh, 1, 5) MATH_TESTS_LIMITS_REAL(round, -10, 10) MATH_TESTS_REAL(cbrt) MATH_TESTS_REAL(expm1) MATH_TESTS_REAL(log1p) MATH_TESTS_REAL(erf) MATH_TESTS_REAL(erfc) #endif TEST(Math, Not) { array a = randu(5, 5, b8); array b = !a; char *ha = a.host(); char *hb = b.host(); for (int i = 0; i < a.elements(); i++) { ASSERT_EQ(ha[i] ^ hb[i], true); } af_free_host(ha); af_free_host(hb); } TEST(Math, Modulus) { af::dim4 shape(2, 2); std::vector aData{3, 3, 3, 3}; std::vector bData{2, 2, 2, 2}; auto a = af::array(shape, aData.data(), afHost); auto b = af::array(shape, bData.data(), afHost); auto rem = a % b; auto neg_rem = -a % b; ASSERT_ARRAYS_EQ(af::constant(1, shape, s64), rem); ASSERT_ARRAYS_EQ(af::constant(-1, shape, s64), neg_rem); } TEST(Math, ModulusFloat) { SUPPORTED_TYPE_CHECK(half_float::half); af::dim4 shape(2, 2); auto a = af::constant(3, shape, af::dtype::f16); auto b = af::constant(2, shape, af::dtype::f16); auto a32 = af::constant(3, shape, af::dtype::f32); auto b32 = af::constant(2, shape, af::dtype::f32); auto a64 = af::constant(3, shape, af::dtype::f64); auto b64 = af::constant(2, shape, af::dtype::f64); auto rem = a % b; auto rem32 = a32 % b32; auto rem64 = a64 % b64; auto neg_rem = -a % b; auto neg_rem32 = -a32 % b32; auto neg_rem64 = -a64 % b64; ASSERT_ARRAYS_EQ(af::constant(1, shape, af::dtype::f16), rem); ASSERT_ARRAYS_EQ(af::constant(1, shape, af::dtype::f32), rem32); ASSERT_ARRAYS_EQ(af::constant(1, shape, af::dtype::f64), rem64); ASSERT_ARRAYS_EQ(af::constant(-1, shape, af::dtype::f16), neg_rem); ASSERT_ARRAYS_EQ(af::constant(-1, shape, af::dtype::f32), neg_rem32); ASSERT_ARRAYS_EQ(af::constant(-1, shape, af::dtype::f64), neg_rem64); ASSERT_ARRAYS_EQ(rem32.as(f16), rem); }