1/*
2 * Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.  Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26package com.oracle.security.ucrypto;
27
28import java.util.Set;
29import java.util.Arrays;
30import java.util.concurrent.ConcurrentSkipListSet;
31import java.lang.ref.*;
32import java.math.BigInteger;
33import java.nio.ByteBuffer;
34
35import java.security.SignatureSpi;
36import java.security.NoSuchAlgorithmException;
37import java.security.InvalidParameterException;
38import java.security.InvalidKeyException;
39import java.security.SignatureException;
40import java.security.Key;
41import java.security.PrivateKey;
42import java.security.PublicKey;
43
44import java.security.*;
45import java.security.interfaces.*;
46import java.security.spec.*;
47
48import sun.nio.ch.DirectBuffer;
49import java.nio.ByteBuffer;
50
51/**
52 * Signature implementation class. This class currently supports the
53 * following algorithms:
54 *
55 * . RSA:
56 *   . MD5withRSA
57 *   . SHA1withRSA
58 *   . SHA256withRSA
59 *   . SHA384withRSA
60 *   . SHA512withRSA
61 *
62 * @since 9
63 */
64class NativeRSASignature extends SignatureSpi {
65
66    private static final int PKCS1PADDING_LEN = 11;
67
68    // fields set in constructor
69    private final UcryptoMech mech;
70    private final int encodedLen;
71
72    // field for ensuring native memory is freed
73    private SignatureContextRef pCtxt = null;
74
75    //
76    // fields (re)set in every init()
77    //
78    private boolean initialized = false;
79    private boolean sign = true;
80    private int sigLength;
81    private NativeKey key;
82    private NativeRSAKeyFactory keyFactory; // may need a more generic type later
83
84    // public implementation classes
85    public static final class MD5 extends NativeRSASignature {
86        public MD5() throws NoSuchAlgorithmException {
87            super(UcryptoMech.CRYPTO_MD5_RSA_PKCS, 34);
88        }
89    }
90
91    public static final class SHA1 extends NativeRSASignature {
92        public SHA1() throws NoSuchAlgorithmException {
93            super(UcryptoMech.CRYPTO_SHA1_RSA_PKCS, 35);
94        }
95    }
96
97    public static final class SHA256 extends NativeRSASignature {
98        public SHA256() throws NoSuchAlgorithmException {
99            super(UcryptoMech.CRYPTO_SHA256_RSA_PKCS, 51);
100        }
101    }
102
103    public static final class SHA384 extends NativeRSASignature {
104        public SHA384() throws NoSuchAlgorithmException {
105            super(UcryptoMech.CRYPTO_SHA384_RSA_PKCS, 67);
106        }
107    }
108
109    public static final class SHA512 extends NativeRSASignature {
110        public SHA512() throws NoSuchAlgorithmException {
111            super(UcryptoMech.CRYPTO_SHA512_RSA_PKCS, 83);
112        }
113    }
114
115    // internal class for native resource cleanup
116    private static class SignatureContextRef extends PhantomReference<NativeRSASignature>
117        implements Comparable<SignatureContextRef> {
118
119        private static ReferenceQueue<NativeRSASignature> refQueue =
120            new ReferenceQueue<NativeRSASignature>();
121
122        // Needed to keep these references from being GC'ed until when their
123        // referents are GC'ed so we can do post-mortem processing
124        private static Set<SignatureContextRef> refList =
125            new ConcurrentSkipListSet<SignatureContextRef>();
126        //           Collections.synchronizedSortedSet(new TreeSet<SignatureContextRef>());
127
128        private final long id;
129        private final boolean sign;
130
131        private static void drainRefQueueBounded() {
132            while (true) {
133                SignatureContextRef next = (SignatureContextRef) refQueue.poll();
134                if (next == null) break;
135                next.dispose(true);
136            }
137        }
138
139        SignatureContextRef(NativeRSASignature ns, long id, boolean sign) {
140            super(ns, refQueue);
141            this.id = id;
142            this.sign = sign;
143            refList.add(this);
144            UcryptoProvider.debug("Resource: track Signature Ctxt " + this.id);
145            drainRefQueueBounded();
146        }
147
148        public int compareTo(SignatureContextRef other) {
149            if (this.id == other.id) {
150                return 0;
151            } else {
152                return (this.id < other.id) ? -1 : 1;
153            }
154        }
155
156        void dispose(boolean doCancel) {
157            refList.remove(this);
158            try {
159                if (doCancel) {
160                    UcryptoProvider.debug("Resource: free Signature Ctxt " + this.id);
161                    NativeRSASignature.nativeFinal(id, sign, null, 0, 0);
162                } else {
163                    UcryptoProvider.debug("Resource: stop tracking Signature Ctxt " + this.id);
164                }
165            } finally {
166                this.clear();
167            }
168        }
169    }
170
171    NativeRSASignature(UcryptoMech mech, int encodedLen)
172        throws NoSuchAlgorithmException {
173        this.mech = mech;
174        this.encodedLen = encodedLen;
175        this.keyFactory = new NativeRSAKeyFactory();
176    }
177
178    // deprecated but abstract
179    @SuppressWarnings("deprecation")
180    protected Object engineGetParameter(String param) throws InvalidParameterException {
181        throw new UnsupportedOperationException("getParameter() not supported");
182    }
183
184    @Override
185    protected synchronized void engineInitSign(PrivateKey privateKey)
186            throws InvalidKeyException {
187        if (privateKey == null) {
188            throw new InvalidKeyException("Key must not be null");
189        }
190        NativeKey newKey = key;
191        int newSigLength = sigLength;
192        // Need to check RSA key length whenever a new private key is set
193        if (privateKey != key) {
194            if (!(privateKey instanceof RSAPrivateKey)) {
195                throw new InvalidKeyException("RSAPrivateKey required. " +
196                    "Received: " + privateKey.getClass().getName());
197            }
198            RSAPrivateKey rsaPrivKey = (RSAPrivateKey) privateKey;
199            BigInteger mod = rsaPrivKey.getModulus();
200            newSigLength = checkRSAKeyLength(mod);
201            BigInteger pe = rsaPrivKey.getPrivateExponent();
202            try {
203                if (rsaPrivKey instanceof RSAPrivateCrtKey) {
204                    RSAPrivateCrtKey rsaPrivCrtKey = (RSAPrivateCrtKey) rsaPrivKey;
205                    newKey = (NativeKey) keyFactory.engineGeneratePrivate
206                        (new RSAPrivateCrtKeySpec(mod,
207                                                  rsaPrivCrtKey.getPublicExponent(),
208                                                  pe,
209                                                  rsaPrivCrtKey.getPrimeP(),
210                                                  rsaPrivCrtKey.getPrimeQ(),
211                                                  rsaPrivCrtKey.getPrimeExponentP(),
212                                                  rsaPrivCrtKey.getPrimeExponentQ(),
213                                                  rsaPrivCrtKey.getCrtCoefficient()));
214                } else {
215                    newKey = (NativeKey) keyFactory.engineGeneratePrivate
216                           (new RSAPrivateKeySpec(mod, pe));
217                }
218            } catch (InvalidKeySpecException ikse) {
219                throw new InvalidKeyException(ikse);
220            }
221        }
222        init(true, newKey, newSigLength);
223    }
224
225
226    @Override
227    protected synchronized void engineInitVerify(PublicKey publicKey)
228            throws InvalidKeyException {
229        if (publicKey == null) {
230            throw new InvalidKeyException("Key must not be null");
231        }
232        NativeKey newKey = key;
233        int newSigLength = sigLength;
234        // Need to check RSA key length whenever a new public key is set
235        if (publicKey != key) {
236            if (publicKey instanceof RSAPublicKey) {
237                BigInteger mod = ((RSAPublicKey) publicKey).getModulus();
238                newSigLength = checkRSAKeyLength(mod);
239                try {
240                    newKey = (NativeKey) keyFactory.engineGeneratePublic
241                        (new RSAPublicKeySpec(mod, ((RSAPublicKey) publicKey).getPublicExponent()));
242                } catch (InvalidKeySpecException ikse) {
243                    throw new InvalidKeyException(ikse);
244                }
245            } else {
246                throw new InvalidKeyException("RSAPublicKey required. " +
247                    "Received: " + publicKey.getClass().getName());
248            }
249        }
250        init(false, newKey, newSigLength);
251    }
252
253    // deprecated but abstract
254    @SuppressWarnings("deprecation")
255    protected void engineSetParameter(String param, Object value) throws InvalidParameterException {
256        throw new UnsupportedOperationException("setParameter() not supported");
257    }
258
259    @Override
260    protected synchronized byte[] engineSign() throws SignatureException {
261        try {
262            byte[] sig = new byte[sigLength];
263            int rv = doFinal(sig, 0, sigLength);
264            if (rv < 0) {
265                throw new SignatureException(new UcryptoException(-rv));
266            }
267            return sig;
268        } finally {
269            // doFinal should already be called, no need to cancel
270            reset(false);
271        }
272    }
273
274    @Override
275    protected synchronized int engineSign(byte[] outbuf, int offset, int len)
276        throws SignatureException {
277        boolean doCancel = true;
278        try {
279            if (outbuf == null || (offset < 0) || (outbuf.length < (offset + sigLength))
280                || (len < sigLength)) {
281                throw new SignatureException("Invalid output buffer. offset: " +
282                    offset + ". len: " + len + ". sigLength: " + sigLength);
283            }
284            int rv = doFinal(outbuf, offset, sigLength);
285            doCancel = false;
286            if (rv < 0) {
287                throw new SignatureException(new UcryptoException(-rv));
288            }
289            return sigLength;
290        } finally {
291            reset(doCancel);
292        }
293    }
294
295    @Override
296    protected synchronized void engineUpdate(byte b) throws SignatureException {
297        byte[] in = { b };
298        int rv = update(in, 0, 1);
299        if (rv < 0) {
300            throw new SignatureException(new UcryptoException(-rv));
301        }
302    }
303
304    @Override
305    protected synchronized void engineUpdate(byte[] in, int inOfs, int inLen)
306            throws SignatureException {
307        if (in == null || inOfs < 0 || inLen == 0) return;
308
309        int rv = update(in, inOfs, inLen);
310        if (rv < 0) {
311            throw new SignatureException(new UcryptoException(-rv));
312        }
313    }
314
315    @Override
316    protected synchronized void engineUpdate(ByteBuffer in) {
317        if (in == null || in.remaining() == 0) return;
318
319        if (in instanceof DirectBuffer == false) {
320            // cannot do better than default impl
321            super.engineUpdate(in);
322            return;
323        }
324        long inAddr = ((DirectBuffer)in).address();
325        int inOfs = in.position();
326        int inLen = in.remaining();
327
328        int rv = update((inAddr + inOfs), inLen);
329        if (rv < 0) {
330            throw new UcryptoException(-rv);
331        }
332        in.position(inOfs + inLen);
333    }
334
335    @Override
336    protected synchronized boolean engineVerify(byte[] sigBytes) throws SignatureException {
337        return engineVerify(sigBytes, 0, sigBytes.length);
338    }
339
340    @Override
341    protected synchronized boolean engineVerify(byte[] sigBytes, int sigOfs, int sigLen)
342        throws SignatureException {
343        boolean doCancel = true;
344        try {
345            if (sigBytes == null || (sigOfs < 0) || (sigBytes.length < (sigOfs + this.sigLength))
346                || (sigLen != this.sigLength)) {
347                throw new SignatureException("Invalid signature length: got " +
348                    sigLen + " but was expecting " + this.sigLength);
349            }
350
351            int rv = doFinal(sigBytes, sigOfs, sigLen);
352            doCancel = false;
353            if (rv == 0) {
354                return true;
355            } else {
356                UcryptoProvider.debug("Signature: " + mech + " verification error " +
357                             new UcryptoException(-rv).getMessage());
358                return false;
359            }
360        } finally {
361            reset(doCancel);
362        }
363    }
364
365    void reset(boolean doCancel) {
366        initialized = false;
367        if (pCtxt != null) {
368            pCtxt.dispose(doCancel);
369            pCtxt = null;
370        }
371    }
372
373    /**
374     * calls ucrypto_sign_init(...) or ucrypto_verify_init(...)
375     * @return pointer to the context
376     */
377    private native static long nativeInit(int mech, boolean sign,
378                                          long keyValue, int keyLength);
379
380    /**
381     * calls ucrypto_sign_update(...) or ucrypto_verify_update(...)
382     * @return an error status code (0 means SUCCESS)
383     */
384    private native static int nativeUpdate(long pContext, boolean sign,
385                                           byte[] in, int inOfs, int inLen);
386    /**
387     * calls ucrypto_sign_update(...) or ucrypto_verify_update(...)
388     * @return an error status code (0 means SUCCESS)
389     */
390    private native static int nativeUpdate(long pContext, boolean sign,
391                                           long pIn, int inLen);
392
393    /**
394     * calls ucrypto_sign_final(...) or ucrypto_verify_final(...)
395     * @return the length of signature bytes or verification status.
396     * If negative, it indicates an error status code
397     */
398    private native static int nativeFinal(long pContext, boolean sign,
399                                          byte[] sig, int sigOfs, int sigLen);
400
401    // actual init() implementation - caller should clone key if needed
402    private void init(boolean sign, NativeKey key, int sigLength) {
403        reset(true);
404        this.sign = sign;
405        this.sigLength = sigLength;
406        this.key = key;
407        long pCtxtVal = nativeInit(mech.value(), sign, key.value(),
408                                   key.length());
409        initialized = (pCtxtVal != 0L);
410        if (initialized) {
411            pCtxt = new SignatureContextRef(this, pCtxtVal, sign);
412        } else {
413            throw new UcryptoException("Cannot initialize Signature");
414        }
415    }
416
417    private void ensureInitialized() {
418        if (!initialized) {
419            init(sign, key, sigLength);
420            if (!initialized) {
421                throw new UcryptoException("Cannot initialize Signature");
422            }
423        }
424    }
425
426    // returns 0 (success) or negative (ucrypto error occurred)
427    private int update(byte[] in, int inOfs, int inLen) {
428        if (inOfs < 0 || inOfs + inLen > in.length) {
429            throw new ArrayIndexOutOfBoundsException("inOfs :" + inOfs +
430                ". inLen: " + inLen + ". in.length: " + in.length);
431        }
432        ensureInitialized();
433        int k = nativeUpdate(pCtxt.id, sign, in, inOfs, inLen);
434        if (k < 0) {
435            reset(false);
436        }
437        return k;
438    }
439
440    // returns 0 (success) or negative (ucrypto error occurred)
441    private int update(long pIn, int inLen) {
442        ensureInitialized();
443        int k = nativeUpdate(pCtxt.id, sign, pIn, inLen);
444        if (k < 0) {
445            reset(false);
446        }
447        return k;
448    }
449
450    // returns 0 (success) or negative (ucrypto error occurred)
451    private int doFinal(byte[] sigBytes, int sigOfs, int sigLen) {
452        ensureInitialized();
453        int k = nativeFinal(pCtxt.id, sign, sigBytes, sigOfs, sigLen);
454        return k;
455    }
456
457    // check and return RSA key size in number of bytes
458    private int checkRSAKeyLength(BigInteger mod) throws InvalidKeyException {
459        int keySize = (mod.bitLength() + 7) >> 3;
460        int maxDataSize = keySize - PKCS1PADDING_LEN;
461        if (maxDataSize < encodedLen) {
462            throw new InvalidKeyException
463                ("Key is too short for this signature algorithm. maxDataSize: " +
464                    maxDataSize + ". encodedLen: " + encodedLen);
465        }
466        return keySize;
467    }
468}
469