I am trying to figure out how to apply knn.reg
function to predict y (which in this case is the mpg
of the Auto
dataset) for a specific value of x (it's the 'horsepower' variable of the same dataset).
At first, I used a knn.reg
function to build a knn regression model with k=10, which looks like this:
#Preliminary setup
library(ISLR)
library(fastDummies)
library(leaps)
library(boot)
library(FNN)
library(caTools)
df<-Auto
df$origin <- as.factor(df$origin)
df <- dummy_cols(df, select_columns = "origin")
df <- df[,!(names(df) %in% c("name", "origin","origin_1"))]
#Attempted models
knn.model<-knn.reg(train=df$horsepower, y=df$mpg, k=10)
split<-sample.split(df$mpg, SplitRatio=0.8)
train=df[split,]
test=df[!split,]
knn.model<-knn.reg(train=train[c('horsepower')], test=test[c('horsepower')], y=df$mpg, k=10)
I've tried two models that either include or exclude test data that is split from the original data, but I think I would like to use the entire dataset as the training data.
After constructing these models, I tried to use predict() function to estimate the mpg of a vehicle when its horsepower is 200, which would look something like this:
mpg<-c(200)
predict(knn.model, newdata=mpg)
Problem with predict()
function, however, was that it gave out an error telling me that predict() can't be applied to class "knnRegCV."
I am unsure if I should use a function other than predict()
, or if the code I have is missing something essential. I'd appreciate any suggestions or comments that can help me address this issue. Massive thank you in advance!
The function predict()
does not have a method for the object that the knn.reg()
function returns, but you can easily use the test=
argument. Using your first knn.model
:
knn.reg(train=df$horsepower, test=200, y=df$mpg, k=10)
# Prediction:
# [1] 12.45
Since you have only one predictor, you need to create a data frame to estimate more then one value:
pred <- data.frame(horsepower=c(100, 150, 200, 250))
knn.reg(train=df$horsepower, test=pred, y=df$mpg, k=10)
# Prediction:
# [1] 17.90 14.50 12.45 12.90