1. 程式人生 > >tensorflow學習筆記(北京大學) 隨機畫點 完全解析

tensorflow學習筆記(北京大學) 隨機畫點 完全解析

#coding:utf-8
#0匯入模組 ,生成模擬資料集
#tensorflow學習筆記(北京大學)  隨機畫點  完全解析
#QQ群:476842922(歡迎加群討論學習)
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE = 30 
seed = 2 
#基於seed產生隨機數
rdm = np.random.RandomState(seed)
#隨機數返回300行2列的矩陣,表示300組座標點(x0,x1)作為輸入資料集
X = rdm.randn(300,2)
#從X這個300行2列的矩陣中取出一行,判斷如果兩個座標的平方和小於2,給Y賦值1,其餘賦值0
#作為輸入資料集的標籤(正確答案) Y_ = [int(x0*x0 + x1*x1 <2) for (x0,x1) in X] #遍歷Y中的每個元素,1賦值'red'其餘賦值'blue',這樣視覺化顯示時人可以直觀區分 Y_c = [['red' if y else 'blue'] for y in Y_] #對資料集X和標籤Y進行shape整理,第一個元素為-1表示n行2列,把Y整理為n行1列 X = np.vstack(X).reshape(-1,2)#表示n行2列 Y_ = np.vstack(Y_).reshape(-1,1)#n行1列 print(X) print(Y_) print
(Y_c) #用plt.scatter畫出資料集X各行中第0列元素和第1列元素的點即各行的(x0,x1),用各行Y_c對應的值表示顏色(c是color的縮寫) plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c)) plt.show()

在這裡插入圖片描述