300行程式碼模擬spring核心原理
一,實現思路
1,配置階段
配置web.xml | DispatcherServlet |
---|---|
設定init-param | contextConfigLocation=classpath:application.xml |
設定url-pattern | /* |
配置Annotation | @Controller @Service @Autowrited @RequestMapping |
2,初始化階段
呼叫init方法 | 載入配置檔案 |
---|---|
IOC容器初始化 | MAP |
掃描相關的類 | scan-package="" |
建立例項化並儲存至容器 | 通過反射機制將類例項化放入IOC容器 |
進行DI操作 | 掃描IOC容器的例項,給沒有賦值的屬性自動填充 |
初始化HandlerMapping | 講一個URL和一個Method進行一對一的對映 |
3,執行階段
呼叫doGet/doPost | web容器呼叫doget、dopost,獲取req和resp物件 |
---|---|
匹配HandlerMapping | 從req物件獲取輸入的URL,找到其對應的method |
反射呼叫method.invoker() | 利用反射呼叫方法並返回結果 |
response.getWrite().write() | 將返回結果輸出到瀏覽器 |
二,自定義配置
1,配置 application.properties 檔案
為了解析方便,用 application.properties 來代替 application.xml 檔案,具體配置內容如下:
scanPackage=com.yhd.spring01
2,配置web.xml檔案
<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://java.sun.com/xml/ns/j2ee" xmlns:javaee="http://java.sun.com/xml/ns/javaee"
xmlns: web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd"
xsi:schemaLocation="http://java.sun.com/xml/ns/j2ee
http://java.sun.com/xml/ns/j2ee/web-app_2_4.xsd"
version="2.4">
<display-name>YHD Web Application</display-name>
<servlet>
<servlet-name>yhdmvc</servlet-name>
<servlet-class>com.yhd.spring01.servlet.HdDispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>application.properties</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
</servlet>
<servlet-mapping>
<servlet-name>yhdmvc</servlet-name>
<url-pattern>/*</url-pattern>
</servlet-mapping>
</web-app>
3,自定義註解
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdAutowired {
String value() default "";
}
import java.lang.annotation.*;
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdController {
String value() default "";
}
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdRequestMapping {
String value() default "";
}
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdRequestParam {
String value() default "";
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface HdService {
String value() default "";
}
4,編寫模擬業務
@HdService
public class DemoService implements IDemoService {
@Override
public String get(String name) {
return "My name is " + name;
}
}
@HdController
@HdRequestMapping("/demo")
public class DemoController {
@HdAutowired
private IDemoService demoService;
@HdRequestMapping("/query")
public void query(HttpServletRequest req, HttpServletResponse resp,
@HdRequestParam("name") String name){
String result = demoService.get(name);
try {
resp.getWriter().write(result);
} catch (IOException e) {
e.printStackTrace();
}
}
@HdRequestMapping("/add")
public void add(HttpServletRequest req, HttpServletResponse resp,
@HdRequestParam("a") Integer a, @HdRequestParam("b") Integer b){
try {
resp.getWriter().write(a + "+" + b + "=" + (a + b));
} catch (IOException e) {
e.printStackTrace();
}
}
@HdRequestMapping("/remove")
public void remove(HttpServletRequest req,HttpServletResponse resp,
@HdRequestParam("id") Integer id){
}
}
三,容器初始化
1.0版本
流程分析
1.首先在doGet方法裡面呼叫doDispatcher方法,根據請求路徑判斷路徑是否存在,如果不存在就返回404存在就從容器中拿到路徑對應的方法,通過動態代理執行對應的方法
2.在類載入階段,用流來載入配置檔案,從配置檔案讀取配置的包掃描路徑根據包掃描路徑進行迭代遍歷,利用反射建立所有類上標有controller註解的類加入到容器,並下鑽到類中,將類中每一個方法的絕對訪問路徑和方法加入到容器,迭代遍歷建立所有標有service註解的類,如果該類實現了介面,將該介面的全限定型別名和類例項物件也放入容器,達到根據介面注入的效果。
3.屬性賦值,遍歷容器中所有類,如果類中標有@autowried註解,將屬性對應的值設定進去。
重要方法
1.clazz.isAnnotationPresent(HdController.class)
判斷clazz上有沒有HdController註解
2.field.set(mappings.get(clazz.getName()), mappings.get(beanName));
屬性賦值:args1:給哪個屬性設值,args2:設定的什麼值
3.method.invoke(mappings.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get(“name”)[0]});
通過動態代理執行方法,方法所在類名,方法引數
程式碼
/**
* @author yhd
* @createtime 2021/1/31 15:49
* @description 模擬IOC容器的建立
*/
public class HdDispatcherServlet extends HttpServlet {
//對映關係 訪問路徑-方法名 全限定類名-例項物件
private Map<String, Object> mappings = new HashMap<>();
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
this.doPost(req, resp);
}
@SneakyThrows
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doDispatch(req, resp);
}
private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
//組裝路徑
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "").replaceAll("/+", "/");
//判斷路徑是否存在
if (!this.mappings.containsKey(url)) {
resp.getWriter().write("404 NotFound!");
return;
}
//獲取路徑對應的方法引數,通過動態代理進行增強
Method method = (Method) this.mappings.get(url);
Map<String, String[]> params = req.getParameterMap();
method.invoke(mappings.get(method.getDeclaringClass().getName()), new Object[]{req, resp, params.get("name")[0]});
}
@Override
public void init(ServletConfig config) throws ServletException {
InputStream is = null;
try {
//載入配置檔案
Properties configContext = new Properties();
is = this.getClass().getClassLoader().getResourceAsStream(config.getInitParameter("contextConfigLocation"));
configContext.load(is);
//獲取掃描路徑
String scanPackage = configContext.getProperty("scanPackage");
doScanner(scanPackage);
for (String className : mappings.keySet()) {
if (!className.contains(".")) {
continue;
}
Class<?> clazz = Class.forName(className);
//當前這個類上有沒有controller註解
if (clazz.isAnnotationPresent(HdController.class)) {
mappings.put(className, clazz.newInstance());
String baseUrl = "";
//判斷有沒有一級訪問路徑
if (clazz.isAnnotationPresent(HdRequestMapping.class)) {
HdRequestMapping requestMapping = clazz.getAnnotation(HdRequestMapping.class);
baseUrl = requestMapping.value();
}
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (!method.isAnnotationPresent(HdRequestMapping.class)) {
continue;
}
HdRequestMapping requestMapping = method.getAnnotation(HdRequestMapping.class);
//拼裝路徑
String url = (baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/");
//map放的是:controller裡面一個方法的訪問絕對路徑,這個對應的方法
mappings.put(url, method);
System.out.println("Mapped " + url + "," + method);
}
} else if (clazz.isAnnotationPresent(HdService.class)) {
HdService service = clazz.getAnnotation(HdService.class);
String beanName = service.value();
if ("".equals(beanName)) {
beanName = clazz.getName();
}
Object instance = clazz.newInstance();
//map裡面放的是類名和例項物件
mappings.put(beanName, instance);
//將這個類實現的介面和例項物件放進去
for (Class<?> i : clazz.getInterfaces()) {
mappings.put(i.getName(), instance);
}
} else {
continue;
}
}
//屬性注入
for (Object object : mappings.values()) {
if (object == null) {
continue;
}
Class clazz = object.getClass();
if (clazz.isAnnotationPresent(HdController.class)) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(HdAutowired.class)) {
continue;
}
HdAutowired autowired = field.getAnnotation(HdAutowired.class);
String beanName = autowired.value();
if ("".equals(beanName)) {
beanName = field.getType().getName();
}
field.setAccessible(true);
try {
field.set(mappings.get(clazz.getName()), mappings.get(beanName));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
System.out.print("Diy MVC Framework is init");
} catch (Exception e) {
}
}
private void doScanner(String scanPackage) {
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File classDir = new File(url.getFile());
Arrays.stream(classDir.listFiles()).forEach(file -> {
if (file.isDirectory()) {
doScanner(scanPackage + "." + file.getName());
} else {
if (!file.getName().endsWith(".class")) {
return;
}
String clazzName = (scanPackage + "." + file.getName().replace(".class", ""));
mappings.put(clazzName, null);
}
});
}
}
2.0版本
分析
1.0版本的所有程式碼都寫在了一個方法裡面,程式碼耦合度 十分高,不符合開發規範
思路
採用設計模式(工廠模式、單例模式、委派模式、策略模式),改造業務邏輯。
程式碼
/**
* @author yhd
* @createtime 2021/2/1 11:29
*/
public class HdDispatcherServlet2 extends HttpServlet {
private Map<String, Object> ioc = new ConcurrentHashMap<>();
private Map<String, Method> handlerMappings = new ConcurrentHashMap<>();
private List<String> classNames = new CopyOnWriteArrayList<>();
private Properties configContext = new Properties();
private static final String CONFIG_LOCATION = "contextConfigLocation";
@Override
public void init(ServletConfig config) throws ServletException {
//1.載入配置檔案
loadConfig(config.getInitParameter(CONFIG_LOCATION));
//2.掃描所有的元件
doScanPackages(configContext.getProperty("scanPackage"));
//3.將元件加入到容器
refersh();
//4.屬性設值
population();
//5.建立方法與路徑的對映
routingAndMapping();
}
/**
* 建立方法與路徑的對映
*/
private void routingAndMapping() {
classNames.forEach(className -> {
Object instance = ioc.get(className);
if (instance.getClass().isAnnotationPresent(HdController.class)) {
String baseUrl = "";
if (instance.getClass().isAnnotationPresent(HdRequestMapping.class)) {
baseUrl += instance.getClass().getAnnotation(HdRequestMapping.class).value().trim();
}
String finalBaseUrl = baseUrl;
Arrays.asList(instance.getClass().getDeclaredMethods()).forEach(method -> {
if (method.isAnnotationPresent(HdRequestMapping.class)) {
String methodUrl = finalBaseUrl;
methodUrl += method.getAnnotation(HdRequestMapping.class).value().trim();
handlerMappings.put(methodUrl, method);
}
});
}
});
}
/**
* 屬性設值
*/
private void population() {
Set<String> keySet = ioc.keySet();
keySet.forEach(key -> {
Field[] fields = ioc.get(key).getClass().getFields();
Arrays.asList(fields).forEach(field -> {
if (field.isAnnotationPresent(HdAutowired.class)) {
HdAutowired autowired = field.getAnnotation(HdAutowired.class);
String name = autowired.value().trim();
if ("".equals(autowired.value().trim())) {
name = field.getType().getName();
}
try {
field.setAccessible(true);
field.set(name, ioc.get(name));
} catch (IllegalAccessException e) {
}
}
});
});
}
/**
* 容器重新整理
* 元件加入到容器中
*/
@SneakyThrows
private void refersh() {
if (classNames == null || classNames.isEmpty()) {
throw new RuntimeException("元件掃描出現異常!");
}
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
if (clazz.isAnnotationPresent(HdController.class)) {
//TODO 類名處理
ioc.put(clazz.getSimpleName(), clazz.newInstance());
} else if (clazz.isAnnotationPresent(HdService.class)) {
Object instance = clazz.newInstance();
ioc.put(clazz.getSimpleName(), instance);
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> inter : interfaces) {
ioc.put(inter.getSimpleName(), clazz);
}
} else {
continue;
}
}
}
/**
* 元件掃描
*
* @param scanPackage
*/
private void doScanPackages(String scanPackage) {
URL url = getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File files = new File(url.getFile());
for (File file : files.listFiles()) {
if (file.isDirectory()) {
doScanPackages(scanPackage + "." + file.getName());
} else {
if (!file.getName().endsWith(".class")) {
continue;
}
String className = scanPackage + "." + file.getName().replace(".class", "");
classNames.add(className);
}
}
}
/**
* 載入配置檔案
*
* @param initParameter
*/
@SneakyThrows
private void loadConfig(String initParameter) {
InputStream is = getClass().getClassLoader().getResourceAsStream(initParameter);
configContext.load(is);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
doDispatcher(req, resp);
} catch (Exception e) {
throw new RuntimeException(" 500 server error!");
}
}
@SneakyThrows
private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {
String realPath = req.getRequestURI().replace(req.getContextPath(), "");
Map<String, String[]> parameterMap = req.getParameterMap();
if (!handlerMappings.containsKey(realPath)) {
throw new RuntimeException("404 Not Found!");
}
Method method = handlerMappings.get(realPath);
Class<?>[] parameterTypes = method.getParameterTypes();
Object[] paramValues = new Object[parameterTypes.length];
for (int i = 0; i < parameterTypes.length - 1; i++) {
Class param = parameterTypes[i];
if (param == HttpServletRequest.class) {
paramValues[i] = req;
}
if (param == HttpServletResponse.class) {
paramValues[i] = resp;
}
if (param == String.class) {
HdRequestParam requestParam = parameterTypes[i].getAnnotation(HdRequestParam.class);
String value = requestParam.value();
String[] realParam = parameterMap.get(value);
paramValues[i] = Arrays.toString(realParam)
.replaceAll("\\[|\\]", "")
.replaceAll("\\s", ",");
}
}
method.invoke(method.getDeclaringClass().getSimpleName(), paramValues);
}
private Object convertParamType() {
return null;
}
}
3.0版本
分析
HandlerMapping還不能像SpringMVC一樣支援正則,url引數還不支援強制型別轉換,反射呼叫之前還需要重新獲取bean的name。
改造 HandlerMapping,在真實的 Spring 原始碼中,HandlerMapping 其實是一個 List 而非 Map。List 中的元素是一個自定義的型別。
思路
使用內部類維護requestMapping和url之間的關係。
程式碼
public class HdDispatcherServlet3 extends HttpServlet {
private Map<String, Object> ioc = new ConcurrentHashMap<>();
private Map<String, Method> handlerMappings = new ConcurrentHashMap<>();
private List<String> classNames = new CopyOnWriteArrayList<>();
private Properties configContext = new Properties();
private static final String CONFIG_LOCATION = "contextConfigLocation";
private List<Handler> handlerMapping = new ArrayList<>();
/**
*
*/
@Data
private class Handler {
//儲存方法對應的例項
private Object controller;
//儲存對映的方法
private Method method;
//正則匹配
private Pattern pattern;
//引數順序
private Map<String, Integer> paramIndexMapping = new ConcurrentHashMap<>();
public Handler(Pattern pattern, Object controller, Method method) {
this.controller = controller;
this.method = method;
this.pattern = pattern;
paramIndexMapping = new HashMap<String, Integer>();
putParamIndexMapping(method);
}
private void putParamIndexMapping(Method method) {
//提取方法中加了註解的引數
Annotation[][] pa = method.getParameterAnnotations();
for (int i = 0; i < pa.length; i++) {
for (Annotation a : pa[i]) {
if (a instanceof HdRequestParam) {
String paramName = ((HdRequestParam) a).value();
if (!"".equals(paramName.trim())) {
paramIndexMapping.put(paramName, i);
}
}
}
}
//提取方法中的req和resp
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> type = parameterTypes[i];
if (type == HttpServletRequest.class ||
type == HttpServletResponse.class) {
paramIndexMapping.put(type.getName(), i);
}
}
}
}
@Override
public void init(ServletConfig config) throws ServletException {
//1.載入配置檔案
loadConfig(config.getInitParameter(CONFIG_LOCATION));
//2.掃描所有的元件
doScanPackages(configContext.getProperty("scanPackage"));
//3.將元件加入到容器
refersh();
//4.屬性設值
population();
//5.建立方法與路徑的對映
routingAndMapping();
}
/**
* 建立方法與路徑的對映
*/
private void routingAndMapping() {
if (ioc.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Class<?> clazz = entry.getValue().getClass();
if (!clazz.isAnnotationPresent(HdController.class)) {
continue;
}
String url = "";
if (clazz.isAnnotationPresent(HdRequestMapping.class)) {
HdRequestMapping requestMapping = clazz.getAnnotation(HdRequestMapping.class);
url = requestMapping.value();
}
for (Method method : clazz.getMethods()) {
if (!method.isAnnotationPresent(HdRequestMapping.class)) {
continue;
}
HdRequestMapping requestMapping = method.getAnnotation(HdRequestMapping.class);
String regex = ("/" + url + requestMapping.value()).replaceAll("/+", "/");
Pattern pattern = Pattern.compile(regex);
handlerMapping.add(new Handler(pattern, entry.getValue(), method));
}
}
}
/**
* 屬性設值
*/
private void population() {
Set<String> keySet = ioc.keySet();
keySet.forEach(key -> {
Field[] fields = ioc.get(key).getClass().getFields();
Arrays.asList(fields).forEach(field -> {
if (field.isAnnotationPresent(HdAutowired.class)) {
HdAutowired autowired = field.getAnnotation(HdAutowired.class);
String name = autowired.value().trim();
if ("".equals(autowired.value().trim())) {
name = field.getType().getName();
}
try {
field.setAccessible(true);
field.set(name, ioc.get(name));
} catch (IllegalAccessException e) {
}
}
});
});
}
/**
* 容器重新整理
* 元件加入到容器中
*/
@SneakyThrows
private void refersh() {
if (classNames == null || classNames.isEmpty()) {
throw new RuntimeException("元件掃描出現異常!");
}
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
if (clazz.isAnnotationPresent(HdController.class)) {
//TODO 類名處理
ioc.put(clazz.getSimpleName(), clazz.newInstance());
} else if (clazz.isAnnotationPresent(HdService.class)) {
Object instance = clazz.newInstance();
ioc.put(clazz.getSimpleName(), instance);
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> inter : interfaces) {
ioc.put(inter.getSimpleName(), clazz);
}
} else {
continue;
}
}
}
/**
* 元件掃描
*
* @param scanPackage
*/
private void doScanPackages(String scanPackage) {
URL url = getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File files = new File(url.getFile());
for (File file : files.listFiles()) {
if (file.isDirectory()) {
doScanPackages(scanPackage + "." + file.getName());
} else {
if (!file.getName().endsWith(".class")) {
continue;
}
String className = scanPackage + "." + file.getName().replace(".class", "");
classNames.add(className);
}
}
}
/**
* 載入配置檔案
*
* @param initParameter
*/
@SneakyThrows
private void loadConfig(String initParameter) {
InputStream is = getClass().getClassLoader().getResourceAsStream(initParameter);
configContext.load(is);
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
try {
doDispatcher(req, resp);
} catch (Exception e) {
throw new RuntimeException(" 500 server error!");
}
}
@SneakyThrows
private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) {
Handler handler = getHandler(req);
if (handler == null) {
throw new RuntimeException("404 Not Found!");
}
Class<?>[] parameterTypes = handler.getMethod().getParameterTypes();
Object[] paramValues = new Object[parameterTypes.length];
Map<String, String[]> params = req.getParameterMap();
for (Map.Entry<String, String[]> param : params.entrySet()) {
String value = Arrays.toString(param.getValue()).replaceAll("\\[|\\]", "")
.replaceAll("\\s", ",");
if (!handler.getParamIndexMapping().containsKey(param.getKey())) {
continue;
}
Integer index = handler.getParamIndexMapping().get(param.getKey());
paramValues[index] = this.convert(parameterTypes[index], value);
}
if (handler.paramIndexMapping.containsKey(HttpServletRequest.class.getName())) {
int reqIndex = handler.paramIndexMapping.get(HttpServletRequest.class.getName());
paramValues[reqIndex] = req;
}
if (handler.paramIndexMapping.containsKey(HttpServletResponse.class.getName())) {
int respIndex = handler.paramIndexMapping.get(HttpServletResponse.class.getName());
paramValues[respIndex] = resp;
}
Object returnValue = handler.getMethod().invoke(handler.getController(), paramValues);
if (returnValue == null || returnValue instanceof Void) {
return;
}
resp.getWriter().write(returnValue.toString());
}
private Object convert(Class<?> parameterType, String value) {
if (Integer.class == parameterType) {
return Integer.parseInt(value);
}
return value;
}
private Handler getHandler(HttpServletRequest req) {
if (handlerMapping.isEmpty()) {
return null;
}
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replace(contextPath, "")
.replaceAll("/+", "/");
for (Handler handler : handlerMapping) {
try {
Matcher matcher = handler.pattern.matcher(url);
//如果沒有匹配上繼續下一個匹配
if (!matcher.matches()) {
continue;
}
return handler;
} catch (Exception e) {
throw e;
}
}
return null;
}
}