Gradient Descent Algorithm via julia
This tutorial illustrate how to use julia to conduct a gradient descent algorithm
Tutorial
Algorithm
julia
1 Load Julia modules
using RDatasets
using DataFrames
= dataset("datasets", "mtcars") mtcars
32 rows × 12 columns (omitted printing of 4 columns)
Model | MPG | Cyl | Disp | HP | DRat | WT | QSec | |
---|---|---|---|---|---|---|---|---|
String | Float64 | Int64 | Float64 | Int64 | Float64 | Float64 | Float64 | |
1 | Mazda RX4 | 21.0 | 6 | 160.0 | 110 | 3.9 | 2.62 | 16.46 |
2 | Mazda RX4 Wag | 21.0 | 6 | 160.0 | 110 | 3.9 | 2.875 | 17.02 |
3 | Datsun 710 | 22.8 | 4 | 108.0 | 93 | 3.85 | 2.32 | 18.61 |
4 | Hornet 4 Drive | 21.4 | 6 | 258.0 | 110 | 3.08 | 3.215 | 19.44 |
5 | Hornet Sportabout | 18.7 | 8 | 360.0 | 175 | 3.15 | 3.44 | 17.02 |
6 | Valiant | 18.1 | 6 | 225.0 | 105 | 2.76 | 3.46 | 20.22 |
7 | Duster 360 | 14.3 | 8 | 360.0 | 245 | 3.21 | 3.57 | 15.84 |
8 | Merc 240D | 24.4 | 4 | 146.7 | 62 | 3.69 | 3.19 | 20.0 |
9 | Merc 230 | 22.8 | 4 | 140.8 | 95 | 3.92 | 3.15 | 22.9 |
10 | Merc 280 | 19.2 | 6 | 167.6 | 123 | 3.92 | 3.44 | 18.3 |
11 | Merc 280C | 17.8 | 6 | 167.6 | 123 | 3.92 | 3.44 | 18.9 |
12 | Merc 450SE | 16.4 | 8 | 275.8 | 180 | 3.07 | 4.07 | 17.4 |
13 | Merc 450SL | 17.3 | 8 | 275.8 | 180 | 3.07 | 3.73 | 17.6 |
14 | Merc 450SLC | 15.2 | 8 | 275.8 | 180 | 3.07 | 3.78 | 18.0 |
15 | Cadillac Fleetwood | 10.4 | 8 | 472.0 | 205 | 2.93 | 5.25 | 17.98 |
16 | Lincoln Continental | 10.4 | 8 | 460.0 | 215 | 3.0 | 5.424 | 17.82 |
17 | Chrysler Imperial | 14.7 | 8 | 440.0 | 230 | 3.23 | 5.345 | 17.42 |
18 | Fiat 128 | 32.4 | 4 | 78.7 | 66 | 4.08 | 2.2 | 19.47 |
19 | Honda Civic | 30.4 | 4 | 75.7 | 52 | 4.93 | 1.615 | 18.52 |
20 | Toyota Corolla | 33.9 | 4 | 71.1 | 65 | 4.22 | 1.835 | 19.9 |
21 | Toyota Corona | 21.5 | 4 | 120.1 | 97 | 3.7 | 2.465 | 20.01 |
22 | Dodge Challenger | 15.5 | 8 | 318.0 | 150 | 2.76 | 3.52 | 16.87 |
23 | AMC Javelin | 15.2 | 8 | 304.0 | 150 | 3.15 | 3.435 | 17.3 |
24 | Camaro Z28 | 13.3 | 8 | 350.0 | 245 | 3.73 | 3.84 | 15.41 |
25 | Pontiac Firebird | 19.2 | 8 | 400.0 | 175 | 3.08 | 3.845 | 17.05 |
26 | Fiat X1-9 | 27.3 | 4 | 79.0 | 66 | 4.08 | 1.935 | 18.9 |
27 | Porsche 914-2 | 26.0 | 4 | 120.3 | 91 | 4.43 | 2.14 | 16.7 |
28 | Lotus Europa | 30.4 | 4 | 95.1 | 113 | 3.77 | 1.513 | 16.9 |
29 | Ford Pantera L | 15.8 | 8 | 351.0 | 264 | 4.22 | 3.17 | 14.5 |
30 | Ferrari Dino | 19.7 | 6 | 145.0 | 175 | 3.62 | 2.77 | 15.5 |
⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
2 Julia Function for Gradient Descent
- learn_rate: the magnitude of the steps the algorithm takes along the slope of the MSE function
- conv_threshold: threshold for convergence of gradient descent n: number of iternations
- max_iter: maximum of iteration before the algorithm stopss
function gradientDesc(x, y, learn_rate, conv_threshold, n, max_iter)
β = rand(Float64, 1)[1]
α = rand(Float64, 1)[1]
ŷ = α .+ β .* x
MSE = sum((y .- ŷ).^2)/n
converged = false
iterations = 0
while converged == false
# Implement the gradient descent algorithm
β_new = β - learn_rate*((1/n)*(sum((ŷ .- y) .* x)))
α_new = α - learn_rate*((1/n)*(sum(ŷ .- y)))
α = α_new
β = β_new
ŷ = β.*x .+ α
MSE_new = sum((y.-ŷ).^2)/n
# decide on whether it is converged or not
if (MSE - MSE_new) <= conv_threshold
converged = true
println("Optimal intercept: $α; Optimal slope: $β")
end
iterations += 1
if iterations > max_iter
converged = true
println("Optimal intercept: $α; Optimal slope: $β")
end
end
end
gradientDesc (generic function with 1 method)
gradientDesc(mtcars[:,:Disp], mtcars[:,:MPG], 0.0000293, 0.001, 32, 2500000)
Optimal intercept: 29.599851506041713; Optimal slope: -0.0412151089535404
3 Compared to linear regression
using GLM
= lm(@formula(MPG ~ Disp), mtcars) linearRegressor
StatsModels.TableRegressionModel{LinearModel{GLM.LmResp{Vector{Float64}}, GLM.DensePredChol{Float64, LinearAlgebra.CholeskyPivoted{Float64, Matrix{Float64}}}}, Matrix{Float64}}
MPG ~ 1 + Disp
Coefficients:
───────────────────────────────────────────────────────────────────────────
Coef. Std. Error t Pr(>|t|) Lower 95% Upper 95%
───────────────────────────────────────────────────────────────────────────
(Intercept) 29.5999 1.22972 24.07 <1e-20 27.0884 32.1113
Disp -0.0412151 0.00471183 -8.75 <1e-09 -0.050838 -0.0315923
───────────────────────────────────────────────────────────────────────────