Search code examples
javaunit-testingreflectionjunitintellij-idea

How can I run all JUnit unit tests except those ending in "IntegrationTest" in my IntelliJ IDEA project using the integrated test runner?


I basically want to run all JUnit unit tests in my IntelliJ IDEA project (excluding JUnit integration tests), using the static suite() method of JUnit. Why use the static suite() method? Because I can then use IntelliJ IDEA's JUnit test runner to run all unit tests in my application (and easily exclude all integration tests by naming convention). The code so far looks like this:

package com.acme;

import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class AllUnitTests extends TestCase {

    public static Test suite() {
        List classes = getUnitTestClasses();
        return createTestSuite(classes);
    }

    private static List getUnitTestClasses() {
        List classes = new ArrayList();
        classes.add(CalculatorTest.class);
        return classes;
    }

    private static TestSuite createTestSuite(List allClasses) {
        TestSuite suite = new TestSuite("All Unit Tests");
        for (Iterator i = allClasses.iterator(); i.hasNext();) {
            suite.addTestSuite((Class<? extends TestCase>) i.next());
        }
        return suite;
    }

}

The method getUnitTestClasses() should be rewritten to add all project classes extending TestCase, except if the class name ends in "IntegrationTest".

I know I can do this easily in Maven for example, but I need to do it in IntelliJ IDEA so I can use the integrated test runner - I like the green bar :)


Solution

  • I've written some code to do most of the work. It works only if your files are on the local disk instead of in a JAR. All you need is one class in the package. You could, for this purpose, create a Locator.java class, just to be able to find the package.

    public class ClassEnumerator {
        public static void main(String[] args) throws ClassNotFoundException {
            List<Class<?>> list = listClassesInSamePackage(Locator.class, true);
    
            System.out.println(list);
        }
    
        private static List<Class<?>> listClassesInSamePackage(Class<?> locator, boolean includeLocator) 
                                                                          throws ClassNotFoundException {
    
            File packageFile = getPackageFile(locator);
    
            String ignore = includeLocator ? null : locator.getSimpleName() + ".class";
    
            return toClassList(locator.getPackage().getName(), listClassNames(packageFile, ignore));
        }
    
        private static File getPackageFile(Class<?> locator) {
            URL url = locator.getClassLoader().getResource(locator.getName().replace(".", "/") + ".class");
            if (url == null) {
                throw new RuntimeException("Cannot locate " + Locator.class.getName());
            }
    
            try {
            return new File(url.toURI()).getParentFile();
            }
            catch (URISyntaxException e) {
                throw new RuntimeException(e);
            }
        }
    
        private static String[] listClassNames(File packageFile, final String ignore) {
            return packageFile.list(new FilenameFilter(){
                @Override
                public boolean accept(File dir, String name) {
                    if (name.equals(ignore)) {
                        return false;
                    }
                    return name.endsWith(".class");
                }
            });
        }
    
        private static List<Class<?>> toClassList(String packageName, String[] classNames)
                                                                 throws ClassNotFoundException {
    
            List<Class<?>> result = new ArrayList<Class<?>>(classNames.length);
            for (String className : classNames) {
                // Strip the .class
                String simpleName = className.substring(0, className.length() - 6);
    
                result.add(Class.forName(packageName + "." + simpleName));
            }
            return result;
        }
    }