Well, basically i want to know what does the fit() function does in general, but especially in the pieces of code down there.
Im taking the Machine Learning A-Z Course because im pretty new to Machine Learning (i just started). I know some basic conceptual terms, but not the technical part.
CODE1:
from sklearn.impute import SimpleImputer
missingvalues = SimpleImputer(missing_values = np.nan, strategy = 'mean', verbose = 0)
missingvalues = missingvalues.fit(X[:, 1:3])
X[:, 1:3] = missingvalues.transform(X[:, 1:3])
Some other example where I still have the doubt
CODE 2:
from sklearn.preprocessing import StandardScaler
sc_X = StandardScaler()
print(sc_X)
X_train = sc_X.fit_transform(X_train)
print(X_train)
X_test = sc_X.transform(X_test)
I think that if I know like the general use for this function and what exactly does in general, I'll be good to go. But certaily I'd like to know what is doing on that code
Here is also a nice check-up possibility: https://scikit-learn.org/stable/tutorial/basic/tutorial.html
The fit
-method is always to learn something in machine learning.
You normally have the following steps:
X_train
) with fit
X_test
) with predict
In your first example: missingvalues.fit(X[:, 1:3])
You are training SimpleImputer
based on your data X
where you are only using column 1,2,3
, with transform you used this training to overwrite this data.
In your second example: You are training StandardScaler
with X_train
and are using this training for both datasets X_train, X_test
, the StandardScaler learnes from X_train
that means if he learned that 10 has to be converted to 2, he will convert 10 to 2 in both sets X_train, X_test
.