Search code examples
pythonimagergbmse

PSNR/MSE calculation for two images


I want to write a function which is getting two images reference and encoded and evaluates the (R)MSE and PSNR for each component (R, G, B, Y, Cb, Cr). For that, I am extracting all components and then I am converting the RGB -> YCbCr. I want to calculate the (R)MSE and PSNR without using a built-in function.

import os, sys, subprocess, csv, datetime
from PIL import Image

############ Functions Definitions ############

# Extracts the values of the R, G, B components of each pixel in the input file and calculates the Y, Cb, Cr components returning a dictionary having a key tuple with the coordinates of
 the pixes and values the values of each R, G, B, Y, Cb, Cr components
def rgb_calc(ref_file):
  img = Image.open(ref_file)
  width, height = img.size
  print(width)
  print(height)
  rgb_dict = {}
  for x in range (width):
    for y in range(height):
      r, g, b = img.load()[x, y]
      lum = 0.299 * r + 0.587 * g + 0.114 * b
      cb = 128 - 0.168736 * r - 0.331264 * g + 0.5 * b
      cr = 128 + 0.5 * r - 0.418688 * g - 0.081312 * b
      print("X {} Y {} R {} G {} B {} Y {} Cb {} Cr {}".format(x, y, r, g, b, lum, cb, cr))
      rgb_dict[(x, y)] = (r, g, b, lum, cb, cr)
  return rgb_dict

############ MAIN FUNCTION ############

r_img = sys.argv[1]
p_img = sys.argv[2]

ref_img = Image.open(r_img)
proc_img = Image.open(p_img)

resolution_ref = ref_img.size
resolution_proc = proc_img.size

if resolution_ref == resolution_proc:
  ycbcr_ref = rgb_calc(r_img)
  ycbcr_proc = rgb_calc(proc_img)
else:
  exit(0)

I want to write a new function and eventually output the average PSNR for each component and an average for the whole image.

Is there a way to speed up my process?

Currently, the img.load() is taking around 10-11 seconds per 8Mpx image and the creation of the dictionary additional 6 seconds. So only extracting these values and creating two dictionaries is taking already 32 seconds.


Solution

  • First of all, do the img.load() outside the loop!

    def rgb_calc(ref_file):
      img = Image.open(ref_file)
      width, height = img.size
      print(width)
      print(height)
      rgb_dict = {}
      rgb = img.load()
      for x in range(width):
        for y in range(height):
          r, g, b = rgb[x, y]
          lum = 0.299 * r + 0.587 * g + 0.114 * b
          cb = 128 - 0.168736 * r - 0.331264 * g + 0.5 * b
          cr = 128 + 0.5 * r - 0.418688 * g - 0.081312 * b
          rgb_dict[(x, y)] = (r, g, b, lum, cb, cr)
      return rgb_dict
    

    But this is only the start. The next thing I would do (but I'm no expert!) is use a numpy array instead of a dict indexed by (x, y).


    EDIT

    I tried to speed things up using a numpy ndarray (N-dimensional array), but was stuck, so asked a specific question, and got the resolving answer (a ×15 speed-up!): numpy.ndarray with shape (height, width, n) from n values per Image pixel

    Here it is, adapted to your needs, and with some detail of your original code fixed:

    import numpy as np
    from PIL import Image
    
    def get_rgbycbcr(img: Image.Image):
        R, G, B = np.array(img).transpose(2, 0, 1)[:3]  # ignore alpha if present
        Y = 0.299 * R + 0.587 * G + 0.114 * B
        Cb = 128 - 0.168736 * R - 0.331264 * G + 0.5 * B
        Cr = 128 + 0.5 * R - 0.418688 * G - 0.081312 * B
        return np.array([R, G, B, Y, Cb, Cr], dtype=float).transpose(2, 1, 0)
    
    r_img = sys.argv[1]
    p_img = sys.argv[2]
    
    ref_img  = Image.open(r_img)
    proc_img = Image.open(p_img)
    
    resolution_ref  = ref_img.size
    resolution_proc = proc_img.size
    
    if resolution_ref == resolution_proc:
        ycbcr_ref  = get_ycbcr(ref_img) 
        ycbcr_proc = get_ycbcr(proc_img)
    else:
        exit(0)
    

    What you are left with now is a numpy array of shape (width, height, 6). I don't think you need the original RGB data in there (you can get it anytime from the image) – you can change the code reducing 6 to 3, in case. You can index, e.g., ycbcr_ref like this: ycbcr_ref[x, y] and get a list of length 6 containing the same data you had in tuples stored in a dictionary. But you can extract slices, specifically along this “axis” (numpy terminology) of length 6, and do operations on them, like

    y_mean = ycbcr_ref[:, :, 3].mean()
    

    It's absolutely worthwhile to learn how to use numpy!

    I'll help you with one detail: Unless you tell it otherwise, numpy stores data with the slowest varying index (AKA axis) first and the fastest varying last. Since images are stored by rows, unless you do a transpose() an image read into numpy will have to be indexed like arr[y, x]. Transposing will shuffle axes. In your case you have 3 axes numbered 0, 1, 2. E.g., .transpose(1, 0, 2) will exchange x and y, while .transpose(2, 0, 1) will make the pixel channels the “outer” (slowest varying) index.