Search code examples
pythontensorflowmachine-learningkerasfeature-extraction

How to combine features extracted from two cnn models?


i have two cnn models both follow same architecture. I trained 'train set 1' on cnn1 and 'train set 2; on cnn2.Then i exracted features using following code.

# cnn1

    model.pop() #removes softmax layer
    model.pop() #removes dropoutlayer
    model.pop() #removes activation layer
    model.pop() #removes batch-norm layer
    model.build() #here lies dense 512
    features1 = model.predict(train set 1)
    print(features1.shape) #600,512

# cnn2

    model.pop() #removes softmax layer
    model.pop() #removes dropoutlayer
    model.pop() #removes activation layer
    model.pop() #removes batch-norm layer
    model.build() #here lies dense 512
    features2 = model.predict(train set 2)
    print(features2.shape) #600,512

How to combine these feature 1 and feature 2, so that output shape is 600,1024?


Solution

  • SIMPLEST SOLUTION:

    you can simply concatenate the output of the two networks in this way:

    features = np.concatenate([features1, features2], 1)
    

    ALTERNATIVE:

    given two trained models that have the same structure, whatever their structures are, you can combine them in this way

    # generate dummy data
    n_sample = 600
    set1 = np.random.uniform(0,1, (n_sample,30))
    set2 = np.random.uniform(0,1, (n_sample,30))
    
    # model 1
    inp1 = Input((30,))
    x1 = Dense(512,)(inp1)
    x1 = Dropout(0.3)(x1)
    x1 = BatchNormalization()(x1)
    out1 = Dense(3, activation='softmax')(x1)
    m1 = Model(inp1, out1)
    # m1.fit(...)
    
    # model 2
    inp2 = Input((30,))
    x2 = Dense(512,)(inp2)
    x2 = Dropout(0.3)(x2)
    x2 = BatchNormalization()(x2)
    out2 = Dense(3, activation='softmax')(x2)
    m2 = Model(inp2, out2)
    # m2.fit(...)
    
    # concatenate the desired output
    concat = Concatenate()([m1.layers[1].output, m2.layers[1].output]) # get the outputs of dense 512 layers
    merge = Model([m1.input, m2.input], concat)
    
    # make combined predictions
    merge.predict([set1,set2]).shape  # (n_sample, 1024)