【異常檢測】DAGMM:結合深度自編碼器器和GMM的端到端無監督網路(二):程式碼實戰(PyTorch)
程式碼部分基於PyTorch1.6.0,使用網路入侵異常檢測資料集KDDCUP99來訓練和評測,完整程式碼見:GitHub。
文章目錄
1.網路部分:
網路部分的實現較為簡單,基本上就是DAE結構接一個全連線結構,最後輸出一個softmax Tensor。要注意好Compression network和Estimation network的輸入輸出的Tensor shape:
class DAGMM(nn. Module):
def __init__(self, hyp):
super(DAGMM, self).__init__()
layers = []
layers += [nn.Linear(hyp['input_dim'],hyp['hidden1_dim'])]
layers += [nn.Tanh()]
layers += [nn.Linear(hyp['hidden1_dim'],hyp['hidden2_dim'])]
layers += [nn. Tanh()]
layers += [nn.Linear(hyp['hidden2_dim'],hyp['hidden3_dim'])]
layers += [nn.Tanh()]
layers += [nn.Linear(hyp['hidden3_dim'],hyp['zc_dim'])]
self.encoder = nn.Sequential(*layers)
layers = []
layers += [nn.Linear(hyp['zc_dim'],hyp[ 'hidden3_dim'])]
layers += [nn.Tanh()]
layers += [nn.Linear(hyp['hidden3_dim'],hyp['hidden2_dim'])]
layers += [nn.Tanh()]
layers += [nn.Linear(hyp['hidden2_dim'],hyp['hidden1_dim'])]
layers += [nn.Tanh()]
layers += [nn.Linear(hyp['hidden1_dim'],hyp['input_dim'])]
self.decoder = nn.Sequential(*layers)
layers = []
layers += [nn.Linear(hyp['zc_dim']+2,hyp['hidden3_dim'])]
layers += [nn.Tanh()]
layers += [nn.Dropout(p=hyp['dropout'])]
layers += [nn.Linear(hyp['hidden3_dim'],hyp['n_gmm'])]
layers += [nn.Softmax(dim=1)]
self.estimation = nn.Sequential(*layers)
def forward(self, x):
enc = self.encoder(x)
dec = self.decoder(enc)
rec_cosine = F.cosine_similarity(x, dec, dim=1)
rec_euclidean = F.pairwise_distance(x, dec,p=2)
z = torch.cat([enc, rec_euclidean.unsqueeze(-1), rec_cosine.unsqueeze(-1)], dim=1)
gamma = self.estimation(z)
return enc,dec,z,gamma
2.GMM引數計算:
網路輸出GMM中的component的概率分佈,用於GMM均值,協方差等引數的計算,參照論文中的公式(5)。程式碼中這部分包括後面計算似然函式的部分涉及到大量數學計算,我自己實驗是在CPU上計算效率比GPU高,所以在完整程式碼的訓練部分將網路引數從GPU搬回到CPU。
def get_gmm_param(gamma, z):
N = gamma.shape[0]
ceta = torch.sum(gamma, dim=0) / N #shape: [n_gmm]
mean = torch.sum(gamma.unsqueeze(-1) * z.unsqueeze(1), dim=0)
mean = mean / torch.sum(gamma, dim=0).unsqueeze(-1) #shape: [n_gmm, z_dim]
z_mean = (z.unsqueeze(1)- mean.unsqueeze(0))
z_mean_mm = z_mean.unsqueeze(-1) * z_mean.unsqueeze(-2)
cov = torch.sum(gamma.unsqueeze(-1).unsqueeze(-1) * z_mean_mm, dim = 0) / torch.sum(gamma, dim=0).unsqueeze(-1).unsqueeze(-1) #shape: [n_gmm,z_dim,z_dim]
return ceta, mean, cov
3.似然函式和總體損失的計算:
在獲得GMM引數後,可由論文中的公式(6)寫出似然函式。損失函式參照公式(7),注意有三項,分別是網路輸出的重構誤差,似然函式損失以及用來防止矩陣計算不可逆的項。
def reconstruct_error(x, x_hat): #重構誤差
e = torch.tensor(0.0)
for i in range(x.shape[0]):
e += torch.dist(x[i], x_hat[i])
return e / x.shape[0]
def sample_energy(ceta, mean, cov, zi,n_gmm,bs):
e = torch.tensor(0.0)
cov_eps = torch.eye(mean.shape[1]) * (1e-12)
# cov_eps = cov_eps.to(device)
for k in range(n_gmm):
miu_k = mean[k].unsqueeze(1)
d_k = zi - miu_k
inv_cov = torch.inverse(cov[k] + cov_eps)
e_k = torch.exp(-0.5 * torch.chain_matmul(torch.t(d_k), inv_cov, d_k))
e_k = e_k / torch.sqrt(torch.abs(torch.det(2 * math.pi * cov[k])))
e_k = e_k * ceta[k]
e += e_k.squeeze()
return -torch.log(e)
def loss_func(x, dec, gamma, z):
bs,n_gmm = gamma.shape[0],gamma.shape[1]
#1 計算重構誤差
recon_error = reconstruct_error(x, dec)
#2 獲得GMM引數
ceta, mean, cov = get_gmm_param(gamma, z)
# ceta = ceta.to(device)
# mean = mean.to(device)
# cov = cov.to(device)
#3 似然函式損失項
e = torch.tensor(0.0)
for i in range(z.shape[0]):
zi = z[i].unsqueeze(1)
ei = sample_energy(ceta, mean, cov, zi,n_gmm,bs)
e += ei
p = torch.tensor(0.0)
for k in range(n_gmm):
cov_k = cov[k]
p_k = torch.sum(1 / torch.diagonal(cov_k, 0))
p += p_k
loss = recon_error + (0.1 / z.shape[0]) * e + 0.005 * p
return loss, recon_error, e/z.shape[0], p
4.KDDCUP99資料集預處理和劃分:
下載連結(外網不好下載,我上傳至CSDN資源中免費下載)
該資料集是從一個模擬的美國空軍區域網上採集來的 9 個星期的網路連線資料, 分成具有標識的訓練資料和未加標識的測試資料。測試資料和訓練資料有著不同的概率分佈, 測試資料包含了一些未出現在訓練資料中的攻擊型別, 這使得入侵檢測更具有現實性。
該資料集的特徵描述和分析可以參考這個連結:
KDDCUP99 網路入侵資料集描述
資料集中normal標識佔總體的20%,所以將有normal標示的視為異常樣本,對字元型離散特徵做one-hot編碼,對連續數值特徵做歸一化處理,如下:
import numpy as np
import pandas as pd
import os
data = pd.read_csv(os.path.join(data_dir,"kddcup.data_10_percent"), header=None,names=['duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 'type'])
data.loc[data["type"] != "normal.", 'type'] = 0 #1是異常
data.loc[data["type"] == "normal.", 'type'] = 1
one_hot_protocol = pd.get_dummies(data["protocol_type"])
one_hot_service = pd.get_dummies(data["service"])
one_hot_flag = pd.get_dummies(data["flag"])
data = data.drop("protocol_type",axis=1)
data = data.drop("service",axis=1)
data = data.drop("flag",axis=1)
data = pd.concat([one_hot_protocol, one_hot_service,one_hot_flag, data],axis=1)
cols_to_norm = ["duration", "src_bytes", "dst_bytes", "wrong_fragment", "urgent",
"hot", "num_failed_logins", "num_compromised", "num_root",
"num_file_creations", "num_shells", "num_access_files", "count", "srv_count",
"serror_rate", "srv_serror_rate", "rerror_rate", "srv_rerror_rate", "same_srv_rate",
"diff_srv_rate", "srv_diff_host_rate", "dst_host_count", "dst_host_srv_count", "dst_host_same_srv_rate",
"dst_host_diff_srv_rate", "dst_host_same_src_port_rate", "dst_host_srv_diff_host_rate",
"dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate", "dst_host_srv_rerror_rate" ]
# data.loc[:, cols_to_norm] = (data[cols_to_norm] - data[cols_to_norm].mean()) / data[cols_to_norm].std()
min_cols = data.loc[data["type"]==0 , cols_to_norm].min()
max_cols = data.loc[data["type"]==0 , cols_to_norm].max()
data.loc[:, cols_to_norm] = (data[cols_to_norm] - min_cols) / (max_cols - min_cols)
np.savez_compressed("kdd_cup",kdd=np.array(data)) #儲存
proportions = data["type"].value_counts()
print(proportions)
print("Anomaly Percentage",proportions[1] / proportions.sum())
採取的資料劃分策略是從分離的正常樣本和異常樣本中,各抽取百分之80%,shuffle後作為訓練集,剩下的作為測試集,在測試的時候,先根據訓練集資料中的似然函式值E(z)從小到排序,擷取百分之八十分位的值(該資料集中正常樣本:異常樣本 = 4:1)作為異常預測的閾值。
class kddcup99_Dataset(Dataset):
def __init__(self, data_dir = None, mode='train',ratio = 0.8):
self.mode = mode
data = np.load(data_dir,allow_pickle=True)['kdd']
data = np.float32(data)
data = torch.from_numpy(data)
normal_data = data[data[:, -1] == 0]
abnormal_data = data[data[:, -1] == 1]
train_normal_mark = int(normal_data.shape[0] * ratio)
train_abnormal_mark = int(abnormal_data.shape[0] * ratio)
train_normal_data = normal_data[:train_normal_mark, :]
train_abnormal_data = abnormal_data[:train_abnormal_mark, :]
self.train_data = np.concatenate((train_normal_data, train_abnormal_data), axis=0)
np.random.shuffle(self.train_data)
test_normal_data = normal_data[train_normal_mark:, :]
test_abnormal_data = abnormal_data[train_abnormal_mark:, :]
self.test_data = np.concatenate((test_normal_data, test_abnormal_data), axis=0)
np.random.shuffle(self.test_data)
def __len__(self):
if self.mode == 'train':
return self.train_data.shape[0]
else:
return self.test_data.shape[0]
def __getitem__(self, index):
if self.mode == 'train':
return self.train_data[index,:-1], self.train_data[index,-1]
else:
return self.test_data[index,:-1], self.test_data[index,-1]
def get_loader(hyp, mode = 'train'):
"""Build and return data loader."""
dataset = kddcup99_Dataset(hyp['data_dir'], mode, hyp['ratio'])
shuffle = True if mode == 'train' else False
data_loader = DataLoader(dataset=dataset,
batch_size=hyp['batch_size'],
shuffle=shuffle)
return data_loader,len(dataset)