import ctypes, pycbc.libutils
from pycbc.types import zeros
from .core import _BaseFFT, _BaseIFFT
import pycbc.scheme as _scheme
lib = pycbc.libutils.get_ctypes_library('mkl_rt', [])
if lib is None:
raise ImportError
#MKL constants taken from mkl_df_defines.h
DFTI_FORWARD_DOMAIN = 0
DFTI_DIMENSION = 1
DFTI_LENGTHS = 2
DFTI_PRECISION = 3
DFTI_FORWARD_SCALE = 4
DFTI_BACKWARD_SCALE = 5
DFTI_NUMBER_OF_TRANSFORMS = 7
DFTI_COMPLEX_STORAGE = 8
DFTI_REAL_STORAGE = 9
DFTI_CONJUGATE_EVEN_STORAGE = 10
DFTI_PLACEMENT = 11
DFTI_INPUT_STRIDES = 12
DFTI_OUTPUT_STRIDES = 13
DFTI_INPUT_DISTANCE = 14
DFTI_OUTPUT_DISTANCE = 15
DFTI_WORKSPACE = 17
DFTI_ORDERING = 18
DFTI_TRANSPOSE = 19
DFTI_DESCRIPTOR_NAME = 20
DFTI_PACKED_FORMAT = 21
DFTI_COMMIT_STATUS = 22
DFTI_VERSION = 23
DFTI_NUMBER_OF_USER_THREADS = 26
DFTI_THREAD_LIMIT = 27
DFTI_COMMITTED = 30
DFTI_UNCOMMITTED = 31
DFTI_COMPLEX = 32
DFTI_REAL = 33
DFTI_SINGLE = 35
DFTI_DOUBLE = 36
DFTI_COMPLEX_COMPLEX = 39
DFTI_COMPLEX_REAL = 40
DFTI_REAL_COMPLEX = 41
DFTI_REAL_REAL = 42
DFTI_INPLACE = 43
DFTI_NOT_INPLACE = 44
DFTI_ORDERED = 48
DFTI_BACKWARD_SCRAMBLED = 49
DFTI_ALLOW = 51
DFTI_AVOID = 52
DFTI_NONE = 53
DFTI_CCS_FORMAT = 54
DFTI_PACK_FORMAT = 55
DFTI_PERM_FORMAT = 56
DFTI_CCE_FORMAT = 57
mkl_domain = {'real': {'complex': DFTI_REAL},
'complex': {'real': DFTI_REAL,
'complex':DFTI_COMPLEX,
}
}
mkl_descriptor = {'single': lib.DftiCreateDescriptor_s_1d,
'double': lib.DftiCreateDescriptor_d_1d,
}
[docs]
def check_status(status):
""" Check the status of a mkl functions and raise a python exeption if
there is an error.
"""
if status:
lib.DftiErrorMessage.restype = ctypes.c_char_p
msg = lib.DftiErrorMessage(status)
raise RuntimeError(msg)
[docs]
def create_descriptor(size, idtype, odtype, inplace):
invec = zeros(1, dtype=idtype)
outvec = zeros(1, dtype=odtype)
desc = ctypes.c_void_p(1)
domain = mkl_domain[str(invec.kind)][str(outvec.kind)]
f = mkl_descriptor[invec.precision]
f.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_long]
status = f(ctypes.byref(desc), domain, size)
if inplace:
lib.DftiSetValue(desc, DFTI_PLACEMENT, DFTI_INPLACE)
else:
lib.DftiSetValue(desc, DFTI_PLACEMENT, DFTI_NOT_INPLACE)
nthreads = _scheme.mgr.state.num_threads
status = lib.DftiSetValue(desc, DFTI_THREAD_LIMIT, nthreads)
check_status(status)
lib.DftiSetValue(desc, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_CCS_FORMAT)
lib.DftiCommitDescriptor(desc)
check_status(status)
return desc
[docs]
def fft(invec, outvec, prec, itype, otype):
descr = create_descriptor(max(len(invec), len(outvec)), invec.dtype,
outvec.dtype, (invec.ptr == outvec.ptr))
f = lib.DftiComputeForward
f.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
status = f(descr, invec.ptr, outvec.ptr)
lib.DftiFreeDescriptor(ctypes.byref(descr))
check_status(status)
[docs]
def ifft(invec, outvec, prec, itype, otype):
descr = create_descriptor(max(len(invec), len(outvec)), invec.dtype,
outvec.dtype, (invec.ptr == outvec.ptr))
f = lib.DftiComputeBackward
f.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
status = f(descr, invec.ptr, outvec.ptr)
lib.DftiFreeDescriptor(ctypes.byref(descr))
check_status(status)
# Class based API
def _get_desc(fftobj):
desc = ctypes.c_void_p(1)
domain = mkl_domain[str(fftobj.invec.kind)][str(fftobj.outvec.kind)]
f = mkl_descriptor[fftobj.invec.precision]
f.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_long]
status = f(ctypes.byref(desc), domain, int(fftobj.size))
check_status(status)
# Now we set various things depending on exactly what kind of transform we're
# performing.
lib.DftiSetValue.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
# The following only matters if the transform is C2R or R2C
status = lib.DftiSetValue(desc, DFTI_CONJUGATE_EVEN_STORAGE,
DFTI_COMPLEX_COMPLEX)
check_status(status)
# In-place or out-of-place:
if fftobj.inplace:
status = lib.DftiSetValue(desc, DFTI_PLACEMENT, DFTI_INPLACE)
else:
status = lib.DftiSetValue(desc, DFTI_PLACEMENT, DFTI_NOT_INPLACE)
check_status(status)
# If we are performing a batched transform:
if fftobj.nbatch > 1:
status = lib.DftiSetValue(desc, DFTI_NUMBER_OF_TRANSFORMS, fftobj.nbatch)
check_status(status)
status = lib.DftiSetValue(desc, DFTI_INPUT_DISTANCE, fftobj.idist)
check_status(status)
status = lib.DftiSetValue(desc, DFTI_OUTPUT_DISTANCE, fftobj.odist)
check_status(status)
# Knowing how many threads will be allowed may help select a better transform
nthreads = _scheme.mgr.state.num_threads
status = lib.DftiSetValue(desc, DFTI_THREAD_LIMIT, nthreads)
check_status(status)
# Now everything's ready, so commit
status = lib.DftiCommitDescriptor(desc)
check_status(status)
return desc
[docs]
class FFT(_BaseFFT):
def __init__(self, invec, outvec, nbatch=1, size=None):
super(FFT, self).__init__(invec, outvec, nbatch, size)
self.iptr = self.invec.ptr
self.optr = self.outvec.ptr
self._efunc = lib.DftiComputeForward
self._efunc.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.desc = _get_desc(self)
[docs]
def execute(self):
self._efunc(self.desc, self.iptr, self.optr)
[docs]
class IFFT(_BaseIFFT):
def __init__(self, invec, outvec, nbatch=1, size=None):
super(IFFT, self).__init__(invec, outvec, nbatch, size)
self.iptr = self.invec.ptr
self.optr = self.outvec.ptr
self._efunc = lib.DftiComputeBackward
self._efunc.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.desc = _get_desc(self)
[docs]
def execute(self):
self._efunc(self.desc, self.iptr, self.optr)