1. 传统的JDBC模式

在没有使用ORM框架时,我们基本都是通过JDBC进行数据库的操作,一般的逻辑代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
public static void main(String[] args) {
Connection connection = null;
PreparedStatement preparedStatement = null;
ResultSet resultSet = null;
try {
// 加载数据库驱动
Class.forName("com.mysql.jdbc.Driver");
// 通过驱动管理类获取数据库链接
connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/mybatis?
characterEncoding=utf-8", "root", "root");
// 定义sql语句?表示占位符
String sql = "select * from user where username = ?";
// 获取预处理statement
preparedStatement = connection.prepareStatement(sql);
// 设置参数,第一个参数为sql语句中参数的序号(从1开始),第二个参数为设置的参数值
preparedStatement.setString(1, "tom");
// 向数据库发出sql执行查询,查询出结果集
resultSet = preparedStatement.executeQuery();
// 遍历查询结果集
while (resultSet.next()) {
int id = resultSet.getInt("id");
String username = resultSet.getString("username");
// 封装User
user.setId(id);
user.setUsername(username);
}
System.out.println(user);
} catch (Exception e) {
e.printStackTrace();
} finally {
// 释放资源
if (resultSet != null) {
try {
resultSet.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
if (preparedStatement != null) {
try {
preparedStatement.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
}

通过对JDBC的代码进行分析,可以发现存在以下问题:

  1. 数据库连接创建、释放频繁造成系统资源浪费,从而影响系统性能;
  2. Sql语句在代码中硬编码,造成代码不易维护,实际应用中sql变化的可能较大,sql变动需要改变java代码;
  3. 使用preparedStatement向占位符传参存在硬编码,因为sql语句的where条件不一定,可能多也可能少,修改sql还要修改代码,系统不易维护;
  4. 对结果集解析存在硬编码(查询列名),sql变化导致解析代码变化,系统不易维护,如果能将数据库记录封装成pojo对象解析比较方便;

由于硬编码可以通过配置文件解决,针对上述的问题,所以我们可以想到如下的解决思路:

  1. 使用数据库连接池初始化连接资源;
  2. 将sql语句抽取到xml配置文件中;
  3. 使用反射、内省等底层技术,自动将实体与表进行属性与字段的自动映射;

根据上述的思路进行自定义的Mybatis框架编写。

2. 客户端开发

首先客户端要提供数据库的连接信息以及SQL的信息。
根据设计思路,这些信息通过配置文件来解决,所以在resources目录下创建sqlMapConfig.xml文件,内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
<configuration>
<!-- 数据库配置信息-->
<dataSource>
<property name="driver" value="com.mysql.jdbc.Driver"></property>
<property name="jdbcUrl" value="jdbc:mysql:///test"></property>
<property name="username" value="root"></property>
<property name="password" value="123456"></property>
</dataSource>

<mappers>
<mapper resource="UserMapper.xml"></mapper>
</mappers>
</configuration>

在这个文件中配置了要加载那些mapper.xml,这样只需要加载一次就完成了所有的加载。
在同个目录下创建映射配置文件UserMapper.xml,并定义了简单的增删改查的SQL。内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
<mapper namespace="com.ormtest.mapper.UserMapper">

<select id="selectAll" resultType="com.ormtest.pojo.User">
select * from user
</select>

<select id="selectOne" resultType="com.ormtest.pojo.User" parameterType="com.ormtest.pojo.User">
select * from user where id = #{id} and username = #{username}
</select>

<!--添加用户-->
<insert id="insertUser" parameterType="com.ormtest.pojo.User" >
insert into user values(#{id},#{username},#{password})
</insert>

<!--修改-->
<update id="updateUser" parameterType="com.ormtest.pojo.User">
update user set username = #{username} where id = #{id}
</update>

<!--删除-->
<delete id="deleteUser" parameterType="java.lang.Integer">
delete from user where id = #{id}
</delete>

</mapper>

同时也要生成对应的POJO对象,并定义Dao层接口,代码简单,此处就不粘贴了。
到此为止,客户的代码就编写完成了,等完成框架的编写之后,就可以进行测试了。

3. 自定义框架开发

开发的时候,我们根据逻辑流程一步步进行。

首先我们需要一个类来接受和加载核心配置文件:

1
2
3
4
5
6
7
8
public class Resources {
// 将xml配置文件加载成为字节流
public static InputStream getResourceAsSteam(String path) {
InputStream resourceAsStream = Resources.class.getClassLoader().getResourceAsStream(path);
return resourceAsStream;
}

}

核心配置文件加载完成之后,就需要对字节流进行解析,所以声明一个解析xml的类,最终这个类要返回一个SqlSessionFactory的对象:

1
2
3
4
5
6
7
8
9
10
11
12
public class SqlSessionFactoryBuilder {
// 使用构造器模式,将复杂对象进行逐步构建
public SqlSessionFactory build(InputStream in) throws Exception {
// 使用dom4j读取字节流(也就是核心配置文件)的内容
// 封装成一个configuration对象
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder();
Configuration configuration = xmlConfigBuilder.parseConfig(in);
// 根据配置信息创建sqlSession的工厂
SqlSessionFactory sqlSessionFactory = new DefaultSqlSessionFactory(configuration);
return sqlSessionFactory;
}
}

在这个解析的过程中,声明了一些解析类,比如XMLConfigBuilder进行核心配置文件的解析,里面还嵌套调用了mapper.xml的解析类,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public class XMLConfigBuilder {

private Configuration configuration;

private static final String DRIVER = "driver";
private static final String JDBCURL = "jdbcUrl";
private static final String USERNAME = "username";
private static final String PASSWORD = "password";

public XMLConfigBuilder() {
this.configuration = new Configuration();
}

/**
* 通过dom4j进行字节流解析
* @param inputStream
* @return
*/
public Configuration parseConfig(InputStream inputStream) throws DocumentException, PropertyVetoException {
// 借助dom4j,进行解析,得到整个的文档对象
Document document = new SAXReader().read(inputStream);
// 得到根对象,即configuration标签
Element rootElement = document.getRootElement();
// 获取property标签,进行数据库连接信息的解析加载
List<Element> elementList = rootElement.selectNodes("//property");
// 借助properties对象,进行属性保存
Properties properties = new Properties();
for (Element element : elementList) {
String name = element.attributeValue("name");
String value = element.attributeValue("value");
properties.setProperty(name, value);
}

// 进行数据库连接信息封装,使用C3P0连接池
ComboPooledDataSource comboPooledDataSource = new ComboPooledDataSource();
comboPooledDataSource.setDriverClass(properties.getProperty(DRIVER));
comboPooledDataSource.setJdbcUrl(properties.getProperty(JDBCURL));
comboPooledDataSource.setUser(properties.getProperty(USERNAME));
comboPooledDataSource.setPassword(properties.getProperty(PASSWORD));

// 将数据源进行保存
configuration.setDataSource(comboPooledDataSource);

// 进行mapper.xml的解析工作
// 首先得到需要加载的xml
List<Element> mappers = rootElement.selectNodes("//mapper");

for (Element element : mappers) {
// 得到需要加载的mapper文件路径
String mapperPath = element.attributeValue("resource");
// 对mapper文件进行解析,得到mapperstatement
InputStream resourceAsSteam = Resources.getResourceAsSteam(mapperPath);
XMLMapperBuilder xmlMapperBuilder = new XMLMapperBuilder(configuration);
xmlMapperBuilder.parseMapper(resourceAsSteam);
}

// 返回整体配置对象
return this.configuration;
}
}

通过解析mapper.xml,生成对应的mappedStatement对象,存入Configuration对象中,因为全局只有一个Configuration对象,所以每次都把这个对象进行传递。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public class XMLMapperBuilder {

private Configuration configuration;

public XMLMapperBuilder(Configuration configuration) {
this.configuration = configuration;
}

/**
* 使用dom4j解析mapper.xml
* @param inputStream
*/
public void parseMapper(InputStream inputStream) throws DocumentException {
Document document = new SAXReader().read(inputStream);
Element rootElement = document.getRootElement();
// 得到当前mapper的namespace
String namespace = rootElement.attributeValue("namespace");
//递归遍历当前节点所有的子节点
List<Element> elementList = rootElement.elements();
for (Element element : elementList) {
String id = element.attributeValue("id");
String resultType = element.attributeValue("resultType");
String parameterType = element.attributeValue("parameterType");
String sqlText = element.getTextTrim();
MappedStatement mappedStatement = new MappedStatement();
mappedStatement.setId(id);
mappedStatement.setResultType(resultType);
mappedStatement.setParameterType(parameterType);
mappedStatement.setSql(sqlText);
// 增加sql类型字段
mappedStatement.setSqlCommandType(element.getName());
// mapper的namespace和SQL语句的id,组成唯一id
String key = namespace+"."+id;
configuration.getMappedStatementMap().put(key,mappedStatement);
}
}
}

到此,所有的配置文件都加载完成,最终根据生成的Configuration对象,返回对应的SqlSessionFactory对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
public class DefaultSqlSessionFactory implements SqlSessionFactory{
// 通过构造方法注入,保证从上到下只有一个configuration对象
private Configuration configuration;

public DefaultSqlSessionFactory(Configuration configuration) {
this.configuration = configuration;
}

@Override
public SqlSession openSession() {
return new DefaultSqlSession(configuration);
}
}

当客户端调用openSession方法时,就会返回一个SqlSession对象,该对象里面封装了增删改查的方法,供客户端调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
public class DefaultSqlSession implements SqlSession{

private Configuration configuration;

public DefaultSqlSession(Configuration configuration) {
this.configuration = configuration;
}

@Override
public Object doQuery(String statementId, boolean resultTypeFlag, Object... params) throws Exception {
// 根据statementId得到要执行的SQL对象
MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId);
// 判断当前的SQL类型是什么
switch (mappedStatement.getSqlCommandType()) {
case "select": {
if (resultTypeFlag) {
return selectList(statementId, params);
} else {
return selectOne(statementId, params);
}
}
case "insert":{
return insert(statementId, params);
}
case "update":{
return update(statementId, params);
}
case "delete":{
return delete(statementId, params);
}
}
return null;
}

@Override
public <E> List<E> selectList(String statementId, Object... params) throws Exception {
// 根据statementId得到要执行的SQL对象
MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId);
Executor executor = new SimpleExecutor();
List<Object> result = executor.query(configuration, mappedStatement, params);
return (List<E>) result;
}

@Override
public <T> T selectOne(String statementId, Object... params) throws Exception {
List<Object> objects = this.selectList(statementId, params);
if (1 == objects.size()) {
return (T) objects.get(0);
} else {
throw new RuntimeException("查询结果为空或查询结果过多");
}
}

@Override
public int insert(String statementId, Object... params) throws Exception {
MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId);
Executor executor = new SimpleExecutor();
executor.updateDatabase(configuration, mappedStatement, params);
return 1;
}

@Override
public int update(String statementId, Object... params) throws Exception {
MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId);
Executor executor = new SimpleExecutor();
executor.updateDatabase(configuration, mappedStatement, params);
return 1;
}

@Override
public int delete(String statementId, Object... params) throws Exception {
MappedStatement mappedStatement = configuration.getMappedStatementMap().get(statementId);
Executor executor = new SimpleExecutor();
executor.updateDatabase(configuration, mappedStatement, params);
return 1;
}

@Override
public <T> T getMapper(Class<?> mapperClass) {
// 根据JDK动态代理生成代理对象,对方法进行加工
Object proxyInstance = Proxy.newProxyInstance(DefaultSqlSession.class.getClassLoader(), new Class[]{mapperClass}, new InvocationHandler() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 这里无法获取mapper.xml的信息,所以为了方便识别,需要将namespace和SQL语句的id与接口的全限定名和方法保持一致
// 获取方法名
String methodName = method.getName();
// 获取所属的接口class名称
String className = method.getDeclaringClass().getName();
// 得到唯一的statementId
String statementId = className + "." + methodName;
// 根据方法的返回结果类型进行判断
Type genericReturnType = method.getGenericReturnType();
if(genericReturnType instanceof ParameterizedType){
return doQuery(statementId, true, args);
}
return doQuery(statementId, false, args);
}
});

return (T) proxyInstance;
}
}

最终的增删改查操作实际上还是通过JDBC来实现,所以声明一个Executor类来专门执行具体操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
public class SimpleExecutor implements Executor{

@Override
public <E> List<E> query(Configuration configuration, MappedStatement mappedStatement, Object... params) throws Exception {
PreparedStatement preparedStatement = getPreparedStatement(configuration, mappedStatement, params);
// 6 执行SQL,得到结果集
ResultSet resultSet = preparedStatement.executeQuery();
// 7 对结果集进行转换
List<Object> resultList = new ArrayList<>();
// 7.1 得到返回结果的类型
String resultType = mappedStatement.getResultType();
// 7.2 转换为类
Class<?> resultTypeClass = getClassType(resultType);
// 7.3 遍历结果集,逐个进行转换
while (resultSet.next()) {
// 声明返回类
Object o = resultTypeClass.newInstance();
// 获取元数据
ResultSetMetaData metaData = resultSet.getMetaData();
// 此处是从1开始
for (int i = 1; i <= metaData.getColumnCount(); i++) {
// 字段名
String columnName = metaData.getColumnName(i);
// 对应字段的值
Object value = resultSet.getObject(columnName);
//使用内省,根据数据库表字段和实体属性的对应关系,完成封装
PropertyDescriptor propertyDescriptor = new PropertyDescriptor(columnName, resultTypeClass);
// 获取写方法,进行值写入
Method writeMethod = propertyDescriptor.getWriteMethod();
writeMethod.invoke(o,value);
}
// 将本次转换的结果加入返回结果中
resultList.add(o);
}

return (List<E>) resultList;
}

@Override
public boolean updateDatabase(Configuration configuration, MappedStatement mappedStatement, Object... params) throws ClassNotFoundException, SQLException, IllegalAccessException, NoSuchFieldException {
PreparedStatement preparedStatement = getPreparedStatement(configuration, mappedStatement, params[0]);
boolean execute = preparedStatement.execute();
return execute;
}

private PreparedStatement getPreparedStatement(Configuration configuration, MappedStatement mappedStatement, Object... params) throws SQLException, ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
// 执行jdbc过程
// 1 注册驱动获取连接,直接获取C3P0连接池里的连接
Connection connection = configuration.getDataSource().getConnection();
// 2 获取要执行的SQL
String sql = mappedStatement.getSql();
// 3 对SQL中存在的参数进行提取和转换
BoundSql boundSql = getBoundSql(sql);
// 4 获取预处理对象
PreparedStatement preparedStatement = connection.prepareStatement(boundSql.getSqlText());
// 5 设置参数
// 5.1 获取参数类型
String parameterType = mappedStatement.getParameterType();
// 5.2 根据类型获取类
Class<?> parameterTypeClass = getClassType(parameterType);
// 5.3 得到解析的参数列表
List<ParameterMapping> parameterMappingList = boundSql.getParameterMappingList();
for (int i = 0; i < parameterMappingList.size(); i++) {
ParameterMapping parameterMapping = parameterMappingList.get(i);
// 得到参数值,即#{id}中的id
String content = parameterMapping.getContent();
Field declaredField = null;
// 判断是否是基本数据类型或者其包装类
if (isCommonDataType(parameterTypeClass) || isWrapClass(parameterTypeClass)) {
declaredField = parameterTypeClass.getDeclaredField("value");
} else {
declaredField = parameterTypeClass.getDeclaredField(content);
}
// 设置权限暴力访问,防止属性私有不让访问
declaredField.setAccessible(true);
// 得到对应的值
Object o = declaredField.get(params[0]);
// 将参数拼接到SQL上
preparedStatement.setObject(i + 1, o);
}
return preparedStatement;
}

/**
* 判断当前类型是否是基本数据类型
* @param clazz
* @return
*/
private Boolean isCommonDataType(Class clazz){
return clazz.isPrimitive();
}

private boolean isWrapClass(Class clazz){
try {
return ((Class) clazz.getField("TYPE").get(null)).isPrimitive();
} catch (Exception e) {
return false;
}
}

/**
* 根据类型获取对应的class
* @param type
* @return
*/
private Class<?> getClassType(String type) throws ClassNotFoundException {
if(type != null){
Class<?> clazz = Class.forName(type);
return clazz;
}
return null;
}

/**
* 对mapper中的原SQL进行解析和替换
* 由于jdbc只认识?占位符,所以要把#{id}进行替换
* 同时要得到其中的id,用于定位获取参数
* @param sql
* @return
*/
private BoundSql getBoundSql(String sql) {
// 标记处理类
ParameterMappingTokenHandler parameterMappingTokenHandler = new ParameterMappingTokenHandler();
// 第一个参数是开始标记、第二个是结束标记、第三个是使用那个处理类
GenericTokenParser genericTokenParser = new GenericTokenParser("#{", "}", parameterMappingTokenHandler);
// 得到处理后的SQL(参数已经变为?)
String formatSql = genericTokenParser.parse(sql);
// 处理过程中,处理类已经将参数中的值进行了存储,直接获取即可
List<ParameterMapping> parameterMappings = parameterMappingTokenHandler.getParameterMappings();
// 通过构造方法进行赋值
BoundSql boundSql = new BoundSql(formatSql, parameterMappings);
return boundSql;
}
}

还有一些基础的POJO这里就不再粘贴了,到此,框架的代码就基本开发完毕,下面进行测试。

4. 测试

在客户端里声明测试类,进行如下的代码编写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public class MyTest {
private SqlSession sqlSession;

// 在测试方法执行前执行
@Before
public void prepare() throws Exception {
// 将配置文件进行加载,得到字节流
InputStream resourceAsSteam = Resources.getResourceAsSteam("sqlMapConfig.xml");
// 此次加载会将xml文件里面的内容加载到框架内
// 通过字节流,得到sqlSession工厂
SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(resourceAsSteam);
// 通过工厂的openSession方法,生成一个session
sqlSession = sqlSessionFactory.openSession();
}

// 不使用代理的模式
@Test
public void test() throws Exception {
List<User> userList = sqlSession.selectList("com.ormtest.mapper.UserMapper.selectAll");
for (User user : userList) {
System.out.println(user);
}
}

// 使用代理模式
@Test
public void test2() throws Exception {
UserMapper userMapper = sqlSession.getMapper(UserMapper.class);
List<User> userList = userMapper.selectAll();
for (User user : userList) {
System.out.println(user);
}
}
}

最终测试通过,查询结果正确。

至此,自定义Mybatis的简单实现就完成了~~