-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLinearRegression.java
154 lines (134 loc) · 5.41 KB
/
LinearRegression.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package loadbalancer;
import java.util.List;
/*************************************************************************
* Compilation: javac LinearRegression.java
* Execution: java LinearRegression
*
* Compute least squares solution to y = beta * x + alpha.
* Simple linear regression.
*
* // TODO: rename beta and alpha to slope and intercept.
*
*************************************************************************/
/**
* The <tt>LinearRegression</tt> class performs a simple linear regression
* on an set of <em>N</em> data points (<em>y<sub>i</sub></em>, <em>x<sub>i</sub></em>).
* That is, it fits a straight line <em>y</em> = α + β <em>x</em>,
* (where <em>y</em> is the response variable, <em>x</em> is the predictor variable,
* α is the <em>y-intercept</em>, and β is the <em>slope</em>)
* that minimizes the sum of squared residuals of the linear regression model.
* It also computes associated statistics, including the coefficient of
* determination <em>R</em><sup>2</sup> and the standard deviation of the
* estimates for the slope and <em>y</em>-intercept.
*
* @author Robert Sedgewick
* @author Kevin Wayne
*/
public class LinearRegression {
private final int N;
private final double alpha, beta;
private final double R2;
private final double svar, svar0, svar1;
/**
* Performs a linear regression on the data points <tt>(y.get(i), x.get(i))</tt>.
* @param x the values of the predictor variable
* @param y the corresponding values of the response variable
* @throws java.lang.IllegalArgumentException if the lengths of the two arrays are not equal
*/
public LinearRegression(List<Double> x, List<Double> y) {
if (x.size() != y.size()) {
throw new IllegalArgumentException("array lengths are not equal");
}
N = x.size();
// first pass
double sumx = 0.0, sumy = 0.0, sumx2 = 0.0;
for (int i = 0; i < N; i++) sumx += x.get(i);
for (int i = 0; i < N; i++) sumx2 += x.get(i)*x.get(i);
for (int i = 0; i < N; i++) sumy += y.get(i);
double xbar = sumx / N;
double ybar = sumy / N;
// second pass: compute summary statistics
double xxbar = 0.0, yybar = 0.0, xybar = 0.0;
for (int i = 0; i < N; i++) {
xxbar += (x.get(i) - xbar) * (x.get(i) - xbar);
yybar += (y.get(i) - ybar) * (y.get(i) - ybar);
xybar += (x.get(i) - xbar) * (y.get(i) - ybar);
}
beta = xybar / xxbar;
alpha = ybar - beta * xbar;
// more statistical analysis
double rss = 0.0; // residual sum of squares
double ssr = 0.0; // regression sum of squares
for (int i = 0; i < N; i++) {
double fit = beta*x.get(i) + alpha;
rss += (fit - y.get(i)) * (fit - y.get(i));
ssr += (fit - ybar) * (fit - ybar);
}
int degreesOfFreedom = N-2;
R2 = ssr / yybar;
svar = rss / degreesOfFreedom;
svar1 = svar / xxbar;
svar0 = svar/N + xbar*xbar*svar1;
}
// public static void main(final String[] args) throws Exception {
// double[] x = {10.0, 20.0, 30.0, 40.0};
// double[] y = {1.0, 2.0, 3.0, 4.0};
// LinearRegression lr = new LinearRegression(x, y);
// System.out.println(lr.predict(50));
// }
/**
* Returns the <em>y</em>-intercept α of the best of the best-fit line <em>y</em> = α + β <em>x</em>.
* @return the <em>y</em>-intercept α of the best-fit line <em>y = α + β x</em>
*/
public double intercept() {
return alpha;
}
/**
* Returns the slope β of the best of the best-fit line <em>y</em> = α + β <em>x</em>.
* @return the slope β of the best-fit line <em>y</em> = α + β <em>x</em>
*/
public double slope() {
return beta;
}
/**
* Returns the coefficient of determination <em>R</em><sup>2</sup>.
* @return the coefficient of determination <em>R</em><sup>2</sup>, which is a real number between 0 and 1
*/
public double R2() {
return R2;
}
/**
* Returns the standard error of the estimate for the intercept.
* @return the standard error of the estimate for the intercept
*/
public double interceptStdErr() {
return Math.sqrt(svar0);
}
/**
* Returns the standard error of the estimate for the slope.
* @return the standard error of the estimate for the slope
*/
public double slopeStdErr() {
return Math.sqrt(svar1);
}
/**
* Returns the expected response <tt>y</tt> given the value of the predictor
* variable <tt>x</tt>.
* @param x the value of the predictor variable
* @return the expected response <tt>y</tt> given the value of the predictor
* variable <tt>x</tt>
*/
public double predict(double x) {
return beta*x + alpha;
}
/**
* Returns a string representation of the simple linear regression model.
* @return a string representation of the simple linear regression model,
* including the best-fit line and the coefficient of determination <em>R</em><sup>2</sup>
*/
public String toString() {
String s = "";
s += String.format("%.2f N + %.2f", slope(), intercept());
return s + " (R^2 = " + String.format("%.3f", R2()) + ")";
}
}