Search code examples
pythonoptimizationopenclfftpyopencl

Optimization of PyOpenCL program / FFT


General Overview of Program: The majority of the code here creates the FrameProcessor object. This object is initialized with some data shape, generally 2048xN, and can then be called to process the data using a series of kernels (proc_frame). For each vector of length 2048 the program will:

  1. Apply a Hanning window (elementwise multiplication 2048*2048)
  2. Do a linear interpolation to remap values (to map to linear-in-wavenumber space from non-linear spectrometer bins which signal is derived from--not too important of a detail but I figured it would be good to include in case it was unclear)
  3. Apply an FFT

Problem: I want to go faster! The code below is not performing poorly, but for this project I need it to be as fast as it can possibly be. However, I am unsure on how I might make further improvements to this code. So, I'm looking for suggestions on relevant reading, alternate libraries which I should use, changes to code structure, etc.

Current Performance: On my rig with a GeForce RTX 2080 the benchmarks I get (with n=60, which seems to give best performance) are:

With n = 60 
Average framerate over 1000 frames: 740Hz
Effective A-line rate over 1000 frames: 44399Hz

It seems that the FFT is a large bottleneck here. When I run the example without running the FFT I get these results:

With n = 60 
Average framerate over 1000 frames: 2494Hz
Effective A-line rate over 1000 frames: 149652Hz

However, I don't know how to improve upon the performance of the Reikna FFT plan I'm using! The docs don't seem to mention any steps for optimization and I've had even worse performance using gpyfft (code in the test-gpyfft branch of github repo).

Profiling: Results of using cProfile on the proc_frame function are shown below:

Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(reshape)
        4    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
        1    0.000    0.000    0.000    0.000 <generated code>:101(set_args)
        2    0.000    0.000    0.000    0.000 <generated code>:4(enqueue_knl_kernel_fft)
        1    0.000    0.000    0.000    0.000 <generated code>:71(set_args)
        1    0.000    0.000    0.002    0.002 <string>:1(<module>)
        3    0.000    0.000    0.000    0.000 __init__.py:1288(result)
        1    0.000    0.000    0.000    0.000 __init__.py:1294(result)
        3    0.000    0.000    0.001    0.000 __init__.py:1522(enqueue_copy)
        1    0.000    0.000    0.000    0.000 __init__.py:222(wrap_in_tuple)
       10    0.000    0.000    0.000    0.000 __init__.py:277(name)
        3    0.000    0.000    0.000    0.000 __init__.py:281(default)
        2    0.000    0.000    0.000    0.000 __init__.py:285(annotation)
        9    0.000    0.000    0.000    0.000 __init__.py:289(kind)
        1    0.000    0.000    0.000    0.000 __init__.py:375(__init__)
        2    0.000    0.000    0.000    0.000 __init__.py:575(wrapper)
        2    0.000    0.000    0.000    0.000 __init__.py:596(parameters)
        1    0.000    0.000    0.000    0.000 __init__.py:659(_bind)
        1    0.000    0.000    0.000    0.000 __init__.py:787(bind)
        2    0.000    0.000    0.000    0.000 __init__.py:833(kernel_set_args)
        2    0.000    0.000    0.000    0.000 __init__.py:837(kernel_call)
        1    0.000    0.000    0.000    0.000 _asarray.py:16(asarray)
        1    0.000    0.000    0.000    0.000 _internal.py:830(npy_ctypes_check)
        1    0.000    0.000    0.000    0.000 abc.py:137(__instancecheck__)
        1    0.000    0.000    0.000    0.000 api.py:376(empty_like)
        1    0.000    0.000    0.000    0.000 api.py:405(to_device)
        3    0.000    0.000    0.000    0.000 api.py:466(_synchronize)
        2    0.000    0.000    0.000    0.000 api.py:678(prepared_call)
        2    0.000    0.000    0.000    0.000 api.py:688(__call__)
        2    0.000    0.000    0.000    0.000 api.py:779(__call__)
        2    0.000    0.000    0.000    0.000 array.py:1474(add_event)
        1    0.000    0.000    0.000    0.000 array.py:28(f_contiguous_strides)
        1    0.000    0.000    0.000    0.000 array.py:38(c_contiguous_strides)
        1    0.000    0.000    0.000    0.000 array.py:393(__init__)
        3    0.000    0.000    0.000    0.000 array.py:48(equal_strides)
        1    0.000    0.000    0.000    0.000 array.py:520(flags)
        1    0.000    0.000    0.000    0.000 array.py:580(set)
        1    0.000    0.000    0.000    0.000 array.py:59(is_f_contiguous_strides)
        1    0.000    0.000    0.000    0.000 array.py:61(_dtype_is_object)
        1    0.000    0.000    0.000    0.000 array.py:63(is_c_contiguous_strides)
        1    0.000    0.000    0.001    0.001 array.py:635(_get)
        1    0.000    0.000    0.000    0.000 array.py:68(__init__)
        1    0.000    0.000    0.001    0.001 array.py:689(get)
        1    0.000    0.000    0.000    0.000 computation.py:620(__call__)
        2    0.000    0.000    0.000    0.000 computation.py:641(__call__)
        1    0.000    0.000    0.000    0.000 dtypes.py:75(normalize_type)
        1    0.000    0.000    0.001    0.001 frameprocessor.py:130(FFT)
        1    0.000    0.000    0.001    0.001 frameprocessor.py:137(interp_hann)
        1    0.000    0.000    0.002    0.002 frameprocessor.py:146(proc_frame)
        1    0.000    0.000    0.000    0.000 frameprocessor.py:20(npcast)
        1    0.000    0.000    0.000    0.000 frameprocessor.py:23(rshp)
        1    0.000    0.000    0.000    0.000 fromnumeric.py:197(_reshape_dispatcher)
        1    0.000    0.000    0.000    0.000 fromnumeric.py:202(reshape)
        1    0.000    0.000    0.000    0.000 fromnumeric.py:55(_wrapfunc)
        1    0.000    0.000    0.000    0.000 ocl.py:109(allocate)
        1    0.000    0.000    0.000    0.000 ocl.py:112(_copy_array)
        2    0.000    0.000    0.000    0.000 ocl.py:223(_prepared_call)
        2    0.000    0.000    0.000    0.000 ocl.py:225(<listcomp>)
        1    0.000    0.000    0.000    0.000 ocl.py:28(__init__)
        1    0.000    0.000    0.001    0.001 ocl.py:63(get)
        1    0.000    0.000    0.000    0.000 ocl.py:88(array)
        1    0.000    0.000    0.000    0.000 signature.py:308(bind_with_defaults)
        1    0.000    0.000    0.000    0.000 {built-in method _abc._abc_instancecheck}
        1    0.000    0.000    0.002    0.002 {built-in method builtins.exec}
        4    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
       12    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
       37    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       14    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        6    0.000    0.000    0.000    0.000 {built-in method builtins.next}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.setattr}
        1    0.000    0.000    0.000    0.000 {built-in method numpy.array}
        1    0.000    0.000    0.000    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
        1    0.000    0.000    0.000    0.000 {built-in method numpy.empty}
        2    0.001    0.001    0.001    0.001 {built-in method pyopencl._cl._enqueue_read_buffer}
        1    0.000    0.000    0.000    0.000 {built-in method pyopencl._cl._enqueue_write_buffer}
        4    0.000    0.000    0.000    0.000 {built-in method pyopencl._cl.enqueue_nd_range_kernel}
        2    0.000    0.000    0.000    0.000 {built-in method pyopencl._cl.get_cl_header_version}
        6    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'astype' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        3    0.000    0.000    0.000    0.000 {method 'pop' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'reshape' of 'numpy.ndarray' objects}
        2    0.000    0.000    0.000    0.000 {method 'values' of 'mappingproxy' objects}

Code: The code as well as the supplementary files can be accessed here: https://github.com/mswallac/PyMotionOCT and is also shown below for convenience.

# -*- coding: utf-8 -*-
"""
Created on Wed Jan 15 10:17:16 2020

@author: Mike
"""

import numpy as np
import pyopencl as cl
from pyopencl import cltypes
from pyopencl import array
from reikna.fft import FFT
from reikna import cluda
import time
import matplotlib.pyplot as plt

class FrameProcessor():
    
    def npcast(self,inp,dt):
        return np.asarray(inp).astype(dt)
    
    def rshp(self,inp,shape):
        return np.reshape(inp,shape,'C')
    
    def __init__(self,nlines):
        
        # Define data formatting
        n = nlines # number of A-lines per frame
        alen = 2048 # length of A-line / # of spec. bins
        self.dshape = (alen*n,)
        self.dt_prefft = np.float32
        self.dt_fft = np.complex64
        self.data_prefft = self.npcast(np.zeros(self.dshape),self.dt_prefft)
        self.data_fft = self.npcast(np.zeros(self.dshape),self.dt_fft)
        
        # Load spectrometer bins and prepare for interpolation / hanning operation
        hanning_win = self.npcast(np.hanning(2048),self.dt_prefft)
        lam = self.npcast(np.load('lam.npy'),self.dt_prefft)
        lmax = np.max(lam)
        lmin = np.min(lam)
        kmax = 1/lmin
        kmin = 1/lmax
        self.d_l = (lmax - lmin)/alen
        self.d_k = (kmax - kmin)/alen
        self.k_raw = self.npcast([1/x for x in (lam)],self.dt_prefft)
        self.k_lin = self.npcast([kmax-(i*self.d_k) for i in range(alen)],self.dt_prefft)
        
        # Find nearest neighbors for interpolation prep.
        nn0 = np.zeros((2048,),np.int32)
        nn1 = np.zeros((2048,),np.int32)
        for i in range(0,2048):
            res = np.abs(self.k_raw-self.k_lin[i])
            minind = np.argmin(res)
            if i==0:
                nn0[i]=0
                nn1[i]=1
            if res[minind]>=0:
                nn0[i]=minind-1
                nn1[i]=minind
            else:
                nn0[i]=minind
                nn1[i]=minind+1
            
        self.nn0=nn0
        self.nn1=nn1
        
        # Initialize PyOpenCL platform, device, context, queue
        self.platform = cl.get_platforms()
        self.platform = self.platform[0]
        self.device = self.platform.get_devices()
        self.device = self.device[0]
        self.context = cl.Context([self.device])
        self.queue = cl.CommandQueue(self.context)
        
        # POCL input buffers
        mflags = cl.mem_flags
        self.win_g = cl.Buffer(self.context, mflags.READ_ONLY | mflags.COPY_HOST_PTR, hostbuf=hanning_win)
        self.nn0_g = cl.Buffer(self.context, mflags.READ_ONLY | mflags.COPY_HOST_PTR, hostbuf=self.nn0)
        self.nn1_g = cl.Buffer(self.context, mflags.READ_ONLY | mflags.COPY_HOST_PTR, hostbuf=self.nn1)
        self.k_lin_g = cl.Buffer(self.context, mflags.READ_ONLY | mflags.COPY_HOST_PTR, hostbuf=self.k_lin)
        self.k_raw_g = cl.Buffer(self.context, mflags.READ_ONLY | mflags.COPY_HOST_PTR, hostbuf=self.k_raw)
        self.d_k_g = cl.Buffer(self.context, mflags.READ_ONLY | mflags.COPY_HOST_PTR, hostbuf=self.d_k)
        
        # POCL output buffers
        self.npres_interp = self.npcast(np.zeros(self.dshape),self.dt_prefft)
        self.npres_hann = self.npcast(np.zeros(self.dshape),self.dt_prefft)
        self.result_interp = cl.Buffer(self.context, cl.mem_flags.COPY_HOST_PTR, hostbuf=self.npres_interp)
        self.result_hann = cl.Buffer(self.context, cl.mem_flags.COPY_HOST_PTR, hostbuf=self.npres_hann)
        
        # Define POCL global / local work group sizes
        self.global_wgsize = (2048,n)
        self.local_wgsize = (512,1)
        
        # Initialize Reikna API, thread, FFT plan, output memory
        self.api = cluda.ocl_api()
        self.thr = self.api.Thread.create()
        self.result = self.npcast(np.zeros((2048,n)),self.dt_fft)
        self.fft = FFT(self.result,axes=(0,)).compile(self.thr)
        
        # kernels for hanning window, and interpolation
        self.program = cl.Program(self.context, """
        __kernel void hann(__global float *inp, __global const float *win, __global float *res)
        {
            int i = get_global_id(0)+(get_global_size(0)*get_global_id(1));
            int j = get_local_id(0)+(get_group_id(0)*get_local_size(0));
            res[i] = inp[i]*win[j];
        }
        
        __kernel void interp(__global float *y,__global const int *nn0,__global const int *nn1,
                             __global const float *k_raw,__global const float *k_lin,__global float *res)
        {
            int i_shift = (get_global_size(0)*get_global_id(1));
            int i_glob = get_global_id(0)+i_shift;
            int i_loc = get_local_id(0)+(get_group_id(0)*get_local_size(0));
            float x1 = k_raw[nn0[i_loc]];
            float x2 = k_raw[nn1[i_loc]];
            float y1 = y[i_shift+nn0[i_loc]];
            float y2 = y[i_shift+nn1[i_loc]];
            float x = k_lin[i_loc];
            res[i_glob]=y1+((x-x1)*((y2-y1)/(x2-x1))); 
        }
        """).build()
        
        self.hann = self.program.hann
        self.interp = self.program.interp
    
    # Wraps FFT kernel
    def FFT(self,data):
        inp = self.thr.to_device(self.npcast(data,self.dt_fft))
        self.fft(inp,inp,inverse=0)
        self.result = inp.get()
        return
    
    # Wraps interpolation and hanning window kernels
    def interp_hann(self,data):
        self.data_pfg = cl.Buffer(self.context, cl.mem_flags.COPY_HOST_PTR, hostbuf=data)
        self.hann.set_args(self.data_pfg,self.win_g,self.result_hann)
        cl.enqueue_nd_range_kernel(self.queue,self.hann,self.global_wgsize,self.local_wgsize)
        self.interp.set_args(self.result_hann,self.nn0_g,self.nn1_g,self.k_raw_g,self.k_lin_g,self.result_interp)
        cl.enqueue_nd_range_kernel(self.queue,self.interp,self.global_wgsize,self.local_wgsize)
        cl.enqueue_copy(self.queue,self.npres_interp,self.result_interp)
        return
    
    def proc_frame(self,data):
        self.interp_hann(data)
        self.FFT(self.rshp(self.npres_interp,(2048,n)))
        return self.result
        
if __name__ == '__main__':
    n=60
    fp = FrameProcessor(n)
    data1 = np.load('data.npy').flatten()
    times = []
    data = fp.npcast(data1[0:2048*n],fp.dt_prefft)
    for i in range(5000):
        t=time.time()
        res = fp.proc_frame(data)
        times.append(time.time()-t)
    res = np.reshape(res,(2048,n),'C')
    avginterval = np.mean(times)
    frate=(1/avginterval)
    afrate=frate*n
    print('With n = %d '%n)
    print('Average framerate over 1000 frames: %.0fHz'%frate)
    print('Effective A-line rate over 1000 frames: %.0fHz'%afrate)

Edit: updated code and benchmark Edit 2: added cProfile result


Solution

  • Copying my reply in the Reikna group for reference.

    1. Create a reikna Thread object from whatever pyopencl queue you want it to use (probably the one associated with the arrays you want to pass to FFT)
    2. Create an FFT computation based on this Thread
    3. Pass your pyopencl arrays to it without any conversion. (you can create a reikna array based on the buffer from a pyopencl array, by passing it as base_data keyword, but if using FFT is all you need, that is not necessary).

    Reikna threads are wrappers on top of pyopencl context + queue, and reikna arrays are subclasses of pyopencl arrays, so the interop should be pretty simple.

    Applying this (in a quick and dirty way, feel free to improve), I get: https://gist.github.com/fjarri/f781d3695b7c6678856110cced95be40 . Basically, the changes are:

    • creating a Thread out of the existing queue (self.thr = self.api.Thread(self.queue))
    • using the PyOpenCL buffer in FFT without copying it to CPU.

    The results I get:

    $ python frameprocessor.py # original version
    With n = 60 
    Average framerate over 1000 frames: 434Hz
    Effective A-line rate over 1000 frames: 26012Hz
    $ python frameprocessor2.py # modified version
    With n = 60 
    Average framerate over 1000 frames: 2191Hz
    Effective A-line rate over 1000 frames: 131478Hz