1/*
2 * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/*
25 * @test
26 * @bug 8006259
27 * @summary Test several modes of operation using vectors from SP 800-38A
28 * @modules java.xml.bind
29 * @run main CheckExampleVectors
30 */
31
32import java.io.*;
33import java.security.*;
34import java.util.*;
35import java.util.function.*;
36import javax.xml.bind.DatatypeConverter;
37import javax.crypto.*;
38import javax.crypto.spec.*;
39
40public class CheckExampleVectors {
41
42    private enum Mode {
43        ECB,
44        CBC,
45        CFB1,
46        CFB8,
47        CFB128,
48        OFB,
49        CTR
50    }
51
52    private enum Operation {
53        Encrypt,
54        Decrypt
55    }
56
57    private static class Block {
58        private byte[] input;
59        private byte[] output;
60
61        public Block() {
62
63        }
64        public Block(String settings) {
65            String[] settingsParts = settings.split(",");
66            input = stringToBytes(settingsParts[0]);
67            output = stringToBytes(settingsParts[1]);
68        }
69        public byte[] getInput() {
70            return input;
71        }
72        public byte[] getOutput() {
73            return output;
74        }
75    }
76
77    private static class TestVector {
78        private Mode mode;
79        private Operation operation;
80        private byte[] key;
81        private byte[] iv;
82        private List<Block> blocks = new ArrayList<Block>();
83
84        public TestVector(String settings) {
85            String[] settingsParts = settings.split(",");
86            mode = Mode.valueOf(settingsParts[0]);
87            operation = Operation.valueOf(settingsParts[1]);
88            key = stringToBytes(settingsParts[2]);
89            if (settingsParts.length > 3) {
90                iv = stringToBytes(settingsParts[3]);
91            }
92        }
93
94        public Mode getMode() {
95            return mode;
96        }
97        public Operation getOperation() {
98            return operation;
99        }
100        public byte[] getKey() {
101            return key;
102        }
103        public byte[] getIv() {
104            return iv;
105        }
106        public void addBlock (Block b) {
107            blocks.add(b);
108        }
109        public Iterable<Block> getBlocks() {
110            return blocks;
111        }
112    }
113
114    private static final String VECTOR_FILE_NAME = "NIST_800_38A_vectors.txt";
115    private static final Mode[] REQUIRED_MODES = {Mode.ECB, Mode.CBC, Mode.CTR};
116    private static Set<Mode> supportedModes = new HashSet<Mode>();
117
118    public static void main(String[] args) throws Exception {
119        checkAllProviders();
120        checkSupportedModes();
121    }
122
123    private static byte[] stringToBytes(String v) {
124        if (v.equals("")) {
125            return null;
126        }
127        return DatatypeConverter.parseBase64Binary(v);
128    }
129
130    private static String toModeString(Mode mode) {
131        return mode.toString();
132    }
133
134    private static int toCipherOperation(Operation op) {
135        switch (op) {
136            case Encrypt:
137                return Cipher.ENCRYPT_MODE;
138            case Decrypt:
139                return Cipher.DECRYPT_MODE;
140        }
141
142        throw new RuntimeException("Unknown operation: " + op);
143    }
144
145    private static void log(String str) {
146        System.out.println(str);
147    }
148
149    private static void checkVector(String providerName, TestVector test) {
150
151        String modeString = toModeString(test.getMode());
152        String cipherString = "AES" + "/" + modeString + "/" + "NoPadding";
153        log("checking: " + cipherString + " on " + providerName);
154        try {
155            Cipher cipher = Cipher.getInstance(cipherString, providerName);
156            SecretKeySpec key = new SecretKeySpec(test.getKey(), "AES");
157            if (test.getIv() != null) {
158                IvParameterSpec iv = new IvParameterSpec(test.getIv());
159                cipher.init(toCipherOperation(test.getOperation()), key, iv);
160            }
161            else {
162                cipher.init(toCipherOperation(test.getOperation()), key);
163            }
164            int blockIndex = 0;
165            for (Block curBlock : test.getBlocks()) {
166                byte[] blockOutput = cipher.update(curBlock.getInput());
167                byte[] expectedBlockOutput = curBlock.getOutput();
168                if (!Arrays.equals(blockOutput, expectedBlockOutput)) {
169                    throw new RuntimeException("Blocks do not match at index "
170                        + blockIndex);
171                }
172                blockIndex++;
173            }
174            log("success");
175            supportedModes.add(test.getMode());
176        } catch (NoSuchAlgorithmException ex) {
177            log("algorithm not supported");
178        } catch (NoSuchProviderException | NoSuchPaddingException
179            | InvalidKeyException | InvalidAlgorithmParameterException ex) {
180            throw new RuntimeException(ex);
181        }
182    }
183
184    private static boolean isComment(String line) {
185        return (line != null) && line.startsWith("//");
186    }
187
188    private static TestVector readVector(BufferedReader in) throws IOException {
189        String line;
190        while (isComment(line = in.readLine())) {
191            // skip comment lines
192        }
193        if (line == null || line.isEmpty()) {
194            return null;
195        }
196
197        TestVector newVector = new TestVector(line);
198        String numBlocksStr = in.readLine();
199        int numBlocks = Integer.parseInt(numBlocksStr);
200        for (int i = 0; i < numBlocks; i++) {
201            Block newBlock = new Block(in.readLine());
202            newVector.addBlock(newBlock);
203        }
204
205        return newVector;
206    }
207
208    private static void checkAllProviders() throws IOException {
209        File dataFile = new File(System.getProperty("test.src", "."),
210                                 VECTOR_FILE_NAME);
211        BufferedReader in = new BufferedReader(new FileReader(dataFile));
212        List<TestVector> allTests = new ArrayList<>();
213        TestVector newTest;
214        while ((newTest = readVector(in)) != null) {
215            allTests.add(newTest);
216        }
217
218        for (Provider provider : Security.getProviders()) {
219            checkProvider(provider.getName(), allTests);
220        }
221    }
222
223    private static void checkProvider(String providerName,
224                                      List<TestVector> allVectors)
225        throws IOException {
226
227        for (TestVector curVector : allVectors) {
228            checkVector(providerName, curVector);
229        }
230    }
231
232    /*
233     *  This method helps ensure that the test is working properly by
234     *  verifying that the test was able to check the test vectors for
235     *  some of the modes of operation.
236     */
237    private static void checkSupportedModes() {
238        for (Mode curMode : REQUIRED_MODES) {
239            if (!supportedModes.contains(curMode)) {
240                throw new RuntimeException(
241                    "Mode not supported by any provider: " + curMode);
242            }
243        }
244
245    }
246
247}
248
249