長短期記憶網路LSTM(matlab)
阿新 • • 發佈:2019-01-22
if(n~=1)
%% 更新weight_input_x
temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
for num=1:output_num
for m=1:data_length
delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
end
W_input_x=W_input_x-lr*delta_weight_input_x;
end
%% 更新weight_forgetgate_x
for num=1:output_num
for m=1:data_length
delta_weight_forgetgate_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train(m,n);
end
W_forgetgate_x=W_forgetgate_x-lr*delta_weight_forgetgate_x;
end
%% 更新weight_inputgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_inputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
end
W_inputgate_c=W_inputgate_c-lr*delta_weight_inputgate_c;
end
%% 更新weight_forgetgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_forgetgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
end
W_forgetgate_c=W_forgetgate_c-lr*delta_weight_forgetgate_c;
end
%% 更新weight_outputgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_outputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
end
W_outputgate_c=W_outputgate_c-lr*delta_weight_outputgate_c;
end
%% 更新weight_input_h
temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
for num=1:output_num
for m=1:output_num
delta_weight_input_h(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
end
W_input_h=W_input_h-lr*delta_weight_input_h;
end
else
%% 更新weight_input_x
temp=train(1:input_num,n)'*W_input_x;
for num=1:output_num
for m=1:data_length
delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
end
W_input_x=W_input_x-lr*delta_weight_input_x;
end
end
W_preh_h=weight_preh_h_temp;
end
%% 更新weight_input_x
temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
for num=1:output_num
for m=1:data_length
delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
end
W_input_x=W_input_x-lr*delta_weight_input_x;
end
%% 更新weight_forgetgate_x
for num=1:output_num
for m=1:data_length
delta_weight_forgetgate_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train(m,n);
end
W_forgetgate_x=W_forgetgate_x-lr*delta_weight_forgetgate_x;
end
%% 更新weight_inputgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_inputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
end
W_inputgate_c=W_inputgate_c-lr*delta_weight_inputgate_c;
end
%% 更新weight_forgetgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_forgetgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
end
W_forgetgate_c=W_forgetgate_c-lr*delta_weight_forgetgate_c;
end
%% 更新weight_outputgate_c
for num=1:output_num
for m=1:cell_num
delta_weight_outputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
end
W_outputgate_c=W_outputgate_c-lr*delta_weight_outputgate_c;
end
%% 更新weight_input_h
temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
for num=1:output_num
for m=1:output_num
delta_weight_input_h(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
end
W_input_h=W_input_h-lr*delta_weight_input_h;
end
else
%% 更新weight_input_x
temp=train(1:input_num,n)'*W_input_x;
for num=1:output_num
for m=1:data_length
delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
end
W_input_x=W_input_x-lr*delta_weight_input_x;
end
end
W_preh_h=weight_preh_h_temp;
end