Search code examples
pythonbooleangeneratorbayesiangenerative-adversarial-network

Why this boolean is in this bayes classifier? (Python question?)


I'm studying GANs (and I'm a beginner in python) and I found this part of the code in the previous exercises that I don't understand. Concretely I don't understand why is used the boolean of the 9th line (Xk = X[Y == k]) for the reasons that I write down below

class BayesClassifier:
  def fit(self, X, Y):
    # assume classes are numbered 0...K-1
    self.K = len(set(Y))

    self.gaussians = []
    self.p_y = np.zeros(self.K)
    for k in range(self.K):
      Xk = X[Y == k]
      self.p_y[k] = len(Xk)
      mean = Xk.mean(axis=0)
      cov = np.cov(Xk.T)
      g = {'m': mean, 'c': cov}
      self.gaussians.append(g)
    # normalize p(y)
    self.p_y /= self.p_y.sum()
  1. That boolean return a 0 or a 1 depending on the trueness of the Y == k, and for that reason always Xk will be the first or the second value of the X list. Y don't find the utility of that.
  2. In the 10th line, len(Xk) always will be 1, why does it use that argument instead of a single 1?
  3. The mean and covariance of the next lines are calculated only with one value each time.

I feel that I'm not understanding something very basic.


Solution

  • You should take into account that X, Y, k are NumPy arrays, not scalars, and some operators are overloaded for them. Particularly, == and Boolean-based indexing. == will be element-wise comparison, not the whole array comparison.

    See how it works:

    In [9]: Y = np.array([0,1,2])                                                                                        
    In [10]: k = np.array([0,1,3])                                                                                       
    In [11]: Y==k                                                                                                        
    
    Out[11]: array([ True,  True, False])
    

    So, the result of == is a Boolean array.

    In [12]: X=np.array([0,2,4])                                                                                         
    In [13]: X[Y==k]                                                                                                     
    
    Out[13]: array([0, 2])
    

    The result is an array with elements selected from X when the condition is True

    Hence len(Xk) will be the number of matched elements between X and k.