-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcomplex.h
60 lines (49 loc) · 2.98 KB
/
complex.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Complex numbers with custom real types
#pragma once
#include "arith.h"
#include <iostream>
#include <cmath>
namespace mandelbrot {
using std::ostream;
// We need our own complex template to handle custom scalars.
// Ours has very few operations.
template<class S> struct Complex {
typedef S Real;
S r, i;
__host__ __device__ Complex() : r(0), i(0) {}
__host__ __device__ Complex(const int r) : r(r), i(0) {}
__host__ __device__ explicit Complex(const S& r) : r(r), i(0) {}
__host__ __device__ Complex(const S& r, const S& i) : r(r), i(i) {}
__host__ __device__ Complex operator-() const { return Complex(-r, -i); }
__host__ __device__ Complex operator+(const Complex& z) const { return Complex(r + z.r, i + z.i); }
__host__ __device__ Complex operator-(const Complex& z) const { return Complex(r - z.r, i - z.i); }
__host__ __device__ void operator+=(const Complex& z) { r += z.r; i += z.i; }
__host__ __device__ void operator-=(const Complex& z) { r -= z.r; i -= z.i; }
__host__ __device__ Complex operator*(const Complex& z) const { return Complex(r*z.r - i*z.i, r*z.i + i*z.r); }
__host__ __device__ friend Complex operator*(const S a, const Complex& z) { return Complex(a*z.r, a*z.i); }
__host__ __device__ friend Complex sqr(const Complex& z) { return Complex(sqr(z.r) - sqr(z.i), twice(z.r*z.i)); }
__host__ __device__ friend Complex conj(const Complex& z) { return Complex(z.r, -z.i); }
__host__ __device__ friend Complex left(const Complex& z) { return Complex(-z.i, z.r); } // iz
__host__ __device__ friend Complex right(const Complex& z) { return Complex(z.i, -z.r); } // -iz
__host__ __device__ friend Complex twice(const Complex& z) { return Complex(twice(z.r), twice(z.i)); }
__host__ __device__ friend Complex half(const Complex& z) { return Complex(half(z.r), half(z.i)); }
__host__ __device__ friend Complex ldexp(const Complex& z, int e) { return Complex(ldexp(z.r, e), ldexp(z.i, e)); }
__host__ __device__ friend Complex hadamard(const Complex& z, const Complex& w) { return Complex(z.r*w.r, z.i*w.i); }
__host__ __device__ friend Complex hadamard_sqr(const Complex& z) { return Complex(sqr(z.r), sqr(z.i)); }
bool operator==(const Complex z) const { return r == z.r && i == z.i; }
friend ostream& operator<<(ostream& out, const Complex& z) {
out << z.r;
if (copysign(S(1), z.i) > 0) out << '+';
return out << z.i << 'j';
}
};
// Diagonal complex scaling by a(1 +- i). 2 adds, 2 muls.
template<int sign, class S> __host__ __device__ static inline Complex<S> diag(const S& a, const Complex<S>& z) {
static_assert(sign == 1 || sign == -1);
if constexpr (sign == 1) return a * Complex<S>(z.r - z.i, z.i + z.r);
else return a * Complex<S>(z.r + z.i, z.i - z.r);
}
static inline double abs(const Complex<double> z) { return hypot(z.r, z.i); }
static inline double sqr_abs(const Complex<double> z) { return sqr(z.r) + sqr(z.i); }
static inline Complex<double> inv(const Complex<double> z) { return inv(sqr_abs(z)) * conj(z); }
} // namespace mandelbrot