Search code examples
machine-learningneural-networkjuliaflux

Error training using Flux.train! in Julia


Problem

I have the below code trying to train a neural network in Julia but I am getting an error whenever I try using Flux.train!

using Flux, CSV, DataFrames, Random, Statistics, Plots

# Load the data
data = CSV.read("daily_stock_returns.csv")

# Split the data into training and testing sets
train_data = data[1:800, :]
test_data = data[801:end, :]

# Define the input and output variables
inputs = Matrix(train_data[:, 2:end])
outputs = Matrix(train_data[:, 1])

# Define the neural network architecture
n_inputs = size(inputs, 2)
n_hidden = 10
n_outputs = 1
model = Chain(
    Dense(n_inputs, n_hidden, relu),
    Dense(n_hidden, n_outputs)
)

# Define the loss function
loss(x, y) = Flux.mse(model(x), y)

# Define the optimizer
optimizer = ADAM()

# Train the model
n_epochs = 100
for epoch in 1:n_epochs
    Flux.train!(loss, params(model), [(inputs, outputs)], optimizer)
end

# Test the model
test_inputs = Matrix(test_data[:, 2:end])
test_outputs = Matrix(test_data[:, 1])
predictions = Flux.predict(model, test_inputs)
test_loss = Flux.mse(predictions, test_outputs)

# Print the test loss
println("Test loss: $test_loss")

Error

When I run this cell of code:

# Train the model
n_epochs = 100
for epoch in 1:n_epochs
    Flux.train!(loss, params(model), [(inputs, outputs)], optimizer)
end

I get the below error.

UndefVarError: params not defined

Stacktrace:
 [1] top-level scope
   @ .\In[16]:4
 [2] eval
   @ .\boot.jl:368 [inlined]
 [3] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base .\loading.jl:1428

I have tried uninstalling and installing with no improvement. How do I fix this error?


Solution

  • It's because Fux.params (see the docs) isn't exported by default. You either replace params by Flux.params (which is how it's often done in the docs, as far as I recall), or replace using Flux by using Flux: params (the second option being probably less ideal because it adds a name that is really generic to your namespace).