1. 程式人生 > >脈衝神經網路之Tempotron(二)

脈衝神經網路之Tempotron(二)

                        脈衝神經網路之Tempotron程式碼示例

        上一篇從原理的角度大致介紹了脈衝神經網路的神經元模型以及Tempotron監督學習方法,這一章記錄了Tempotron的程式碼實現。這份程式碼是使用matlab編寫,用脈衝神經網路實現對26個字母分類,我們從細節去解釋該程式碼:

function TempotronClassify()
% Tempotron: a neuron that learns spike timing-based decisions
% Rober Gutig 2006 Nature Neuroscience
clear; clc;
NumImages=26;
for i=1:NumImages
    ImageName=strcat('Icon16X16\Letter-',char('A'+i-1),'-black-icon');% 從icon16X16資料夾中讀取所有圖片
    ImageMatrix=imread(ImageName,'bmp');% 讀取圖片為灰度圖,儲存在矩陣中,取反
    ImageMatrix=~ImageMatrix;  % make the white pixel be 0, and black be 1;
    TrainPtns(:,i)=image2ptn(ImageMatrix);
end
上面程式碼片段就是從Icon16X16\資料夾中讀取圖片檔案,圖片從A到Z,這個程式碼的圖片資料很簡單,只有0和1,由黑白兩色構成字母,讀取的資料以16*16矩陣形式儲存,之後呼叫image2ptn()方法,將矩陣轉換。這個方法實現的就是對外界刺激進行編碼,也就是將圖片資料轉換成脈衝序列的形式。開啟該方法檢視程式碼:
%% convert a image to a spike train saved in a vector
RandParts=1;

spikeTrain=zeros(32,1);
if RandParts==1
    loadR=1;
    AR=A(:);
    if loadR==0
        R = randperm(size(A,1)*size(A,2));
        save('RandIndex','R');
    else
        load('RandIndex','R');
    end
    numRandNeus=32;
    for i=1:numRandNeus
        IndexR=R((1+(i-1)*8):(8+(i-1)*8));
        Bits=AR(IndexR);
        Bits=Bits';
        spikeTime=bi2de(Bits,'left-msb');    % 二進位制轉十進位制
        if spikeTime==0
            spikeTime=2^8;  % put 0 to the end of the interval
        end
        spikeTrain(i)=spikeTime;
    end
end
這裡的rand函式目的只是打亂16*16矩陣中的行,我是不太理解這個=-=,AR是個256*1的陣列,是將A(圖片矩陣16*16)展平(flatten)的結果,Bits=AR(IndexR);這句簡單的說就是從AR中隨機抽取8個數字,構成bits,然後將bits轉換成spikeTime(十進位制數字),代表脈衝發放的時間。綜上原先的圖片16*16矩陣被轉換成32*8矩陣,8代表8位二進位制編碼,範圍在0~255之間,也就是說時間視窗大小為256毫秒,然後將8為二進位制數轉為十進位制,代表脈衝發放的時間,該時間在(0~255)之間。總結起來,就是對於一張16*16的黑白圖片,轉換成了一個32*1的陣列,陣列中有32個時間視窗,每個元素代表脈衝傳送時間。
TrainPtns=TrainPtns*1e-3;  % scale to ms
nAfferents = size(TrainPtns,1);
nPtns = NumImages;
%nOutputs = 1;    %%%%%%%%%%   1
nOutputs = 5;

TrainPtns是32*26的矩陣,也就是將所有轉換後的spikeTrain拼接一起。nAfferents代表了輸入脈衝數量,這裡就是32,由這32個輸入確定一個輸出(字母)。輸出就定義為5,也就是5個二進位制數表示一個字母的編號。

loadData=0;% 是否載入已儲存的模型

V_thr = 1; V_rest = 0;
T = 256e-3;         % pattern duration ms
dt = 1e-3;
tau_m = 20e-3; % tau_m = 15e-3;???
tau_s = tau_m/4;
% K(t?ti)=V0(exp[?(t?ti)/τ]–exp[?(t?ti)/τs])
aa = exp(-(0:dt:3*tau_m)/tau_m)-exp(-(0:dt:3*tau_m)/tau_s);

V0 = 1/max(exp(-(0:dt:3*tau_m)/tau_m)-exp(-(0:dt:3*tau_m)/tau_s));
lmd = 2e-2;%1e-2/V0;   % optimal performance lmd=3e-3*T/(tau_m*nAfferents*V0)  1e-4/V0;
maxEpoch = 200;
mu = 0.99;  % momentum factor
上面定義了閾值電壓V_thr,這是按經驗來取,這裡取1,復位電壓則取0.定義了時間視窗大小為256毫秒,單位時間是1毫秒,aa則是為了核函式K的計算簡化定義。即預先定義公式:
K(t-ti)=V0(exp[-(t-ti)/τ]–exp[-(t-ti)/τs])
V0也被預先定義。最大訓練次數為200次(maxEpoch),lmd類似於學習率,而mu則是動量,該程式碼用了動量優化模型。
if loadData ==0 %初始化網路
    weights = 1e-2*randn(nAfferents,nOutputs);  % 1e-3*randn(nAfferents,1);
    save('weights0','weights');
else
    load('weights0','weights');
end
隨機初始化權重,或者讀取儲存的權重檔案。
Class = de2bi(1:26,'left-msb'); Class=Class';

correctRate=zeros(1,maxEpoch);
dw_Past=zeros(nAfferents,nPtns,nOutputs);  % momentum for accelerating learning.上一個權重的更新,用於動量計算
for epoch=1:maxEpoch    
    
    Class_Tr = false(nOutputs,nPtns);  % actual outputs of training
    for pp=1:nPtns                 
        for neuron=1:nOutputs
            Vmax=0; tmax=0;
            fired=false;        
            Vm1=zeros(1,256); indx1= 1; % trace pattern 1
            for t=dt:dt:T
                Vm = 0; 
                if fired==false
                    Tsyn=find(TrainPtns(:,pp)<=t+0.1*dt);    % no cut window
                else
                    Tsyn=find(TrainPtns(:,pp)<=t_fire+0.1*dt); % shut down inputs
                end
                if ~isempty(Tsyn)                    
                    A1=TrainPtns(:,pp);
                    A2=A1(Tsyn);
                    K =V0*(exp(-(t-A2)/tau_m)-exp(-(t-A2)/tau_s)); % the kernel value for each fired afferent
                    A1=weights(:,neuron);
                    firedWeights=A1(Tsyn);
                    Vm = Vm + firedWeights'*K ;
                end

                Vm = Vm + V_rest;
                if Vm>=V_thr && fired==false % fire
                    fired=true;
                    t_fire=t;
                    Class_Tr(neuron,pp)=true;
                end
                if Vm>Vmax
                    Vmax=Vm; tmax=t;
                end

                %if pp==1
                    Vm1(indx1)=Vm;
                    indx1=indx1+1;
                %end
            end

            %if pp==1
                figure(1); plot(dt:dt:T,Vm1);
                title(strcat('Image ',char('A'+pp-1),'; neuron: ',num2str(neuron))); drawnow;
            %end
            if Vmax<=0
                tmax=max(TrainPtns(:,pp));
            end
            
            if Class_Tr(neuron,pp)~=Class(neuron,pp) % error
                
                Tsyn=find(TrainPtns(:,pp)<=tmax+0.1*dt); 

                if ~isempty(Tsyn)                    
                    A1=TrainPtns(:,pp);
                    A2=A1(Tsyn);
                    K =V0*(exp(-(tmax-A2)/tau_m)-exp(-(tmax-A2)/tau_s)); % the kernel value for each fired afferent
                    A1=weights(:,neuron);
                    dwPst=dw_Past(:,pp,neuron);
                    if fired==false    % LTP
                        Dw=lmd*K;
                    else           % LTD
                        Dw=-1.1*lmd*K;
                    end
                    A1(Tsyn) = A1(Tsyn) + Dw + mu*dwPst(Tsyn);
                    weights(:,neuron)=A1;
                    dwPst(Tsyn)=Dw;
                    dw_Past(:,pp,neuron) = dwPst;
                end                
            end            
            
        end  % end of one neuron computation
        
   end % end of one image
   %CC=isequal(Class,Class_Tr);
   %correctRate(epoch)=sum(Class==Class_Tr)/length(Class);
   CC = bi2de(Class_Tr','left-msb');
end
save('TrainedWt','weights');
figure(2); plot(1:maxEpoch,correctRate,'-b.');
end
class代表最後的類別,由5位二進位制編碼表示。最外層迴圈是最大訓練次數,第二層nPtns代表了26個樣本,最後一層迴圈nOutputs為5,代表對每個輸出神經元進行一次計算。之後將Vmax,Tmax(代表最大膜電位出現時間)初始化為0,fired表示是否發出脈衝,這裡初始化為false,Vm1記錄了此輸出神經元膜電位(在一個時間視窗內)的變化,之後進入迴圈,迴圈一個時間視窗長度,對一個時間視窗內的每個單位時間,若該神經元已經發送過脈衝,則shut down所有脈衝輸入;若該神經元還未發出過脈衝(fired = false), 就執行Tsyn=find(TrainPtns(:,pp)<=t+0.1*dt);也就是在當前樣本(pp)的脈衝序列中,找到比時間t+0.1*dt小的脈衝傳送時間,也就是在t這個時間點之前有脈衝輸入發生,此時A2就儲存了該脈衝傳送的時間(也就是ti),之後就可呼叫K(t - ti)來計算此時的核函式;下一步就是找到傳送這個脈衝的輸入神經元和該輸出神經元之間連線的權重數值,用K*Weights就代表了該脈衝輸入對膜電位的影響,最後將該貢獻值累加到Vm上,得到此刻膜電位大小。

對整個時間視窗計算完後,將Vm加上覆位電壓,如果該值超過了閾值電壓,則發放脈衝;若該值超過了已有的Vmax,則更新Vmax。之後進入判斷語句:

if Class_Tr(neuron,pp)~=Class(neuron,pp) 神經元的輸出和實際輸出不相符時,分為兩個情況:

(1)實際輸出為0,但發放了脈衝。

此時應該抑制脈衝發放,於是權重應該減小:

Dw=-1.1*lmd*K;

(2)實際輸出為1,但沒有發放脈衝:

此時應該增大神經元的刺激,權重應該增加:

Dw=lmd*K;
最後儲存此次迴圈的權重。於是單層的脈衝神經網路監督學習就這樣完成了。下面附上完整程式碼:

function TempotronClassify()
% Tempotron: a neuron that learns spike timing-based decisions
% Rober Gutig 2006 Nature Neuroscience
clear; clc;
NumImages=26;
for i=1:NumImages
    ImageName=strcat('Icon16X16\Letter-',char('A'+i-1),'-black-icon');% 從icon16X16資料夾中讀取所有圖片
    ImageMatrix=imread(ImageName,'bmp');% 讀取圖片為灰度圖,儲存在矩陣中,取反
    ImageMatrix=~ImageMatrix;  % make the white pixel be 0, and black be 1;
    TrainPtns(:,i)=image2ptn(ImageMatrix);
end
TrainPtns=TrainPtns*1e-3;  % scale to ms
nAfferents = size(TrainPtns,1);
nPtns = NumImages;
%nOutputs = 1;    %%%%%%%%%%   1
nOutputs = 5;

loadData=0;% 是否載入已儲存的模型

V_thr = 1; V_rest = 0;
T = 256e-3;         % pattern duration ms
dt = 1e-3;
tau_m = 20e-3; % tau_m = 15e-3;???
tau_s = tau_m/4;
% K(t?ti)=V0(exp[?(t?ti)/τ]–exp[?(t?ti)/τs])
aa = exp(-(0:dt:3*tau_m)/tau_m)-exp(-(0:dt:3*tau_m)/tau_s);

V0 = 1/max(exp(-(0:dt:3*tau_m)/tau_m)-exp(-(0:dt:3*tau_m)/tau_s));
lmd = 2e-2;%1e-2/V0;   % optimal performance lmd=3e-3*T/(tau_m*nAfferents*V0)  1e-4/V0;
maxEpoch = 200;
mu = 0.99;  % momentum factor
% generate patterns (each pattern consists one spik-e per afferent)

if loadData ==0 %初始化網路
    weights = 1e-2*randn(nAfferents,nOutputs);  % 1e-3*randn(nAfferents,1);
    save('weights0','weights');
else
    load('weights0','weights');
end
%Class = logical(eye(nOutputs));     % desired class label for each pattern
%Class = false(1,26); Class(26)=true;
Class = de2bi(1:26,'left-msb'); Class=Class';

correctRate=zeros(1,maxEpoch);
dw_Past=zeros(nAfferents,nPtns,nOutputs);  % momentum for accelerating learning.上一個權重的更新,用於動量計算
for epoch=1:maxEpoch    
    
    Class_Tr = false(nOutputs,nPtns);  % actual outputs of training
    for pp=1:nPtns 
 %       Class_Tr = false(nOutputs,1);  % actual outputs of training
                
        for neuron=1:nOutputs
            Vmax=0; tmax=0;
            fired=false;        
            Vm1=zeros(1,256); indx1= 1; % trace pattern 1
            for t=dt:dt:T
                Vm = 0; 
                if fired==false
                    Tsyn=find(TrainPtns(:,pp)<=t+0.1*dt);    % no cut window
                else
                    Tsyn=find(TrainPtns(:,pp)<=t_fire+0.1*dt); % shut down inputs
                end
                if ~isempty(Tsyn)                    
                    A1=TrainPtns(:,pp);
                    A2=A1(Tsyn);
                    K =V0*(exp(-(t-A2)/tau_m)-exp(-(t-A2)/tau_s)); % the kernel value for each fired afferent
                    A1=weights(:,neuron);
                    firedWeights=A1(Tsyn);
                    Vm = Vm + firedWeights'*K ;
                end

                Vm = Vm + V_rest;
                if Vm>=V_thr && fired==false % fire
                    fired=true;
                    t_fire=t;
                    Class_Tr(neuron,pp)=true;
                end
                if Vm>Vmax
                    Vmax=Vm; tmax=t;
                end

                %if pp==1
                    Vm1(indx1)=Vm;
                    indx1=indx1+1;
                %end
            end

            %if pp==1
                figure(1); plot(dt:dt:T,Vm1);
                title(strcat('Image ',char('A'+pp-1),'; neuron: ',num2str(neuron))); drawnow;
            %end
            if Vmax<=0
                tmax=max(TrainPtns(:,pp));
            end
            
            if Class_Tr(neuron,pp)~=Class(neuron,pp) % error
                
                Tsyn=find(TrainPtns(:,pp)<=tmax+0.1*dt); 

                if ~isempty(Tsyn)                    
                    A1=TrainPtns(:,pp);
                    A2=A1(Tsyn);
                    K =V0*(exp(-(tmax-A2)/tau_m)-exp(-(tmax-A2)/tau_s)); % the kernel value for each fired afferent
                    A1=weights(:,neuron);
                    dwPst=dw_Past(:,pp,neuron);
                    if fired==false    % LTP
                        Dw=lmd*K;
                    else           % LTD
                        Dw=-1.1*lmd*K;
                    end
                    A1(Tsyn) = A1(Tsyn) + Dw + mu*dwPst(Tsyn);
                    weights(:,neuron)=A1;
                    dwPst(Tsyn)=Dw;
                    dw_Past(:,pp,neuron) = dwPst;
                end                
            end            
            
        end  % end of one neuron computation
        
   end % end of one image
   %CC=isequal(Class,Class_Tr);
   %correctRate(epoch)=sum(Class==Class_Tr)/length(Class);
   CC = bi2de(Class_Tr','left-msb');
end
save('TrainedWt','weights');
figure(2); plot(1:maxEpoch,correctRate,'-b.');
end

%%將圖片編碼為脈衝序列並儲存在向量中
function spikeTrain=image2ptn(A)  
%% convert a image to a spike train saved in a vector
RandParts=1;
% A1=A';
% B=[A1(:);A(:)];
% numPixels=length(B);
% numInputNeurons=numPixels/8; % 64 neurons
% spikeTrain=zeros(numInputNeurons,1);
% for i=1:numInputNeurons
%     Bits=B((1+(i-1)*8):(8+(i-1)*8));
%     Bits=Bits';
%     spikeTime=bi2de(Bits,'left-msb');    
%     if spikeTime==0
%         spikeTime=2^8;  % put 0 to the end of the interval
%     end
%     spikeTrain(i)=spikeTime;
% end

spikeTrain=zeros(32,1);
if RandParts==1
    loadR=1;
    AR=A(:);
    if loadR==0
        R = randperm(size(A,1)*size(A,2));
        save('RandIndex','R');
    else
        load('RandIndex','R');
    end
    numRandNeus=32;
    for i=1:numRandNeus
        IndexR=R((1+(i-1)*8):(8+(i-1)*8));
        Bits=AR(IndexR);
        Bits=Bits';
        spikeTime=bi2de(Bits,'left-msb');    % 二進位制轉十進位制
        if spikeTime==0
            spikeTime=2^8;  % put 0 to the end of the interval
        end
        spikeTrain(i)=spikeTime;
    end
end
end