1/*
2 * Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.  Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26package com.oracle.security.ucrypto;
27
28import java.nio.ByteBuffer;
29import java.util.Set;
30import java.util.Arrays;
31import java.util.Locale;
32import java.util.concurrent.ConcurrentSkipListSet;
33import java.lang.ref.*;
34
35import java.security.AlgorithmParameters;
36import java.security.GeneralSecurityException;
37import java.security.InvalidAlgorithmParameterException;
38import java.security.InvalidKeyException;
39import java.security.Key;
40import java.security.NoSuchAlgorithmException;
41import java.security.SecureRandom;
42
43
44import java.security.spec.AlgorithmParameterSpec;
45import java.security.spec.InvalidParameterSpecException;
46
47import javax.crypto.BadPaddingException;
48import javax.crypto.Cipher;
49import javax.crypto.CipherSpi;
50import javax.crypto.IllegalBlockSizeException;
51import javax.crypto.NoSuchPaddingException;
52import javax.crypto.ShortBufferException;
53
54import javax.crypto.spec.IvParameterSpec;
55
56/**
57 * Wrapper class which uses NativeCipher class and Java impls of padding scheme.
58 * This class currently supports
59 * - AES/ECB/PKCS5PADDING
60 * - AES/CBC/PKCS5PADDING
61 * - AES/CFB128/PKCS5PADDING
62 *
63 * @since 9
64 */
65public class NativeCipherWithJavaPadding extends CipherSpi {
66
67    private static interface Padding {
68        // ENC: generate and return the necessary padding bytes
69        int getPadLen(int dataLen);
70
71        // ENC: generate and return the necessary padding bytes
72        byte[] getPaddingBytes(int dataLen);
73
74        // DEC: process the decrypted data and buffer up the potential padding
75        // bytes
76        byte[] bufferBytes(byte[] intermediateData);
77
78        // DEC: return the length of internally buffered pad bytes
79        int getBufferedLength();
80
81        // DEC: unpad and place the output in 'out', starting from outOfs
82        // and return the number of bytes unpadded into 'out'.
83        int unpad(byte[] paddedData, byte[] out, int outOfs)
84                throws BadPaddingException, IllegalBlockSizeException,
85                ShortBufferException;
86
87        // DEC: Clears the padding object to the initial state
88        void clear();
89    }
90
91    private static class PKCS5Padding implements Padding {
92        private final int blockSize;
93        // buffer for storing the potential padding bytes
94        private ByteBuffer trailingBytes = null;
95
96        PKCS5Padding(int blockSize)
97            throws NoSuchPaddingException {
98            if (blockSize == 0) {
99                throw new NoSuchPaddingException
100                        ("PKCS#5 padding not supported with stream ciphers");
101            }
102            this.blockSize = blockSize;
103        }
104
105        public int getPadLen(int dataLen) {
106            return (blockSize - (dataLen & (blockSize - 1)));
107        }
108
109        public byte[] getPaddingBytes(int dataLen) {
110            byte padValue = (byte) getPadLen(dataLen);
111            byte[] paddingBytes = new byte[padValue];
112            Arrays.fill(paddingBytes, padValue);
113            return paddingBytes;
114        }
115
116        public byte[] bufferBytes(byte[] dataFromUpdate) {
117            if (dataFromUpdate == null || dataFromUpdate.length == 0) {
118                return null;
119            }
120            byte[] result = null;
121            if (trailingBytes == null) {
122                trailingBytes = ByteBuffer.wrap(new byte[blockSize]);
123            }
124            int tbSize = trailingBytes.position();
125            if (dataFromUpdate.length > trailingBytes.remaining()) {
126                int totalLen = dataFromUpdate.length + tbSize;
127                int newTBSize = totalLen % blockSize;
128                if (newTBSize == 0) {
129                    newTBSize = blockSize;
130                }
131                if (tbSize == 0) {
132                    result = Arrays.copyOf(dataFromUpdate, totalLen - newTBSize);
133                } else {
134                    // combine 'trailingBytes' and 'dataFromUpdate'
135                    result = Arrays.copyOf(trailingBytes.array(),
136                                           totalLen - newTBSize);
137                    if (result.length != tbSize) {
138                        System.arraycopy(dataFromUpdate, 0, result, tbSize,
139                                         result.length - tbSize);
140                    }
141                }
142                // update 'trailingBytes' w/ remaining bytes in 'dataFromUpdate'
143                trailingBytes.clear();
144                trailingBytes.put(dataFromUpdate,
145                                  dataFromUpdate.length - newTBSize, newTBSize);
146            } else {
147                trailingBytes.put(dataFromUpdate);
148            }
149            return result;
150        }
151
152        public int getBufferedLength() {
153            if (trailingBytes != null) {
154                return trailingBytes.position();
155            }
156            return 0;
157        }
158
159        public int unpad(byte[] lastData, byte[] out, int outOfs)
160                throws BadPaddingException, IllegalBlockSizeException,
161                ShortBufferException {
162            int tbSize = (trailingBytes == null? 0:trailingBytes.position());
163            int dataLen = tbSize + lastData.length;
164
165            // Special handling to match SunJCE provider behavior
166            if (dataLen <= 0) {
167                return 0;
168            } else if (dataLen % blockSize != 0) {
169                UcryptoProvider.debug("PKCS5Padding: unpad, buffered " + tbSize +
170                                 " bytes, last block " + lastData.length + " bytes");
171
172                throw new IllegalBlockSizeException
173                    ("Input length must be multiples of " + blockSize);
174            }
175
176            // check padding bytes
177            if (lastData.length == 0) {
178                if (tbSize != 0) {
179                    // work on 'trailingBytes' directly
180                    lastData = Arrays.copyOf(trailingBytes.array(), tbSize);
181                    trailingBytes.clear();
182                    tbSize = 0;
183                } else {
184                    throw new BadPaddingException("No pad bytes found!");
185                }
186            }
187            byte padValue = lastData[lastData.length - 1];
188            if (padValue < 1 || padValue > blockSize) {
189                UcryptoProvider.debug("PKCS5Padding: unpad, lastData: " + Arrays.toString(lastData));
190                UcryptoProvider.debug("PKCS5Padding: unpad, padValue=" + padValue);
191                throw new BadPaddingException("Invalid pad value: " + padValue);
192            }
193
194            // sanity check padding bytes
195            int padStartIndex = lastData.length - padValue;
196            for (int i = padStartIndex; i < lastData.length; i++) {
197                if (lastData[i] != padValue) {
198                    UcryptoProvider.debug("PKCS5Padding: unpad, lastData: " + Arrays.toString(lastData));
199                    UcryptoProvider.debug("PKCS5Padding: unpad, padValue=" + padValue);
200                    throw new BadPaddingException("Invalid padding bytes!");
201                }
202            }
203
204            int actualOutLen = dataLen - padValue;
205            // check output buffer capacity
206            if (out.length - outOfs < actualOutLen) {
207                throw new ShortBufferException("Output buffer too small, need " + actualOutLen +
208                    ", got " + (out.length - outOfs));
209            }
210            try {
211                if (tbSize != 0) {
212                    trailingBytes.rewind();
213                    if (tbSize < actualOutLen) {
214                        trailingBytes.get(out, outOfs, tbSize);
215                        outOfs += tbSize;
216                    } else {
217                        // copy from trailingBytes and we are done
218                        trailingBytes.get(out, outOfs, actualOutLen);
219                        return actualOutLen;
220                    }
221                }
222                if (lastData.length > padValue) {
223                    System.arraycopy(lastData, 0, out, outOfs,
224                                     lastData.length - padValue);
225                }
226                return actualOutLen;
227            } finally {
228                clear();
229            }
230        }
231
232        public void clear() {
233            if (trailingBytes != null) trailingBytes.clear();
234        }
235    }
236
237    public static final class AesEcbPKCS5 extends NativeCipherWithJavaPadding {
238        public AesEcbPKCS5() throws NoSuchAlgorithmException, NoSuchPaddingException {
239            super(new NativeCipher.AesEcbNoPadding(), "PKCS5Padding");
240        }
241    }
242
243    public static final class AesCbcPKCS5 extends NativeCipherWithJavaPadding {
244        public AesCbcPKCS5() throws NoSuchAlgorithmException, NoSuchPaddingException {
245            super(new NativeCipher.AesCbcNoPadding(), "PKCS5Padding");
246        }
247    }
248
249    public static final class AesCfb128PKCS5 extends NativeCipherWithJavaPadding {
250        public AesCfb128PKCS5() throws NoSuchAlgorithmException, NoSuchPaddingException {
251            super(new NativeCipher.AesCfb128NoPadding(), "PKCS5Padding");
252        }
253    }
254
255    // fields (re)set in every init()
256    private final NativeCipher nc;
257    private final Padding padding;
258    private final int blockSize;
259    private int lastBlockLen = 0;
260
261    // Only ECB, CBC, CTR, and CFB128 modes w/ NOPADDING for now
262    NativeCipherWithJavaPadding(NativeCipher nc, String paddingScheme)
263        throws NoSuchAlgorithmException, NoSuchPaddingException {
264        this.nc = nc;
265        this.blockSize = nc.engineGetBlockSize();
266        if (paddingScheme.toUpperCase(Locale.ROOT).equals("PKCS5PADDING")) {
267            padding = new PKCS5Padding(blockSize);
268        } else {
269            throw new NoSuchAlgorithmException("Unsupported padding scheme: " + paddingScheme);
270        }
271    }
272
273    void reset() {
274        padding.clear();
275        lastBlockLen = 0;
276    }
277
278    @Override
279    protected synchronized void engineSetMode(String mode) throws NoSuchAlgorithmException {
280        nc.engineSetMode(mode);
281    }
282
283    // see JCE spec
284    @Override
285    protected void engineSetPadding(String padding)
286            throws NoSuchPaddingException {
287        // Disallow change of padding for now since currently it's explicitly
288        // defined in transformation strings
289        throw new NoSuchPaddingException("Unsupported padding " + padding);
290    }
291
292    // see JCE spec
293    @Override
294    protected int engineGetBlockSize() {
295        return blockSize;
296    }
297
298    // see JCE spec
299    @Override
300    protected synchronized int engineGetOutputSize(int inputLen) {
301        int result = nc.engineGetOutputSize(inputLen);
302        if (nc.encrypt) {
303            result += padding.getPadLen(result);
304        } else {
305            result += padding.getBufferedLength();
306        }
307        return result;
308    }
309
310    // see JCE spec
311    @Override
312    protected synchronized byte[] engineGetIV() {
313        return nc.engineGetIV();
314    }
315
316    // see JCE spec
317    @Override
318    protected synchronized AlgorithmParameters engineGetParameters() {
319        return nc.engineGetParameters();
320    }
321
322    @Override
323    protected int engineGetKeySize(Key key) throws InvalidKeyException {
324        return nc.engineGetKeySize(key);
325    }
326
327    // see JCE spec
328    @Override
329    protected synchronized void engineInit(int opmode, Key key, SecureRandom random)
330            throws InvalidKeyException {
331        reset();
332        nc.engineInit(opmode, key, random);
333    }
334
335    // see JCE spec
336    @Override
337    protected synchronized void engineInit(int opmode, Key key,
338            AlgorithmParameterSpec params, SecureRandom random)
339            throws InvalidKeyException, InvalidAlgorithmParameterException {
340        reset();
341        nc.engineInit(opmode, key, params, random);
342    }
343
344    // see JCE spec
345    @Override
346    protected synchronized void engineInit(int opmode, Key key, AlgorithmParameters params,
347            SecureRandom random)
348            throws InvalidKeyException, InvalidAlgorithmParameterException {
349        reset();
350        nc.engineInit(opmode, key, params, random);
351    }
352
353    // see JCE spec
354    @Override
355    protected synchronized byte[] engineUpdate(byte[] in, int inOfs, int inLen) {
356        if (nc.encrypt) {
357            lastBlockLen += inLen;
358            lastBlockLen &= (blockSize - 1);
359            return nc.engineUpdate(in, inOfs, inLen);
360        } else {
361            return padding.bufferBytes(nc.engineUpdate(in, inOfs, inLen));
362        }
363    }
364
365    // see JCE spec
366    @Override
367    protected synchronized int engineUpdate(byte[] in, int inOfs, int inLen, byte[] out,
368            int outOfs) throws ShortBufferException {
369        if (nc.encrypt) {
370            lastBlockLen += inLen;
371            lastBlockLen &= (blockSize - 1);
372            return nc.engineUpdate(in, inOfs, inLen, out, outOfs);
373        } else {
374            byte[] result = padding.bufferBytes(nc.engineUpdate(in, inOfs, inLen));
375            if (result != null) {
376                System.arraycopy(result, 0, out, outOfs, result.length);
377                return result.length;
378            } else return 0;
379        }
380    }
381
382    // see JCE spec
383    @Override
384    protected synchronized byte[] engineDoFinal(byte[] in, int inOfs, int inLen)
385            throws IllegalBlockSizeException, BadPaddingException {
386        int estimatedOutLen = engineGetOutputSize(inLen);
387        byte[] out = new byte[estimatedOutLen];
388        try {
389            int actualOut = this.engineDoFinal(in, inOfs, inLen, out, 0);
390            // truncate off extra bytes
391            if (actualOut != out.length) {
392                out = Arrays.copyOf(out, actualOut);
393            }
394        } catch (ShortBufferException sbe) {
395            throw new UcryptoException("Internal Error", sbe);
396        } finally {
397            reset();
398        }
399        return out;
400    }
401
402    // see JCE spec
403    @Override
404    protected synchronized int engineDoFinal(byte[] in, int inOfs, int inLen, byte[] out,
405                                             int outOfs)
406        throws ShortBufferException, IllegalBlockSizeException,
407               BadPaddingException {
408        int estimatedOutLen = engineGetOutputSize(inLen);
409        if (out.length - outOfs < estimatedOutLen) {
410            throw new ShortBufferException("Actual: " + (out.length - outOfs) +
411                ". Estimated Out Length: " + estimatedOutLen);
412        }
413        try {
414            if (nc.encrypt) {
415                int k = nc.engineUpdate(in, inOfs, inLen, out, outOfs);
416                lastBlockLen += inLen;
417                lastBlockLen &= (blockSize - 1);
418                byte[] padBytes = padding.getPaddingBytes(lastBlockLen);
419                k += nc.engineDoFinal(padBytes, 0, padBytes.length, out, (outOfs + k));
420                return k;
421            } else {
422                byte[] tempOut = nc.engineDoFinal(in, inOfs, inLen);
423                int len = padding.unpad(tempOut, out, outOfs);
424                return len;
425            }
426        } finally {
427            reset();
428        }
429    }
430
431    // see JCE spec
432    @Override
433    protected synchronized byte[] engineWrap(Key key) throws IllegalBlockSizeException,
434                                                InvalidKeyException {
435        byte[] result = null;
436        try {
437            byte[] encodedKey = key.getEncoded();
438            if ((encodedKey == null) || (encodedKey.length == 0)) {
439                throw new InvalidKeyException("Cannot get an encoding of " +
440                                              "the key to be wrapped");
441            }
442            result = engineDoFinal(encodedKey, 0, encodedKey.length);
443        } catch (BadPaddingException e) {
444            // Should never happen for key wrapping
445            throw new UcryptoException("Internal Error", e);
446        }
447        return result;
448    }
449
450    // see JCE spec
451    @Override
452    protected synchronized Key engineUnwrap(byte[] wrappedKey, String wrappedKeyAlgorithm,
453                               int wrappedKeyType)
454        throws InvalidKeyException, NoSuchAlgorithmException {
455
456        byte[] encodedKey;
457        try {
458            encodedKey = engineDoFinal(wrappedKey, 0,
459                                       wrappedKey.length);
460        } catch (Exception e) {
461            throw (InvalidKeyException)
462                (new InvalidKeyException()).initCause(e);
463        }
464
465        return NativeCipher.constructKey(wrappedKeyType, encodedKey,
466                                         wrappedKeyAlgorithm);
467    }
468}
469