Search code examples
pythonscikit-learnlogistic-regression

Fit a logistic regression for each group using scikit-learn in python


I am trying to follow this post, but instead of fitting a bivariate linear model, I want to fit a logistic regression with many more X vars, as a basic example:

df = pd.DataFrame({'group': [1,2,3,4,5,6], 
               'var1': [9,5,3,1,2,3],
              'var2': [3,4,1,6,4,9], 
              'var3': [3,5,2,8,4,4]
              'closed': [0,1,1,1,0,0]
              'date': [2020, 2020, 2021, 2022, 2021, 2021]

              })

def GroupRegress(data, yvar, xvars):
    Y = data[yvar]
    X = data[xvars]
    #X['intercept'] = 1.
    result = LogisticRegression(Y, X).fit()
    return result.params


df.groupby('group').apply(GroupRegress, 'closed', ['X'])

This results in an error, "TypeError: init() takes from 1 to 2 positional arguments but 3 were given" but more importantly, I don't just have four X variables, but closer to 20. Is there a way to fit a model for each group ID and many X vars?

My non-grouped model looks like this:

X = df.drop('closed', axis=1)
X['date']=X['date'].map(dt.datetime.toordinal)

y = df['closed']

X_train, X_test, y_train, y_test = tts(X, y, test_size=0.20)

model = LogisticRegression

model.fit(X_train, y_train)

Solution

  • You have the arguments (X, y) in the wrong place -- they must be given to the fit() function on the LogisticRegression() object. Also, you must have at least one example of each class in your data, so I have modified your example 'group' and 'closed' variables.

    df = pd.DataFrame(
        {
            "group": [1, 1, 2, 2, 3, 3],
            "var1": [9, 5, 3, 1, 2, 3],
            "var2": [3, 4, 1, 6, 4, 9],
            "var3": [3, 5, 2, 8, 4, 4],
            "closed": [0, 1, 0, 1, 0, 1],
            "date": [2020, 2020, 2021, 2022, 2021, 2021],
        }
    )
    
    def GroupRegress(data, yvar, xvars):
        Y = data[yvar]
        X = data[xvars]
        # X['intercept'] = 1.
        result = LogisticRegression().fit(X,Y)
        return result
    
    df.groupby("group").apply(GroupRegress, "closed", ["var1","var2","var3"])