OGSBI Matlab Code(轉載)
阿新 • • 發佈:2019-01-03
註釋:程式碼轉載自 https://sites.google.com/site/zaiyang0248/publication
一、主程式
clear all; close all; clc resolution = 2; grid = (0:resolution:180)'; SNR = 0; M = 8; K = 2; N = length(grid); T = 200; issvd = true; % issvd = false; % true DOA theta = [60.3 88.6]'; % true Phi Phi = zeros(M,K); for m = 1:M for k = 1:K Phi(m,k) = exp(1i * pi * (m-(M+1)/2) * cos(theta(k)/180*pi)); end end % true signal % random signals amp = [1 1]'; if T == 1 X = exp(1i*2*pi*unifrnd(0,1,2,1)); else X1 = (randn(K,T) + 1i * randn(K,T)) / sqrt(2); X = diag(amp) * X1; end % % correlated signals % amp = [1 1]'; % x1 = (randn(1,200) + 1i *randn(1,200))/sqrt(2); % y1 = (randn(1,200) + 1i *randn(1,200))/sqrt(2); % x2 = .99*x1 + sqrt(1-.99^2)*y1; % X = diag(amp) * [x1; x2]; % observed signal Y = Phi * X; sigma2 = 10^(-SNR/10) * norm(Y,'fro')^2 / (M * T); % error E = sqrt(sigma2)/sqrt(2)*randn(M,T) + 1i*sqrt(sigma2)/sqrt(2)*randn(M,T); Y = Phi * X + E; % uniform linear array (ULA), with the origin at the middle A = zeros(M,N); B = zeros(M,N); for m = 1:M for n = 1:N temp = exp(1i * pi * (m-(M+1)/2) * cos(grid(n)/180*pi)); A(m,n) = temp; B(m,n) = -1i * pi * (m-(M+1)/2) * sin(grid(n)/180*pi) * temp; end end % initialize parameters if issvd [s,v,d] = svd(Y,'econ'); Y = Y * d(:, 1:K); end params.Y = Y; params.A = A; params.B = B; params.resolution = resolution/180*pi; params.rho = 1e-2; params.alpha = mean(abs(A'*Y), 2); params.beta = zeros(N,1); params.K = K; params.maxiter = 2000; params.tolerance = 1e-3; % % params.sigma2 = mean(var(Y))/100; % isKnownSigmaVar = true; % params.knownsigma2 = sigma2; tstart = tic; res = OGSBI(params); time = toc(tstart); % line plot xp_rec = grid + res.beta * 180 / pi; if issvd x_rec = res.mu * d(:,1:size(res.mu,2))'; xpower_rec = mean(abs(x_rec).^2,2) + real(diag(res.Sigma)) * K / T; else xpower_rec = mean(abs(res.mu).^2,2) + real(diag(res.Sigma)); end figure(1000),plot(theta, 10*log10(amp.^2), 'bo', xp_rec, 10*log10(xpower_rec), 'rx-'); axis([0,180,min([10*log10(amp.^2); 10*log10(xpower_rec)]),max([10*log10(amp.^2); 10*log10(xpower_rec)])+3]); xlabel('DOA (degrees)', 'fontsize',12); ylabel('Power (dB)','fontsize',12); legend('True DOAs','OGSBI spectral');
二、程式呼叫的函式
function res = OGSBI(paras) % res = OGSBI(paras) % % OGSBI(paras) performs DOA estimation using Sparse Bayesian Inference % % Input: % paras.Y: M * T matrix, sensor measurements at all snapshots % paras.A: M * N matrix, columns are the steering vectors for different directions % paras.B: M * N matrix, columns are derivatives of the steering vectors wrt. different directions % paras.sigma2: initialization of noise variance % paras.alpha: initialization of alpha % paras.beta: initialization of beta % paras.rho: rho % paras.resolution: grid resolution for the directions % paras.maxiter: maximum iteration % paras.tol: stopping criterion % paras.isKnownNoiseVar: true if known variance, false if unknown % paras.K: number of sources % % Output: % res.mu: mean estimation % res.Sigma: variance estimation % res.sigma2: estimated noise variance % res.sigma2seq: estimated noise variance at all iterations % res.alpha: reconstructed alpha % res.beta: reconstructed beta % res.iter: iteration used in the algorithm % res.ML: maximum likelihood function value at all iterations % % Written by Zai Yang, 19 Jul, 2011 % reference: % Z. Yang, L. Xie, and C. Zhang, "Off-grid direction of arrival estimation ... % using sparse Bayesian inference", IEEE Trans. Signal Processing, ... % vol. 61, no. 1, pp. 38--43, 2013. eps = 1e-16; Y = paras.Y; A = paras.A; B = paras.B; [M, T] = size(Y); N = size(A, 2); alpha0 = 1 / paras.sigma2; rho = paras.rho / T; beta = paras.beta; alpha = paras.alpha; r = paras.resolution; maxiter = paras.maxiter; tol = paras.tolerance; if isfield(paras, 'isKnownNoiseVar') && ~isempty(paras.isKnownNoiseVar) isKnownNoiseVar = paras.isKnownNoiseVar; else isKnownNoiseVar = false; end if isKnownNoiseVar a = 1; b = T * M * paras.knownsigma2; else a = 1e-4; b = 1e-4; end if isfield(paras, 'K') && ~isempty(paras.K) K = paras.K; else K = min(T, M-1); end idx = []; BHB = B' * B; converged = false; iter_beta = 1; iter = 0; ML = zeros(maxiter,1); alpha0seq = zeros(maxiter,1); while ~converged iter = iter + 1; Phi = A; Phi(:,idx) = A(:,idx) + B(:,idx) * diag(beta(idx)); alpha_last = alpha; C = 1 / alpha0 * eye(M) + Phi * diag(alpha) * Phi'; % Sigma = diag(alpha) - diag(alpha) * Phi' / C * Phi * diag(alpha); Cinv = inv(C); Sigma = diag(alpha) - diag(alpha) * Phi' * Cinv * Phi * diag(alpha); mu = alpha0 * Sigma * Phi' * Y; gamma1 = 1 - real(diag(Sigma)) ./ (alpha + eps); % update alpha musq = mean(abs(mu).^2, 2); alpha = musq + real(diag(Sigma)); if rho ~= 0 alpha = -.5 / rho + sqrt(.25 / rho^2 + alpha / rho); end % update alpha0 resid = Y - Phi * mu; alpha0 = (T * M + a - 1) / (norm(resid, 'fro')^2 + T / alpha0 * sum(gamma1) + b); alpha0seq(iter) = alpha0; % stopping criteria if norm(alpha - alpha_last)/norm(alpha_last) < tol || iter >= maxiter converged = true; iter_beta = 5; end temp = 0; for t = 1:T temp = temp + real(Y(:,t)' * Cinv * Y(:,t)); end ML(iter) = -T * real(log(det(C))) - temp + (a-1) * log(alpha0) - b * alpha0 - rho * sum(alpha); % update beta [temp, idx] = sort(alpha, 'descend'); idx = idx(1:K); % [peaks, idx] = findpeaks(alpha,'sortstr','descend'); % if length(idx) > K % idx = idx(1:K); % end temp = beta; beta = zeros(N,1); beta(idx) = temp(idx); P = real(conj(BHB(idx,idx)) .* (mu(idx,:) * mu(idx,:)' + T * Sigma(idx,idx))); v = zeros(length(idx), 1); for t = 1:T v = v + real(conj(mu(idx,t)) .* (B(:,idx)' * (Y(:,t) - A * mu(:,t)))); end v = v - T * real(diag(B(:,idx)' * A * Sigma(:,idx))); temp1 = P \ v; if any(abs(temp1) > r/2) || any(diag(P) == 0) for i = 1:iter_beta for n = 1:K temp_beta = beta(idx); temp_beta(n) = 0; beta(idx(n)) = (v(n) - P(n,:) * temp_beta) / P(n,n); if beta(idx(n)) > r/2 beta(idx(n)) = r/2; end if beta(idx(n)) < -r/2 beta(idx(n)) = -r/2; end if P(n,n) == 0 beta(idx(n)) = 0; end end end else beta = zeros(N,1); beta(idx) = temp1; end end res.mu = mu; res.Sigma = Sigma; res.beta = beta; res.alpha = alpha; res.iter = iter; res.ML = ML(1:iter); res.sigma2 = 1/alpha0; res.sigma2seq = 1./alpha0seq(1:iter); end