/************************************************************************* * * * CESeCore: CE Security Core * * * * This software is free software; you can redistribute it and/or * * modify it under the terms of the GNU Lesser General Public * * License as published by the Free Software Foundation; either * * version 2.1 of the License, or any later version. * * * * See terms of license at gnu.org. * * * *************************************************************************/ package org.cesecore.util; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import java.beans.XMLDecoder; import java.beans.XMLEncoder; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Date; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Random; import java.util.Set; import java.util.TreeMap; import org.apache.log4j.Logger; import org.cesecore.certificates.certificateprofile.CertificatePolicy; import org.junit.Test; /** * * @version $Id: SecureXMLDecoderTest.java 26353 2017-08-17 14:57:40Z mikekushner $ */ public class SecureXMLDecoderTest { private static final Logger log = Logger.getLogger(SecureXMLDecoderTest.class); @Test public void testElementaryTypes() throws IOException { log.trace(">testElementaryTypes"); decodeCompare(""); decodeCompare("\n"); decodeCompare("\n\n"); decodeCompare("\nA\n"); decodeCompare("\n&\n"); decodeCompare("\n-12345\n"); decodeCompare("\n-12.34\n"); decodeCompare("\nhello\n"); decodeCompare("\ntrue\n"); decodeCompare("\n\n\n"); decodeCompare("\n\n" + "\n123\n\n" + "\n456\n\n" + "\n"); log.trace("testBasicCollections"); // Empty list decodeCompare("\n\n\n"); // List of two integers decodeCompare("\n\n" + "\n123\n\n" + "\n456\n\n" + "\n"); // Map from int to string decodeCompare("\n\n" + "\n-1\nA\n\n" + "\n10\nB\n\n" + "\n\n" + "\n-1\n\n" + "\n3\n\n" + "\n"); log.trace("testMultipleObjects"); decodeCompare("\n-12345\nABC\n"); log.trace(">testMultipleObjects"); } private static enum MockEnum { FOO; } /** * Tests encoding and decoding an enum */ @Test public void testEnum() throws IOException { final Map root = new LinkedHashMap<>(); root.put("testEnum", MockEnum.FOO); // Encode final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final XMLEncoder encoder = new XMLEncoder(baos); encoder.writeObject(root); encoder.close(); // Try to decode it and compare decodeCompare(baos.toByteArray()); } /** * Encodes a complex value with the standard XMLEncoder and tries to decode it again. * @throws UnsupportedEncodingException */ @Test public void testComplexEncodeDecode() throws IOException { log.trace(">testComplexEncodeDecode"); final Map root = new LinkedHashMap<>(); root.put("testfloat", 12.3); root.put("testnull", null); root.put("testutf8string", "Test ÅÄÖ \u4f60\u597d"); root.put("teststrangechars", "\0001\0002fooString"); root.put("testchar1", '<'); root.put("testchar2", '\\'); root.put("testchar3", 'å'); root.put("testbool", false); root.put("testemptyset", Collections.EMPTY_SET); root.put("testemptylist", Collections.emptyList()); root.put("testClass", SecureXMLDecoder.class); final List unmodifiable = new ArrayList<>(); unmodifiable.add('A'); unmodifiable.add('B'); root.put("testemptylist", Collections.unmodifiableList(unmodifiable)); root.put("testbyte", Byte.valueOf((byte)123)); root.put("testshort", Short.valueOf((short)12345)); final Set set = new HashSet<>(); set.add("Test"); set.add(12345); set.add(new ArrayList()); root.put("testhashset", set); root.put("testbytearray", new byte[] { -128, 0, 123, 45, 67, 89, 127 }); root.put("teststringarray", new String[] { "Hello", "World" }); final Map map = new LinkedHashMap<>(); map.put(123, "ABC"); root.put("testmaparray", new Map[] { map }); root.put("testbooleanarray", new boolean[] { true, false, true }); root.put("testnestedarray", new byte[][] { new byte[] { 1, 2 }, new byte[] { 3, 4 } }); final Map treeMap = new TreeMap<>(new Comparator() { @Override public int compare(String o1, String o2) { return o1.hashCode() - o2.hashCode(); } }); root.put("testdate", new Date(1457949109000L)); root.put("testproperties0", new Properties()); final Properties props1 = new Properties(); props1.put("test.something", "value"); root.put("testproperties1", props1); final Properties propsDefaults = new Properties(); propsDefaults.put("test.something1", "default1"); propsDefaults.put("test.something2", "default2"); final Properties props2 = new Properties(propsDefaults); props2.put("test.something1", "override"); root.put("testproperties2", props2); treeMap.put("aaa", 1); treeMap.put("bbb", 2); treeMap.put("ccc", 3); treeMap.put("ddd", 4); treeMap.put("eee", 5); root.put("testtreemap", treeMap); final List list = new ArrayList<>(2); final Map nested = new HashMap<>(); nested.put("testlong", Long.valueOf(Long.MAX_VALUE)); list.add(nested); root.put("testlist", list); final Map propmap = new HashMap<>(); root.put("b64getmap", new Base64GetHashMap(propmap)); root.put("b64putmap", new Base64PutHashMap(propmap)); root.put("certpolicy", new CertificatePolicy("1.2.3.4", "Finders keepers!", "http://example.com/policy")); // Base64PutHashMap // Encode final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final XMLEncoder encoder = new XMLEncoder(baos); encoder.writeObject(root); encoder.close(); // Try to decode it and compare decodeCompare(baos.toByteArray()); log.trace("testNotAllowedType"); // Encode final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final XMLEncoder encoder = new XMLEncoder(baos); encoder.writeObject(new Random()); // java.util.Random is serializable, but isn't whitelisted encoder.close(); decodeBad(baos.toByteArray()); log.trace(" root = new LinkedHashMap<>(); root.put("testClass", Object.class); // Encode final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final XMLEncoder encoder = new XMLEncoder(baos); encoder.writeObject(root); encoder.close(); // Try to decode it and compare try { decodeCompare(baos.toByteArray()); fail("Test should have failed when decoding an unauthorized class."); } catch (IOException e) { } } @Test public void testNotAllowedMethod() { log.trace(">testNotAllowedMethod"); final String xml = "\n\n" + "\n123\n\n" + "\n0\n\n" + "\n"; decodeBad(xml.getBytes()); log.trace("decodeBad(" + new String(xml) + ")"); } try (XMLDecoder xmldec = new XMLDecoder(new ByteArrayInputStream(xml))) { xmldec.readObject(); // Should succeed (XMLDecoder is insecure) } try (SecureXMLDecoder securedec = new SecureXMLDecoder(new ByteArrayInputStream(xml))) { securedec.readObject(); fail("Should not accept arbitrary classes/methods"); } catch (IOException e) { // NOPMD: Expected } log.trace("decodeCompare(" + new String(xml) + ")"); } final List expectedObjs = new ArrayList<>(); final List actualObjs = new ArrayList<>(); try (XMLDecoder xmldec = new XMLDecoder(new ByteArrayInputStream(xml))) { int i = 1; while (true) { log.debug("Reading object " + (i++) + " from the standard JDK XMLDecoder"); expectedObjs.add(xmldec.readObject()); } } catch (ArrayIndexOutOfBoundsException e) { // NOPMD: Expected, happens when we reach the end } try (SecureXMLDecoder securedec = new SecureXMLDecoder(new ByteArrayInputStream(xml))) { int i = 1; while (true) { log.debug("Reading object " + (i++) + " from SecureXMLDecoder"); actualObjs.add(securedec.readObject()); } } catch (EOFException e) { // NOPMD: Expected, happens when we reach the end } // Compare the results assertEquals("Number of objects decoded from XML differ.", expectedObjs.size(), actualObjs.size()); final int count = expectedObjs.size(); log.debug("Comparing " + count + " objects"); for (int i = 0; i < count; i++) { final Object expected = expectedObjs.get(i); final Object actual = actualObjs.get(i); compareObjects(expected, actual); } log.trace("compareObjects(" + expected + ", " + actual + ")"); } if (expected == null) { assertNull("Deserialized value should have been null", actual); return; } assertNotNull("Deserialized value should NOT be null", actual); assertEquals("Class of deserialized value differs.", expected.getClass(), actual.getClass()); if (expected instanceof List) { final List expectedList = (List)expected; final List actualList = (List)actual; assertEquals("Number of elements in lists differ.", expectedList.size(), actualList.size()); final Iterator expectedIter = expectedList.iterator(); final Iterator actualIter = actualList.iterator(); while (expectedIter.hasNext()) { final Object expectedElem = expectedIter.next(); final Object actualElem = actualIter.next(); compareObjects(expectedElem, actualElem); } } else if (expected instanceof LinkedHashMap || expected instanceof TreeMap) { final Map expectedMap = (Map)expected; final Map actualMap = (Map)actual; assertEquals("Number of elements in maps differ.", expectedMap.size(), actualMap.size()); // For LinkedHashMaps we expect the entries to come in the same order final Iterator expectedIter = expectedMap.entrySet().iterator(); final Iterator actualIter = expectedMap.entrySet().iterator(); while (expectedIter.hasNext()) { final Map.Entry expectedEntry = (Map.Entry)expectedIter.next(); final Map.Entry actualEntry = (Map.Entry)actualIter.next(); compareObjects(expectedEntry.getKey(), actualEntry.getKey()); compareObjects(expectedEntry.getValue(), actualEntry.getValue()); } } else if (expected instanceof Map) { final Map expectedMap = (Map)expected; final Map actualMap = (Map)actual; assertEquals("Number of elements in maps differ.", expectedMap.size(), actualMap.size()); for (Object key : expectedMap.keySet()) { final Object expectedValue = expectedMap.get(key); final Object actualValue = actualMap.get(key); compareObjects(expectedValue, actualValue); } } else if (expected.getClass().isArray()) { // Note: The array could be of a primitive type so we can't cast it to Object[] final int expectedLength = Array.getLength(expected); assertEquals("Number of array elements differ.", expectedLength, Array.getLength(actual)); for (int i = 0; i < expectedLength; i++) { final Object expectedElem = Array.get(expected, i); final Object actualElem = Array.get(actual, i); compareObjects(expectedElem, actualElem); } } else { assertEquals("Deserialized values differ.", expected, actual); } if (log.isTraceEnabled()) { log.trace("