springboot下Filter的POST和GET過濾引數
阿新 • • 發佈:2019-01-14
//定義一個filter過濾器 import org.apache.commons.lang.StringUtils; import org.springframework.stereotype.Component; import org.apache.commons.lang.StringEscapeUtils; import javax.servlet.*; import javax.servlet.annotation.WebFilter; import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.util.Map; import java.util.Set; @Component @WebFilter(filterName = "ValidatorFilter" , urlPatterns = "/*") public class ValidatorFilter implements Filter { String[] strArr = {"\"","%","'"}; @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException{ HttpServletRequest httpServletRequest = (HttpServletRequest) request; String method = (httpServletRequest.getMethod()); Map<String, String[]> map = httpServletRequest.getParameterMap(); ServletRequest requestWrapper = null; GetParameterRequestWrapper requestWrapper1= null; if(httpServletRequest.getMethod().equals("POST")){ requestWrapper = new PostParameterRequestWrapper(httpServletRequest,method,map); chain.doFilter(requestWrapper, response); }else if(httpServletRequest.getMethod().equals("GET")){ requestWrapper1 = new GetParameterRequestWrapper((HttpServletRequest)request); Set<String> key = map.keySet(); for(String arr :strArr){ for(String k : key){ String[] arrValues = map.get(k); String newValues= StringUtils.join(arrValues); if(newValues.contains(arr)){ //對不合法引數轉義 String escape = StringEscapeUtils.escapeXml(arr); String s1 = newValues.replace(arr,escape); //重新put相同的key,替換對應的values requestWrapper1.addParameter(k, new String[]{s1}); } } } chain.doFilter(requestWrapper1, response); } } @Override public void destroy() { } @Override public void init(FilterConfig filterConfig) throws ServletException { } }
//get方式,修改請求域中的引數值,攔截不合法的引數,進行轉義 import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.util.*; class GetParameterRequestWrapper extends HttpServletRequestWrapper { private Map<String , String[]> params = new HashMap<String, String[]>(); @SuppressWarnings("unchecked") public GetParameterRequestWrapper(HttpServletRequest request) { super(request); this.params.putAll(request.getParameterMap()); } public GetParameterRequestWrapper(HttpServletRequest request , Map<String , Object> extendParams) { this(request); addAllParameters(extendParams); } @Override public String getParameter(String name) { String[] values = params.get(name); if (values == null || values.length == 0) { return null; } return values[0]; } public String[] getParameterValues(String name) { return params.get(name); } public void addAllParameters(Map<String , Object>otherParams) { for(Map.Entry<String , Object>entry : otherParams.entrySet()) { addParameter(entry.getKey() , entry.getValue()); } } public void addParameter(String name , Object value) { if(value != null) { if(value instanceof String[]) { params.put(name , (String[])value); }else if(value instanceof String) { params.put(name , new String[] {(String)value}); }else { params.put(name , new String[] {String.valueOf(value)}); } } } }
//post方式,修改請求域中的引數值,攔截不合法的引數,進行轉義 import org.apache.commons.lang.StringEscapeUtils; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.*; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; import java.util.Map; public class PostParameterRequestWrapper extends HttpServletRequestWrapper { private byte[] body; String[] strArr = {"\"","%","'"}; public PostParameterRequestWrapper(HttpServletRequest request, String method, Map<String, String[]> newParams) throws IOException { super(request); //獲取request域json型別引數 String param = getBodyString(request); //拆分json,引數屬性放一個List集合中 List<String> shuxing = new ArrayList<String>(); //拆分json,引數值放一個List集合中 List<String> values = new ArrayList<String>(); System.out.println("param "+param); if(param!= null && !param.equals("")){ String newParam = param.substring(1,param.length()-1); String[] arrParam = newParam.split(","); for(String arr : arrParam){ String[] newArr = arr.split(":"); //屬性 String par = newArr[0].trim(); if(par.contains("\"") && par.length()>2){ par = par.substring(1,par.length()-1); } shuxing.add(par); //值 if(newArr.length>1){ String par1 = newArr[1].trim(); if(par1.contains("\"") && par1.length()>2){ par1 = par1.substring(1,par1.length()-1); } values.add(par1); }else{ values.add(""); } } //對值裡面的不合法引數轉義 for(int i = 0;i<shuxing.size();i++){ for(String arr :strArr){ if(values.get(i).contains(arr)){ //對不合法引數values轉義 String newValues = StringEscapeUtils.escapeXml(arr); String s1 = values.get(i).replace(arr,newValues); values.set(i,s1); } } } StringBuffer bf =new StringBuffer(); //重組json字串 for(int k = 0;k<shuxing.size();k++){ if(k+1 != shuxing.size()){ bf.append("\""+shuxing.get(k)+"\""+":"+ "\""+ values.get(k)+"\""+","); }else{ bf.append("\""+shuxing.get(k)+"\""+":"+ "\""+values.get(k)+"\""); } } String sb = "{"+ bf.toString() +"}"; System.out.println("sb "+sb); body = sb.getBytes(Charset.forName("UTF-8")); } } /** * 獲取請求Body * * @param request * @return */ public String getBodyString(final ServletRequest request) { StringBuilder sb = new StringBuilder(); InputStream inputStream = null; BufferedReader reader = null; try { inputStream = cloneInputStream(request.getInputStream()); reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8"))); String line = ""; while ((line = reader.readLine()) != null) { sb.append(line); } } catch (IOException e) { e.printStackTrace(); } finally { if (inputStream != null) { try { inputStream.close(); } catch (IOException e) { e.printStackTrace(); } } if (reader != null) { try { reader.close(); } catch (IOException e) { e.printStackTrace(); } } } System.out.println("sb.toString " +sb.toString()); return sb.toString(); } /** * Description: 複製輸入流</br> * * @param inputStream * @return</br> */ public InputStream cloneInputStream(ServletInputStream inputStream) { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; int len; try { while ((len = inputStream.read(buffer)) > -1) { byteArrayOutputStream.write(buffer, 0, len); } byteArrayOutputStream.flush(); } catch (IOException e) { e.printStackTrace(); } InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray()); return byteArrayInputStream; } @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(getInputStream())); } @Override public ServletInputStream getInputStream() throws IOException { final ByteArrayInputStream bais = new ByteArrayInputStream(body); return new ServletInputStream() { @Override public int read() throws IOException { return bais.read(); } @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } }; } }