用Numpy实现线性回归
-
用Numpy实现线性回归
现在二维平面上有一系列点point,我们要找到一个一次函数y=wx+by=wx+by=wx+b,使得所有点到这条直线的距离平方和∑(wx+b−y)2\sum(wx+b-y)^2∑(wx+b−y)2最小
因此我们可以定义损失函数L=(wx+b−y)2\mathcal{L} = (wx+b-y)^2L=(wx+b−y)2,计算损失的代码如下:
# compute loss def compute_error_for_line_given_points(b, w, points): totalError = 0 for i in range(len(points)): x = points[i, 0] y = points[i, 1] totalError += (y - (w * x + b)) ** 2 return totalError / float(len(points)) # average
然后用梯度下降法更新www和bbb
$$
\begin{aligned}
w’ &= w - lr*\frac{\partial \mathcal{L}}{\partial w}\
b’ &= b - lr*\frac{\partial \mathcal{L}}{\partial b}
\end{aligned}
$$其中∂L∂w=2∗x∗(wx+b−y),∂L∂b=2∗(wx+b−y)\frac{\partial \mathcal{L}}{\partial w} = 2 * x * (wx + b - y),\frac{\partial \mathcal{L}}{\partial b} = 2 * (wx + b - y)∂w∂L=2∗x∗(wx+b−y),∂b∂L=2∗(wx+b−y)
# compute gradient def step_gradient(b_current, w_current, points, learningRate): b_gradient = 0 w_gradient = 0 N = float(len(points)) for i in range(len(points)): x = points[i, 0] y = points[i, 1] b_gradient += 2 * ((w_current * x) + b_current - y) w_gradient += 2 * x * ((w_current * x) + b_current - y) b_gradient = b_gradient / N w_gradient = w_gradient / N new_b = b_current - (learningRate * b_gradient) new_w = w_current - (learningRate * w_gradient) return [new_b, new_w]
最后只要设定迭代次数,不断的重复更新www和bbb就行了
def gradient_descent_runner(points, starting_b, starting_w, learning_rate, num_iterations): # num_iteration 迭代次数 b = starting_b w = starting_w for i in range(num_iterations): b, w = step_gradient(b, w, np.array(points), learning_rate) return [b, w]
主函数
def run(): points = np.genfromtxt("data.txt", delimiter=",") learning_rate = 0.0001 initial_b = random() initial_w = random() num_iterations = 1000 print("Starting gradient descent at b = {0}, w = {1}, error = {2}" .format(initial_b, initial_w, compute_error_for_line_given_points(initial_b, initial_w, points))) print("Running...") [b, w] = gradient_descent_runner(points, initial_b, initial_w, learning_rate, num_iterations) print("After {0} iterations at b = {1}, w = {2}, error = {3}" .format(num_iterations, b, w, compute_error_for_line_given_points(b, w, points))) run()