View Javadoc

1   /**
2    * Copyright (C) 2009 kiy0taka.org
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *         http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package org.kiy0taka.dbunit;
17  
18  import static org.kiy0taka.dbunit.DataSetBuilder.dataSet;
19  
20  import java.io.FileNotFoundException;
21  import java.io.IOException;
22  import java.lang.reflect.Field;
23  import java.net.URL;
24  import java.security.AccessController;
25  import java.security.PrivilegedAction;
26  import java.sql.Connection;
27  import java.sql.PreparedStatement;
28  import java.sql.SQLException;
29  import java.util.ArrayList;
30  import java.util.HashSet;
31  import java.util.List;
32  import java.util.Locale;
33  import java.util.MissingResourceException;
34  import java.util.Properties;
35  import java.util.PropertyResourceBundle;
36  import java.util.ResourceBundle;
37  import java.util.Set;
38  
39  import javax.sql.DataSource;
40  
41  import org.apache.commons.dbcp.BasicDataSource;
42  import org.dbunit.Assertion;
43  import org.dbunit.DatabaseUnitException;
44  import org.dbunit.database.DatabaseConfig;
45  import org.dbunit.database.DatabaseConfig.ConfigProperty;
46  import org.dbunit.database.DatabaseDataSourceConnection;
47  import org.dbunit.database.IDatabaseConnection;
48  import org.dbunit.dataset.DataSetException;
49  import org.dbunit.dataset.IDataSet;
50  import org.dbunit.dataset.excel.XlsDataSet;
51  import org.dbunit.dataset.xml.FlatXmlDataSet;
52  import org.dbunit.dataset.xml.FlatXmlProducer;
53  import org.junit.runners.BlockJUnit4ClassRunner;
54  import org.junit.runners.model.FrameworkField;
55  import org.junit.runners.model.FrameworkMethod;
56  import org.junit.runners.model.InitializationError;
57  import org.junit.runners.model.Statement;
58  import org.xml.sax.InputSource;
59  
60  /**
61   * JUnit Runner implementation for DbUnit.
62   * @author kiy0taka
63   */
64  public class DbUnitRunner extends BlockJUnit4ClassRunner {
65  
66      private static final ResourceBundle BUNDLE;
67  
68      static {
69          BUNDLE = PropertyResourceBundle.getBundle("dbunit-runner");
70          loadDriver(BUNDLE.getString("driver"));
71      }
72  
73      protected static void loadDriver(String driverName) {
74          try {
75              Class.forName(driverName);
76          } catch (ClassNotFoundException e) {
77              throw new RuntimeException(e);
78          }
79      }
80  
81      private enum DataSetType {
82          xml() {
83              public IDataSet createDataSet(URL url) throws DataSetException, IOException {
84                  return new FlatXmlDataSet(new FlatXmlProducer(new InputSource(url.openStream())));
85              }
86          },
87          xls() {
88              public IDataSet createDataSet(URL url) throws DataSetException, IOException {
89                  return new XlsDataSet(url.openStream());
90              }
91          };
92          public abstract IDataSet createDataSet(URL url) throws DataSetException, IOException;
93      }
94  
95      protected DataSource dataSource;
96  
97      protected Connection testConnection;
98  
99      protected String jdbcUrl = BUNDLE.getString("url");
100 
101     protected String username = BUNDLE.getString("username");
102 
103     protected String password = BUNDLE.getString("password");
104 
105     protected String schema = optionalValue(BUNDLE, "schema");
106 
107     protected Properties configProperties = new Properties();
108 
109     /**
110      * Constract Runner for DbUnit.
111      * @param testClass Test Class
112      * @throws InitializationError Initialization error
113      */
114     public DbUnitRunner(Class<?> testClass) throws InitializationError {
115         super(testClass);
116         for (ConfigProperty cp : DatabaseConfig.ALL_PROPERTIES) {
117             try {
118                 configProperties.put(cp.getProperty(), BUNDLE.getString(cp.getProperty()));
119             } catch (MissingResourceException ignore) {
120                 // NOP
121             }
122         }
123     }
124 
125     protected Statement methodBlock(final FrameworkMethod method) {
126         Statement stmt = super.methodBlock(method);
127         DbUnitTest ann = method.getAnnotation(DbUnitTest.class);
128         return ann == null ? stmt : new DbUnitStatement(ann, stmt);
129     }
130 
131     protected List<FrameworkMethod> computeTestMethods() {
132         Set<FrameworkMethod> set = new HashSet<FrameworkMethod>(super.computeTestMethods());
133         set.addAll(getTestClass().getAnnotatedMethods(DbUnitTest.class));
134         return new ArrayList<FrameworkMethod>(set);
135     }
136 
137     protected Object createTest() throws Exception {
138         Object result = super.createTest();
139         dataSource = createDataSource();
140         List<FrameworkField> connFields = getTestClass().getAnnotatedFields(TestConnection.class);
141         if (!connFields.isEmpty()) {
142             testConnection = dataSource.getConnection();
143             for (FrameworkField ff : connFields) {
144                 final Field f = ff.getField();
145                 AccessController.doPrivileged(new SetAccessibleAction(f));
146                 f.set(result, testConnection);
147             }
148         }
149         List<FrameworkField> dsFields = getTestClass().getAnnotatedFields(TestDataSource.class);
150         if (!dsFields.isEmpty()) {
151             for (FrameworkField ff : dsFields) {
152                 final Field f = ff.getField();
153                 AccessController.doPrivileged(new SetAccessibleAction(f));
154                 f.set(result, dataSource);
155             }
156         }
157         return result;
158     }
159 
160     protected DataSource createDataSource() {
161         BasicDataSource result = new BasicDataSource();
162         result.setUsername(username);
163         result.setPassword(password);
164         result.setUrl(jdbcUrl);
165         return result;
166     }
167 
168     protected static String optionalValue(ResourceBundle bundle, String key) {
169         try {
170             return bundle.getString(key);
171         } catch (MissingResourceException ignore) {
172             return null;
173         }
174     }
175 
176     private static class SetAccessibleAction implements PrivilegedAction<Object> {
177 
178         private Field field;
179 
180         public SetAccessibleAction(Field field) {
181             this.field = field;
182         }
183 
184         public Object run() {
185             field.setAccessible(true);
186             return null;
187         }
188     }
189 
190     protected class DbUnitStatement extends Statement {
191         private DbUnitTest ann;
192         private Statement statement;
193 
194         protected DbUnitStatement(DbUnitTest ann, Statement statement) {
195             this.ann = ann;
196             this.statement = statement;
197         }
198 
199         public void evaluate() throws Throwable {
200             IDatabaseConnection conn = createDatabaseConnection();
201             try {
202                 executeUpdate(conn, ann.sql());
203                 IDataSet initData = dataSet(load(ann.init())).nullValue(ann.nullValue()).toDataSet();
204                 ann.operation().toDatabaseOperation().execute(conn, initData);
205                 statement.evaluate();
206                 if (testConnection != null) {
207                     testConnection.commit();
208                 }
209             } catch (Throwable e) {
210                 if (testConnection != null) {
211                     testConnection.rollback();
212                 }
213                 throw e;
214             } finally {
215                 if (testConnection != null) {
216                     testConnection.close();
217                 }
218                 conn.close();
219             }
220             if (!ann.expected().isEmpty()) {
221                 assertTables();
222             }
223         }
224 
225         protected void assertTables() {
226             IDatabaseConnection conn = createDatabaseConnection();
227             try {
228                 IDataSet expected = dataSet(load(ann.expected()))
229                     .excludeColumns(ann.excludeColumns())
230                     .nullValue(ann.nullValue())
231                     .rtrim(ann.rtrim())
232                     .toDataSet();
233                 IDataSet actual = dataSet(conn.createDataSet(expected.getTableNames()))
234                     .excludeColumns(ann.excludeColumns())
235                     .rtrim(ann.rtrim())
236                     .toDataSet();
237                 Assertion.assertEquals(expected, actual);
238             } catch (SQLException e) {
239                 throw new RuntimeException(e);
240             } catch (DatabaseUnitException e) {
241                 throw new RuntimeException(e);
242             } finally {
243                 try {
244                     conn.close();
245                 } catch (SQLException e) {
246                     throw new RuntimeException(e);
247                 }
248             }
249         }
250 
251         protected IDataSet load(String path) {
252             URL url = getTestClass().getJavaClass().getResource(path);
253             if (url == null) {
254                 throw new RuntimeException(new FileNotFoundException(path));
255             }
256             String suffix = path.substring(path.lastIndexOf('.') + 1).toLowerCase(Locale.getDefault());
257             try {
258                 return DataSetType.valueOf(suffix).createDataSet(url);
259             } catch (Exception e) {
260                 throw new RuntimeException(e);
261             }
262         }
263 
264         protected IDatabaseConnection createDatabaseConnection() {
265             try {
266                 DatabaseDataSourceConnection result = new DatabaseDataSourceConnection(dataSource, schema);
267                 DatabaseConfig config = result.getConfig();
268                 config.setPropertiesByString(configProperties);
269                 return result;
270             } catch (SQLException e) {
271                 throw new RuntimeException(e);
272             } catch (DatabaseUnitException e) {
273                 throw new RuntimeException(e);
274             }
275         }
276 
277         protected void executeUpdate(IDatabaseConnection conn, String... sql) throws SQLException {
278             for (String s : sql) {
279                 if (s.isEmpty()) {
280                     continue;
281                 }
282                 PreparedStatement stmt = null;
283                 try {
284                     stmt = conn.getConnection().prepareStatement(s);
285                     stmt.executeUpdate();
286                 } finally {
287                     if (stmt != null) {
288                         stmt.close();
289                     }
290                 }
291             }
292         }
293     }
294 }