Search code examples
javaunit-testingcode-generation

Representing Java toString results as Java source code


I have a String value of a User list object generated at runtime using toString method:

[User[firstName=John, lastName=Smith]]

Is there a fast way to generate Java source code that would create such an object? Ideally it would generate the following source code:

Arrays.asList(new User("John", "Smith")); // might also use getters-setters

This will make writing of test code assertions faster for me.

My object is defined as follows:

public record User (String firstName, String lastName) {}

Solution

  • Out of curiosity I took a stab at this and it turns out to be reasonably simple, if you can live with some restrictions. Some of those are fundamental, others just require a bit more code to work around

    • this handles exactly only:
      • record classes
      • lists of those record classes
      • strings, primitives and their wrappers
    • it assumes the input is actually generated by the default toString and is not forgiving (extra/missing whitespace at the wrong place breaks it).
    • the top-level class(es) that can occur have to be handed to the parser
    • records are assumed to not override their toString
    • the canonical constructor of each record is used
    • some inputs are simple ambiguous, they will be resolved to something, but there's no guarantee which interpretation
    • there is limited error checking
    • there is no validation for the primitive/wrapper types: non-sensical inputs will be quietly translated. Your compiler will tell you, 'though.
    • I'm using the new switch features, but that's trivial to change, if needed
    • I've not hand-written a parser in quite some while. Don't look at this for best practices.
    • there are definitely bugs

    Adding support for Map would be fairly straightforward. And due to the nature of these parsers, modifying it to actually build the objects instead of the object construction code would be reasonably simple as well. Both of these are left as an exercise for the reader.

    The full tests are below, but as an example, here's how to use it:

    import java.util.List;
    
    public class Foo {
      public static void main(String[] args) {
        var parser = new RecordToStringParser(List.of(User.class, UserContainer.class));
        System.out.println(parser.convert("User[firstName=John, lastName=Smith]"));
        System.out.println(parser.convert("User[firstName=John, Doe, lastName=Smith]"));
        System.out.println(parser.convert("UserContainer[user=User[firstName=John, lastName=Smith], flags=13]"));
        System.out.println(parser.convert("UserContainer[user=null, flags=42]"));
        System.out.println(parser.convert("[User[firstName=John, lastName=Smith], UserContainer[user=null, flags=1]]"));
        System.out.println(parser.convert("UserContainer[user=null, flags=what are you talking about]"));
      }
    }
    
    record User(String firstName, String lastName) {}
    record UserContainer(User user, int flags) {}
    

    This prints

    new User("John", "Smith")
    new User("John, Doe", "Smith")
    new UserContainer(new User("John", "Smith"), 13)
    new UserContainer(null, 42)
    List.of(new User("John", "Smith"), new UserContainer(null, 1))
    new UserContainer(null, what are you talking about)
    

    Note how the last one just blindly dumps the value without verifying the type? 🤷

    import java.lang.reflect.RecordComponent;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    import java.util.function.Function;
    import java.util.regex.Matcher;
    import java.util.regex.Pattern;
    import java.util.stream.Collectors;
    
    public class RecordToStringParser {
    
      private static final Pattern RECORD_NAME = Pattern.compile("^[A-Z][A-Za-z0-9]+");
      private static final Set<Class<?>> WRAPPER_CLASSES = Set.of(
          Boolean.class, Byte.class, Character.class, Short.class, Integer.class, Long.class, Float.class, Double.class);
    
      private final Map<String, Class<? extends Record>> classes;
    
      public RecordToStringParser(Class<? extends Record> c1) {
        this(List.of(c1));
      }
    
      public RecordToStringParser(List<Class<? extends Record>> inputClasses) {
        if (inputClasses.isEmpty()) {
          throw new IllegalArgumentException("No record classes specified!");
        }
        try {
          classes = inputClasses.stream()
              .collect(Collectors.toMap(Class::getSimpleName, Function.identity()));
        } catch (IllegalStateException e) {
          // this happens when simple names are not unique!
          throw new IllegalArgumentException("Colliding class names detected!", e);
        }
        for (String name : classes.keySet()) {
          if (!RECORD_NAME.matcher(name).matches()) {
            throw new IllegalArgumentException("Nonconforming class name found: " + name);
          }
        }
      }
    
      public String convert(String input) {
        return new Parse(input).parse();
      }
    
      private class Parse {
        String in;
        int pos = 0;
        StringBuilder out = new StringBuilder();
    
        Parse(String input) {
          this.in = input;
        }
    
        String parse() {
          parseNext();
          if (pos != in.length()) {
            throw err("Unexpected content after successful parse");
          }
          return out.toString();
        }
    
        private void parseNext() {
          if (isNext('[')) {
            parseList();
          } else {
            parseRecord();
          }
        }
    
    
        private void parseRecord() {
          if (maybeConsume("null")) {
            out.append("null");
            return;
          }
          Matcher matcher = RECORD_NAME.matcher(in);
          matcher.region(pos, in.length());
          if (!matcher.find()) {
            throw err("Record expected.");
          }
          pos = matcher.end();
          consume('[');
          String className = matcher.group();
          out.append("new ").append(className).append("(");
          Class<? extends Record> recordClass = classes.get(className);
          if (recordClass == null) {
            throw err("Unknown record class " + className);
          }
          parseMembers(recordClass);
          out.append(")");
          consume(']');
        }
    
        private void parseRecord(Class<? extends Record> type) {
          if (maybeConsume("null")) {
            out.append("null");
            return;
          }
          consume(type.getSimpleName());
          consume('[');
          out.append("new ").append(type.getSimpleName()).append("(");
          parseMembers(type);
          out.append(")");
          consume(']');
        }
    
        private void parseMembers(Class<? extends Record> recordClass) {
          RecordComponent[] components = recordClass.getRecordComponents();
          for (int i = 0; i < components.length; i++) {
            RecordComponent component = components[i];
            boolean isLastComponent = i == components.length - 1;
            consume(component.getName());
            consume('=');
            Class<?> type = component.getType();
            if (type.isRecord()) {
              parseRecord(type.asSubclass(Record.class));
            } else if (List.class == type) {
              parseList();
            } else if (type == String.class || type.isPrimitive() || WRAPPER_CLASSES.contains(type)) {
              parseSimpleValues(components, i, isLastComponent, type);
            } else {
              throw err("Unsupported component type " + type);
            }
            if (!isLastComponent) {
              consume(", ");
              out.append(", ");
            }
          }
    
        }
    
        private void parseSimpleValues(RecordComponent[] components, int i, boolean isLastComponent, Class<?> type) {
          String expectedSuffix = isLastComponent ? "]" : ", " + components[i + 1].getName() + "=";
          isNext('\0'); // ensure we throw the right exception when we're at the end.
          int valueEnd = in.indexOf(expectedSuffix, pos);
          if (valueEnd == -1) {
            throw err("Failed to find end of value!");
          }
          String value = in.substring(pos, valueEnd);
          if (type == String.class && !value.equals("null")) {
            // interpret string values null as the value null and not the string "null"
            out.append('"');
            for (char c : value.toCharArray()) {
              switch (c) {
                case '"', '\\' -> out.append('\\').append(c);
                case '\n' -> out.append("\\n");
                case '\t' -> out.append("\\t");
                // maybe add some more?
                default -> out.append(c);
              }
            }
            out.append('"');
          } else {
            // maybe add verification?
            out.append(value);
          }
          pos = valueEnd;
        }
    
        private void parseList() {
          if (maybeConsume("null")) {
            out.append("null");
            return;
          }
          consume('[');
          out.append("List.of(");
          if (!isNext(']')) {
            parseNext();
            while (maybeConsume(", ")) {
              out.append(", ");
              parseRecord();
            }
          }
          consume(']');
          out.append(")");
        }
    
        private boolean isNext(char expected) {
          if (pos >= in.length()) {
            throw err("Unexpected end of string.");
          }
          return in.charAt(pos) == expected;
        }
    
        private boolean maybeConsume(String expected) {
          boolean matches = in.regionMatches(pos, expected, 0, expected.length());
          if (matches) {
            pos += expected.length();
          }
          return matches;
        }
    
        private void consume(char expected) {
          if (pos >= in.length()) {
            throw err("Unexpected end of string, expected " + expected);
          }
          if (in.charAt(pos) != expected) {
            throw err("Expected " + expected);
          }
          pos++;
        }
    
        private void consume(String expected) {
          if (!in.regionMatches(pos, expected, 0, expected.length())) {
            throw err("Expected " + expected);
          }
          pos += expected.length();
        }
    
        private IllegalArgumentException err(String msg) {
          return new IllegalArgumentException(String.format("Unexpected input at pos %d (%s): %s", pos, this, msg));
        }
    
        @Override
        public String toString() {
          // simple way to see parser state in the debugger ;-)
          return in.substring(0, pos) + "<|>" + in.substring(pos);
        }
      }
    }
    

    Test class:

    import static org.junit.jupiter.api.Assertions.*;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    import org.junit.jupiter.api.Test;
    import org.junit.jupiter.params.ParameterizedTest;
    import org.junit.jupiter.params.provider.MethodSource;
    
    public class RecordToStringParserTest {
      private List<Class<? extends Record>> BASE_RECORD_CLASSES = List.of(
          Empty.class, Simple.class, WithNumber.class, NestingSimple.class, NestingListOfSimple.class, Recursive.class,
          TwoStrings.class, NestingRecord.class);
    
      @Test
      void mustProvideClasses() {
        final List<Class<? extends Record>> noClasses = List.of();
        assertThrows(IllegalArgumentException.class, () -> new RecordToStringParser(noClasses));
      }
    
      @Test
      void classNamesMustStartWithUpperCase() {
        assertThrows(IllegalArgumentException.class, () -> new RecordToStringParser(lowerRecordName.class));
      }
    
      @Test
      void classNamesMustBeAlphaNumeric() {
        assertThrows(IllegalArgumentException.class, () -> new RecordToStringParser(NonAlphaNum_.class));
      }
    
      @Test
      void simpleClassNamesMustBeUnique() {
        List<Class<? extends Record>> collidingClasses = List.of(NS1.Collision.class, NS2.Collision.class);
        assertThrows(IllegalArgumentException.class,
            () -> new RecordToStringParser(collidingClasses));
      }
    
      @SuppressWarnings("unchecked")
      public static Object[][] realToStringSources() {
        return new Object[][] {
            {new Empty(), "new Empty()"},
            {new Simple("bar"), "new Simple(\"bar\")"},
            {new Simple(""), "new Simple(\"\")"},
            {new Simple(null), "new Simple(null)"},
            {new Simple("\n"), "new Simple(\"\\n\")"},
            {new Simple("\\"), "new Simple(\"\\\\\")"},
            {new Simple("\"\\\""), "new Simple(\"\\\"\\\\\\\"\")"}, // thanks, I hate it
            {new WithNumber(1), "new WithNumber(1)"},
            {new NestingSimple(new Simple("bar")), "new NestingSimple(new Simple(\"bar\"))"},
            {new NestingSimple(null), "new NestingSimple(null)"},
            {new Recursive(null), "new Recursive(null)"},
            {new Recursive(new Recursive(null)), "new Recursive(new Recursive(null))"},
            {List.of(), "List.of()"},
            {List.of(new Empty()), "List.of(new Empty())"},
            {List.of(new Empty(), new Empty()), "List.of(new Empty(), new Empty())"},
            {List.of(List.of(new Empty())), "List.of(List.of(new Empty()))"},
            {new NestingListOfSimple(List.of(new Simple("bar"))), "new NestingListOfSimple(List.of(new Simple(\"bar\")))"},
            {new TwoStrings("foo", "bar"), "new TwoStrings(\"foo\", \"bar\")"},
            {new TwoStrings("foo, bar", "bar"), "new TwoStrings(\"foo, bar\", \"bar\")"},
            {new TwoStrings("John", "Smith"), "new TwoStrings(\"John\", \"Smith\")"},
            // "WRONG" results start here
            // List.of() can't contain null values, but we don't check for that, so we tolerate it
            {new ArrayList<>(Arrays.asList(null, new Empty())), "List.of(null, new Empty())"},
            // Simple[foo=null] is ambiguous, we decided to interpret it as null and not the string value "null"
            {new Simple("null"), "new Simple(null)"},
            // We don't check for insanity/generics breakage
            {new NestingListOfSimple((List<Simple>) ((List<?>) List.of(new WithNumber(1)))),
                "new NestingListOfSimple(List.of(new WithNumber(1)))"},
            // Well, that's just ambiguous!
            {new TwoStrings("foo, bar=", "bar"), "new TwoStrings(\"foo\", \", bar=bar\")"},
            // We don't verify simple types, so it'll blindly output nonsense
            {"WithNumber[foo=bar]", "new WithNumber(bar)"},
        };
      }
    
      @ParameterizedTest
      @MethodSource("realToStringSources")
      void parseRealToString(Object record, String expectedOutput) {
        String input = record.toString();
        RecordToStringParser parser = new RecordToStringParser(BASE_RECORD_CLASSES);
        String output = parser.convert(input);
        assertEquals(expectedOutput, output);
      }
    
      public static String[] unparseableStrings() {
        return new String[] {
            "",
            "]",
            "[",
            "Empty",
            "Empty[",
            "Simple[foo=",
            "NestingSimple[foo=10]",
            "UnknownClass[]",
            // "WithNumber[foo=bar]", // this isn't validated
            "[]]",
            "[bar]", // non-record values in lists are not supported
            "{}", // maps are not supported, but wouldn't be too hard
            "NestingRecord[r=Empty[]]", // this could easily be supported, but why?
        };
      }
    
      @ParameterizedTest
      @MethodSource("unparseableStrings")
      void failToParse(String input) {
        RecordToStringParser parser = new RecordToStringParser(BASE_RECORD_CLASSES);
        assertThrows(IllegalArgumentException.class, () -> parser.convert(input));
      }
    
      record lowerRecordName(String foo) {}
    
      record NonAlphaNum_(String foo) {}
    
      record Empty() {}
    
      record Simple(String foo) {}
    
      record NestingSimple(Simple foo) {}
    
      record NestingListOfSimple(List<Simple> foo) {}
    
      record Recursive(Recursive foo) {}
    
      record WithNumber(int foo) {}
    
      record TwoStrings(String foo, String bar) {}
    
      record NestingRecord(Record r) {}
    
      static class NS1 {
        record Collision(String foo) {}
      }
    
      static class NS2 {
        record Collision(String foo) {}
      }
    }