MSGB程式碼
阿新 • • 發佈:2018-12-31
import numpy as np import random def gen_line_data(sample_num=100): x1 = np.linspace(0, 9, sample_num) x2 = np.linspace(4, 13, sample_num) x = np.concatenate(([x1], [x2]), axis=0).T y = np.dot(x, np.array([3, 4]).T) # y 列向量 return x, y def mbgd(x, y, step_size=0.01, max_iter_count=10000, batch_size=0.2): sample_num, dim = x.shape w = np.random.randn(dim) # batch_size = np.ceil(sample_num * batch_size) loss = 10 iter_count = 0 while loss > 0.001 and iter_count < max_iter_count: loss = 0 error = np.zeros(dim) # 從sample_num中隨機取20%的資料 index = random.sample(range(sample_num),int(np.ceil(sample_num * batch_size)))#獲取隨機索引 # 取出資料 batch_x = x[index] batch_y = y[index] for i in range(len(batch_x)): predict_y = np.dot(w, batch_x[i]) for j in range(dim): error[j] += (batch_y[i] - predict_y) * batch_x[i][j] w[j] += step_size * error[j] / sample_num for i in range(sample_num): predict_y = np.dot(w.T, x[i]) error = (1 / (sample_num * dim)) * np.power((predict_y - y[i]), 2) loss += error print("iter_count: ", iter_count, "the loss:", loss) iter_count += 1 return w if __name__ == '__main__': x, y = gen_line_data() w = mbgd(x, y) print(w) # 會很接近[3, 4]