Search code examples
pythontensorflowmachine-learningkeraslstm

Making One Step Forecast Predictions


I'm making a LSTM model and I'm training it on a TSLA data set I found on kaggle. So my question is when I call model.predict does this prediction give me the price of the stock for the next day? And is this a one- step forecast? And when I print the model.predict I get a huge list so I use the numpy argmax function to give me a number. Here is the code:

import tensorflow as tf 
from tensorflow.keras.layers import LSTM, Dense, Dropout, Input, GlobalMaxPooling1D
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.optimizers import Adam


df = pd.read_csv('TSLA.csv')

series = df['Close'].values.reshape(-1, 1)

scaler = StandardScaler()
scaler.fit(series[:len(series)//2])
series = scaler.transform(series).flatten()

X = []
Y = []
T = 10
D = 1

for t in range(len(series) - T):

    X.append(series[t:t+T])
    Y.append(series[t+T])

X = np.array(X).reshape(-1, T, D)
Y = np.array(Y)
N = len(X)

print(X.shape, Y.shape)

model = tf.keras.Sequential([
    Input(shape=(T, D)),
    LSTM(50),
    Dense(100, activation='relu'),
    Dropout(0.25),
    Dense(1)
])

model.compile(optimizer=Adam(lr=0.01), loss='mse')
r = model.fit(X[:-N//2], Y[:-N//2], validation_data=(X[-N//2:], Y[-N//2:]), epochs=200)

plt.plot(r.history['loss'])
plt.plot(r.history['val_loss'])
plt.show()

preds = model.predict(X)
outs = preds[:,0]

print(outs)
print(np.argmax(outs))

Solution

  • Argmax doesn't make sense here. The 90 values are the 90 predictions for the next day of the training set. When you run this:

    preds = model.predict(X)
    

    It gives you the next day value for all 90 data points of your train set. This line:

    print(np.argmax(outs))
    

    Makes no sense.

    By the way, you can get stock prices with Python, you don't need a CSV.

    pip install pandas-datareader
    
    from pandas_datareader import data as wb
    
    ticker=wb.DataReader('TSLA',start='2015-1-1',data_source='yahoo')
    print(ticker)
    
                      High         Low  ...      Volume   Adj Close
    Date                                ...                        
    2015-01-02   44.650002   42.652000  ...  23822000.0   43.862000
    2015-01-05   43.299999   41.431999  ...  26842500.0   42.018002
    2015-01-06   42.840000   40.841999  ...  31309500.0   42.256001
    2015-01-07   42.956001   41.956001  ...  14842000.0   42.189999
    2015-01-08   42.759998   42.001999  ...  17212500.0   42.124001
                    ...         ...  ...         ...         ...
    2020-10-13  448.890015  436.600006  ...  34463700.0  446.649994
    2020-10-14  465.899994  447.350006  ...  48045400.0  461.299988
    2020-10-15  456.570007  442.500000  ...  35672400.0  448.880005
    2020-10-16  455.950012  438.850006  ...  32620000.0  439.670013
    2020-10-19  447.000000  437.649994  ...   9422697.0  442.840607