1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.kiy0taka.dbunit;
17
18 import static org.kiy0taka.dbunit.DataSetBuilder.dataSet;
19
20 import java.io.FileNotFoundException;
21 import java.io.IOException;
22 import java.lang.reflect.Field;
23 import java.net.URL;
24 import java.security.AccessController;
25 import java.security.PrivilegedAction;
26 import java.sql.Connection;
27 import java.sql.PreparedStatement;
28 import java.sql.SQLException;
29 import java.util.ArrayList;
30 import java.util.HashSet;
31 import java.util.List;
32 import java.util.Locale;
33 import java.util.MissingResourceException;
34 import java.util.Properties;
35 import java.util.PropertyResourceBundle;
36 import java.util.ResourceBundle;
37 import java.util.Set;
38
39 import javax.sql.DataSource;
40
41 import org.apache.commons.dbcp.BasicDataSource;
42 import org.dbunit.Assertion;
43 import org.dbunit.DatabaseUnitException;
44 import org.dbunit.database.DatabaseConfig;
45 import org.dbunit.database.DatabaseConfig.ConfigProperty;
46 import org.dbunit.database.DatabaseDataSourceConnection;
47 import org.dbunit.database.IDatabaseConnection;
48 import org.dbunit.dataset.DataSetException;
49 import org.dbunit.dataset.IDataSet;
50 import org.dbunit.dataset.excel.XlsDataSet;
51 import org.dbunit.dataset.xml.FlatXmlDataSet;
52 import org.dbunit.dataset.xml.FlatXmlProducer;
53 import org.junit.runners.BlockJUnit4ClassRunner;
54 import org.junit.runners.model.FrameworkField;
55 import org.junit.runners.model.FrameworkMethod;
56 import org.junit.runners.model.InitializationError;
57 import org.junit.runners.model.Statement;
58 import org.xml.sax.InputSource;
59
60
61
62
63
64 public class DbUnitRunner extends BlockJUnit4ClassRunner {
65
66 private static final ResourceBundle BUNDLE;
67
68 static {
69 BUNDLE = PropertyResourceBundle.getBundle("dbunit-runner");
70 loadDriver(BUNDLE.getString("driver"));
71 }
72
73 protected static void loadDriver(String driverName) {
74 try {
75 Class.forName(driverName);
76 } catch (ClassNotFoundException e) {
77 throw new RuntimeException(e);
78 }
79 }
80
81 private enum DataSetType {
82 xml() {
83 public IDataSet createDataSet(URL url) throws DataSetException, IOException {
84 return new FlatXmlDataSet(new FlatXmlProducer(new InputSource(url.openStream())));
85 }
86 },
87 xls() {
88 public IDataSet createDataSet(URL url) throws DataSetException, IOException {
89 return new XlsDataSet(url.openStream());
90 }
91 };
92 public abstract IDataSet createDataSet(URL url) throws DataSetException, IOException;
93 }
94
95 protected DataSource dataSource;
96
97 protected Connection testConnection;
98
99 protected String jdbcUrl = BUNDLE.getString("url");
100
101 protected String username = BUNDLE.getString("username");
102
103 protected String password = BUNDLE.getString("password");
104
105 protected String schema = optionalValue(BUNDLE, "schema");
106
107 protected Properties configProperties = new Properties();
108
109
110
111
112
113
114 public DbUnitRunner(Class<?> testClass) throws InitializationError {
115 super(testClass);
116 for (ConfigProperty cp : DatabaseConfig.ALL_PROPERTIES) {
117 try {
118 configProperties.put(cp.getProperty(), BUNDLE.getString(cp.getProperty()));
119 } catch (MissingResourceException ignore) {
120
121 }
122 }
123 }
124
125 protected Statement methodBlock(final FrameworkMethod method) {
126 Statement stmt = super.methodBlock(method);
127 DbUnitTest ann = method.getAnnotation(DbUnitTest.class);
128 return ann == null ? stmt : new DbUnitStatement(ann, stmt);
129 }
130
131 protected List<FrameworkMethod> computeTestMethods() {
132 Set<FrameworkMethod> set = new HashSet<FrameworkMethod>(super.computeTestMethods());
133 set.addAll(getTestClass().getAnnotatedMethods(DbUnitTest.class));
134 return new ArrayList<FrameworkMethod>(set);
135 }
136
137 protected Object createTest() throws Exception {
138 Object result = super.createTest();
139 dataSource = createDataSource();
140 List<FrameworkField> connFields = getTestClass().getAnnotatedFields(TestConnection.class);
141 if (!connFields.isEmpty()) {
142 testConnection = dataSource.getConnection();
143 for (FrameworkField ff : connFields) {
144 final Field f = ff.getField();
145 AccessController.doPrivileged(new SetAccessibleAction(f));
146 f.set(result, testConnection);
147 }
148 }
149 List<FrameworkField> dsFields = getTestClass().getAnnotatedFields(TestDataSource.class);
150 if (!dsFields.isEmpty()) {
151 for (FrameworkField ff : dsFields) {
152 final Field f = ff.getField();
153 AccessController.doPrivileged(new SetAccessibleAction(f));
154 f.set(result, dataSource);
155 }
156 }
157 return result;
158 }
159
160 protected DataSource createDataSource() {
161 BasicDataSource result = new BasicDataSource();
162 result.setUsername(username);
163 result.setPassword(password);
164 result.setUrl(jdbcUrl);
165 return result;
166 }
167
168 protected static String optionalValue(ResourceBundle bundle, String key) {
169 try {
170 return bundle.getString(key);
171 } catch (MissingResourceException ignore) {
172 return null;
173 }
174 }
175
176 private static class SetAccessibleAction implements PrivilegedAction<Object> {
177
178 private Field field;
179
180 public SetAccessibleAction(Field field) {
181 this.field = field;
182 }
183
184 public Object run() {
185 field.setAccessible(true);
186 return null;
187 }
188 }
189
190 protected class DbUnitStatement extends Statement {
191 private DbUnitTest ann;
192 private Statement statement;
193
194 protected DbUnitStatement(DbUnitTest ann, Statement statement) {
195 this.ann = ann;
196 this.statement = statement;
197 }
198
199 public void evaluate() throws Throwable {
200 IDatabaseConnection conn = createDatabaseConnection();
201 try {
202 executeUpdate(conn, ann.sql());
203 IDataSet initData = dataSet(load(ann.init())).nullValue(ann.nullValue()).toDataSet();
204 ann.operation().toDatabaseOperation().execute(conn, initData);
205 statement.evaluate();
206 if (testConnection != null) {
207 testConnection.commit();
208 }
209 } catch (Throwable e) {
210 if (testConnection != null) {
211 testConnection.rollback();
212 }
213 throw e;
214 } finally {
215 if (testConnection != null) {
216 testConnection.close();
217 }
218 conn.close();
219 }
220 if (!ann.expected().isEmpty()) {
221 assertTables();
222 }
223 }
224
225 protected void assertTables() {
226 IDatabaseConnection conn = createDatabaseConnection();
227 try {
228 IDataSet expected = dataSet(load(ann.expected()))
229 .excludeColumns(ann.excludeColumns())
230 .nullValue(ann.nullValue())
231 .rtrim(ann.rtrim())
232 .toDataSet();
233 IDataSet actual = dataSet(conn.createDataSet(expected.getTableNames()))
234 .excludeColumns(ann.excludeColumns())
235 .rtrim(ann.rtrim())
236 .toDataSet();
237 Assertion.assertEquals(expected, actual);
238 } catch (SQLException e) {
239 throw new RuntimeException(e);
240 } catch (DatabaseUnitException e) {
241 throw new RuntimeException(e);
242 } finally {
243 try {
244 conn.close();
245 } catch (SQLException e) {
246 throw new RuntimeException(e);
247 }
248 }
249 }
250
251 protected IDataSet load(String path) {
252 URL url = getTestClass().getJavaClass().getResource(path);
253 if (url == null) {
254 throw new RuntimeException(new FileNotFoundException(path));
255 }
256 String suffix = path.substring(path.lastIndexOf('.') + 1).toLowerCase(Locale.getDefault());
257 try {
258 return DataSetType.valueOf(suffix).createDataSet(url);
259 } catch (Exception e) {
260 throw new RuntimeException(e);
261 }
262 }
263
264 protected IDatabaseConnection createDatabaseConnection() {
265 try {
266 DatabaseDataSourceConnection result = new DatabaseDataSourceConnection(dataSource, schema);
267 DatabaseConfig config = result.getConfig();
268 config.setPropertiesByString(configProperties);
269 return result;
270 } catch (SQLException e) {
271 throw new RuntimeException(e);
272 } catch (DatabaseUnitException e) {
273 throw new RuntimeException(e);
274 }
275 }
276
277 protected void executeUpdate(IDatabaseConnection conn, String... sql) throws SQLException {
278 for (String s : sql) {
279 if (s.isEmpty()) {
280 continue;
281 }
282 PreparedStatement stmt = null;
283 try {
284 stmt = conn.getConnection().prepareStatement(s);
285 stmt.executeUpdate();
286 } finally {
287 if (stmt != null) {
288 stmt.close();
289 }
290 }
291 }
292 }
293 }
294 }