1. 程式人生 > >手寫SpringMVC

手寫SpringMVC

相信用過SpringMVC的同學都會對它愛不釋手,它作為MVC框架使用起來簡直就是享受。時間久了相信會問它到底是怎麼實現的呢,今天我們來揭開其神祕的面紗。

這裡我們通過寫一個簡單的例子來模擬SpringMVC的基本原理,希望能夠對愛提問的人有所幫助

1.web.xml中配置過濾器

<?xml version="1.0" encoding="UTF-8"?>
<web-app version="3.0" metadata-complete="false" xmlns="http://java.sun.com/xml/ns/javaee" xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://java.sun.com/xml/ns/javaee 
        http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd">

  <display-name>Web Application</display-name>

  <servlet>
    <servlet-name>springmvc</servlet-name>
    <servlet-class>com.cyq.DispatcherServlet</servlet-class>
    <init-param>
      <param-name>contextConfigLocation</param-name>
      <param-value>application.properties</param-value>
    </init-param>
    <load-on-startup>1</load-on-startup>
  </servlet>
  
  <servlet-mapping>
    <servlet-name>springmvc</servlet-name>
    <url-pattern>/*</url-pattern>
  </servlet-mapping>


</web-app>

2.建立SpringMVC中常用的註解類

@Target(value= ElementType.FIELD)
@Retention(value = RetentionPolicy.RUNTIME)
@Documented
public @interface GPAutowired {

    String value() default "";
}



@Target(value= ElementType.TYPE)
@Retention(value = RetentionPolicy.RUNTIME)
@Documented
public @interface GPController {

    String value() default "";
}


@Target(value= {ElementType.TYPE, ElementType.METHOD})
@Retention(value = RetentionPolicy.RUNTIME)
@Documented
public @interface GPRequestMapping {

    String value() default "";
}


@Target(value= {ElementType.PARAMETER})
@Retention(value = RetentionPolicy.RUNTIME)
@Documented
public @interface GPRequestParam {

    String value() default "";
}


@Target(value= ElementType.TYPE)
@Retention(value = RetentionPolicy.RUNTIME)
@Documented
public @interface GPService {

    String value() default "";
}

3.建立Controller、Service

@GPController
@GPRequestMapping("/demo")
public class DemoController {

    @GPAutowired
    private DemoService demoService;

    @GPRequestMapping("/query.json")
    public void query(HttpServletRequest request,
                      HttpServletResponse response,
                      @GPRequestParam("name") String name){

        String result = demoService.get(name);
        try {
            response.getWriter().write(result);
        } catch (IOException e) {
            e.printStackTrace();
        }


    }

    @GPRequestMapping("/add.json")
    public void add(HttpServletRequest request,
                      HttpServletResponse response,
                      @GPRequestParam("name") String name){

        String result = demoService.add(name);
        try {
            response.getWriter().write(result);
        } catch (IOException e) {
            e.printStackTrace();
        }


    }

}
@GPService
public class DemoServiceImpl implements DemoService{
    public String get(String name) {
        return "Hello " + name;
    }

    public String add(String name) {
        return "Add success " + name;
    }
}

4.建立DispatcherServlet,此為關鍵

public class DispatcherServlet extends HttpServlet {

    private Properties properties = new Properties();

    private List<String> classNames = new ArrayList<String>();

    private Map<String, Object> ioc = new HashMap<String, Object>();

    private Map<String, Method> handlerMapping = new HashMap<String, Method>();

    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        try {
            this.doDispatch(request, response);
        } catch (Exception e) {

            response.getWriter().write("500 Exception : \r\n" + Arrays.toString(e.getStackTrace())
            .replaceAll("\\[|\\]", "").replaceAll(",\\s", "\r\n"));

        }
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doPost(request, response);
    }

    private void doDispatch(HttpServletRequest request, HttpServletResponse response) throws IOException, InvocationTargetException, IllegalAccessException {
        if (handlerMapping.isEmpty()) {
            return;
        }

        String url = request.getRequestURI();
        String contextPath = request.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");

        if (!handlerMapping.containsKey(url)) {
            response.getWriter().write("404 Not Found!");
        }

        Map<String, String[]> parameterMap = request.getParameterMap();
        Method method = handlerMapping.get(url);
        String beanName = lowerFirstCase(method.getDeclaringClass().getSimpleName());
        method.invoke(ioc.get(beanName), request, response, parameterMap.get("name")[0]);

    }

    @Override
    public void init(ServletConfig servletConfig) {

        // 1
        doLoadConfig(servletConfig.getInitParameter("contextConfigLocation"));

        // 2
        doScanner((String) properties.getProperty("scanPackage"));

        // 3
        doInstance();

        // 4
        doAutowired();

        // 5
        initHandlerMapping();

        // 6

        System.out.println("springmvc stared OK");

    }

    private void initHandlerMapping() {
        if (ioc.isEmpty()){
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()){
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(GPController.class)) {
                continue;
            }

            String baseUrl = "";
            if (clazz.isAnnotationPresent(GPRequestMapping.class)) {
                GPRequestMapping gpRequestMapping = clazz.getAnnotation(GPRequestMapping.class);
                baseUrl = gpRequestMapping.value();
            }

            Method[] methods = clazz.getMethods();

            for (Method method : methods) {
                if (!method.isAnnotationPresent(GPRequestMapping.class)) {
                    continue;
                }

                GPRequestMapping gpRequestMapping = method.getAnnotation(GPRequestMapping.class);
                String url = ("/" + baseUrl + gpRequestMapping.value()).replaceAll("/+", "/");
                handlerMapping.put(url, method);

                System.out.println("mapped: " +url + "," + method);
            }


        }
    }

    private void doAutowired() {

        if (ioc.isEmpty()) {
            return;
        }

        for (Map.Entry<String, Object> entry : ioc.entrySet()){
            Field[] fields = entry.getValue().getClass().getDeclaredFields();

            for (Field field : fields) {
                if (!field.isAnnotationPresent(GPAutowired.class)) {
                    continue;
                }
                GPAutowired gpAutowired = field.getAnnotation(GPAutowired.class);
                String beanName = gpAutowired.value().trim();
                if ("".equals(beanName)) {
                    beanName = field.getType().getName();
                }
                field.setAccessible(true);

                try {
                    field.set(entry.getValue(), ioc.get(beanName));
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                    continue;
                }

            }
        }
    }

    private void doInstance() {

        if (classNames.size() == 0) {
            return;
        }

        try {
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                if (clazz.isAnnotationPresent(GPController.class)) {
                    String beanName = lowerFirstCase(clazz.getSimpleName());
                    ioc.put(beanName, clazz.newInstance());
                }
                if (clazz.isAnnotationPresent(GPService.class)) {
                    GPService service = clazz.getAnnotation(GPService.class);
                    String beanName = service.value();

                    if (!"".equals(beanName)) {
                        ioc.put(beanName, clazz.newInstance());
                        continue;
                    }

                    Class<?>[] interfaces = clazz.getInterfaces();
                    for (Class<?> inter : interfaces) {
                        ioc.put(inter.getName(), clazz.newInstance());
                    }
                } else {
                    continue;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private String lowerFirstCase(String className) {
        char[] chars = className.toCharArray();
        chars[0] += 32;
        return String.valueOf(chars);
    }

    private void doScanner(String scanPackage) {
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File dir = new File(url.getFile());
        for (File file : dir.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else {
                classNames.add(scanPackage + "." + file.getName().replace(".class", "").trim());
            }
        }

    }

    private void doLoadConfig(String contextConfigLocation) {
        InputStream resourceAsStream = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);

        try {
            properties.load(resourceAsStream);
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (resourceAsStream != null) {
                try {
                    resourceAsStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }

    }
}

6.新增配置檔案application.properties

scanPackage=com.cyq

當啟動工程後,就可以正常使用自己的SpringMVC了,上面的例子簡單粗略的模擬了SpringMVC的基本流程和原理,而實際的SpringMVC要複雜得多,建議自己多去看Spring原始碼。至此!