# Copyright (C) 2024 Y Ddraig Goch
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# =============================================================================
#
# Preamble
#
# =============================================================================
#
"""Cupy based CPU backend for PyCBC Array
"""
import cupy as cp
from pycbc.types.array import common_kind, complex128, float64
[docs]
def zeros(length, dtype=cp.float64):
return cp.zeros(length, dtype=dtype)
[docs]
def empty(length, dtype=cp.float64):
return cp.empty(length, dtype=dtype)
[docs]
def ptr(self):
return self.data.data.mem.ptr
[docs]
def dot(self, other):
return cp.dot(self._data,other)
[docs]
def min(self):
return self.data.min()
[docs]
def abs_max_loc(self):
if self.kind == 'real':
tmp = abs(self.data)
ind = cp.argmax(tmp)
return tmp[ind], ind
else:
tmp = self.data.real ** 2.0
tmp += self.data.imag ** 2.0
ind = cp.argmax(tmp)
return tmp[ind] ** 0.5, ind
[docs]
def cumsum(self):
return self.data.cumsum()
[docs]
def max(self):
return self.data.max()
[docs]
def max_loc(self):
ind = cp.argmax(self.data)
return self.data[ind], ind
[docs]
def take(self, indices):
return self.data.take(indices)
[docs]
def weighted_inner(self, other, weight):
""" Return the inner product of the array with complex conjugation.
"""
if weight is None:
return self.inner(other)
cdtype = common_kind(self.dtype, other.dtype)
if cdtype.kind == 'c':
acum_dtype = complex128
else:
acum_dtype = float64
return cp.sum(self.data.conj() * other / weight, dtype=acum_dtype)
[docs]
def abs_arg_max(self):
if self.dtype == cp.float32 or self.dtype == cp.float64:
return cp.argmax(abs(self.data))
else:
return abs_arg_max_complex(self._data)
[docs]
def inner(self, other):
""" Return the inner product of the array with complex conjugation.
"""
cdtype = common_kind(self.dtype, other.dtype)
if cdtype.kind == 'c':
return cp.sum(self.data.conj() * other, dtype=complex128)
else:
return inner_real(self.data, other)
[docs]
def vdot(self, other):
""" Return the inner product of the array with complex conjugation.
"""
return cp.vdot(self.data, other)
[docs]
def squared_norm(self):
""" Return the elementwise squared norm of the array """
return (self.data.real**2 + self.data.imag**2)
def numpy(self):
return cp.asnumpy(self.data)
def _copy(self, self_ref, other_ref):
self_ref[:] = other_ref[:]
def _getvalue(self, index):
return self._data[index]
[docs]
def sum(self):
if self.kind == 'real':
return cp.sum(self._data,dtype=float64)
else:
return cp.sum(self._data,dtype=complex128)
[docs]
def clear(self):
self[:] = 0
def _scheme_matches_base_array(array):
if isinstance(array, cp.ndarray):
return True
else:
return False
def _to_device(array):
return cp.asarray(array)
[docs]
def numpy(self):
return cp.asnumpy(self._data)
def _copy_base_array(array):
return array.copy()