1/*	$NetBSD: sntrup761.c,v 1.3 2023/07/26 17:58:15 christos Exp $	*/
2/*  $OpenBSD: sntrup761.c,v 1.6 2023/01/11 02:13:52 djm Exp $ */
3
4/*
5 * Public Domain, Authors:
6 * - Daniel J. Bernstein
7 * - Chitchanok Chuengsatiansup
8 * - Tanja Lange
9 * - Christine van Vredendaal
10 */
11#include "includes.h"
12__RCSID("$NetBSD: sntrup761.c,v 1.3 2023/07/26 17:58:15 christos Exp $");
13
14#include <string.h>
15#include "crypto_api.h"
16
17#define int8 crypto_int8
18#define uint8 crypto_uint8
19#define int16 crypto_int16
20#define uint16 crypto_uint16
21#define int32 crypto_int32
22#define uint32 crypto_uint32
23#define int64 crypto_int64
24#define uint64 crypto_uint64
25
26/* from supercop-20201130/crypto_sort/int32/portable4/int32_minmax.inc */
27#define int32_MINMAX(a,b) \
28do { \
29  int64_t ab = (int64_t)b ^ (int64_t)a; \
30  int64_t c = (int64_t)b - (int64_t)a; \
31  c ^= ab & (c ^ b); \
32  c >>= 31; \
33  c &= ab; \
34  a ^= c; \
35  b ^= c; \
36} while(0)
37
38/* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
39
40
41static void crypto_sort_int32(void *array,long long n)
42{
43  long long top,p,q,r,i,j;
44  int32 *x = array;
45
46  if (n < 2) return;
47  top = 1;
48  while (top < n - top) top += top;
49
50  for (p = top;p >= 1;p >>= 1) {
51    i = 0;
52    while (i + 2 * p <= n) {
53      for (j = i;j < i + p;++j)
54        int32_MINMAX(x[j],x[j+p]);
55      i += 2 * p;
56    }
57    for (j = i;j < n - p;++j)
58      int32_MINMAX(x[j],x[j+p]);
59
60    i = 0;
61    j = 0;
62    for (q = top;q > p;q >>= 1) {
63      if (j != i) for (;;) {
64        if (j == n - q) goto done;
65        int32 a = x[j + p];
66        for (r = q;r > p;r >>= 1)
67          int32_MINMAX(a,x[j + r]);
68        x[j + p] = a;
69        ++j;
70        if (j == i + p) {
71          i += 2 * p;
72          break;
73        }
74      }
75      while (i + p <= n - q) {
76        for (j = i;j < i + p;++j) {
77          int32 a = x[j + p];
78          for (r = q;r > p;r >>= 1)
79            int32_MINMAX(a,x[j+r]);
80          x[j + p] = a;
81        }
82        i += 2 * p;
83      }
84      /* now i + p > n - q */
85      j = i;
86      while (j < n - q) {
87        int32 a = x[j + p];
88        for (r = q;r > p;r >>= 1)
89          int32_MINMAX(a,x[j+r]);
90        x[j + p] = a;
91        ++j;
92      }
93
94      done: ;
95    }
96  }
97}
98
99/* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
100
101/* can save time by vectorizing xor loops */
102/* can save time by integrating xor loops with int32_sort */
103
104static void crypto_sort_uint32(void *array,long long n)
105{
106  crypto_uint32 *x = array;
107  long long j;
108  for (j = 0;j < n;++j) x[j] ^= 0x80000000;
109  crypto_sort_int32(array,n);
110  for (j = 0;j < n;++j) x[j] ^= 0x80000000;
111}
112
113/* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
114
115/*
116CPU division instruction typically takes time depending on x.
117This software is designed to take time independent of x.
118Time still varies depending on m; user must ensure that m is constant.
119Time also varies on CPUs where multiplication is variable-time.
120There could be more CPU issues.
121There could also be compiler issues.
122*/
123
124static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
125{
126  uint32 v = 0x80000000;
127  uint32 qpart;
128  uint32 mask;
129
130  v /= m;
131
132  /* caller guarantees m > 0 */
133  /* caller guarantees m < 16384 */
134  /* vm <= 2^31 <= vm+m-1 */
135  /* xvm <= 2^31 x <= xvm+x(m-1) */
136
137  *q = 0;
138
139  qpart = (x*(uint64)v)>>31;
140  /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
141  /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
142  /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
143  /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
144  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
145  /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
146
147  x -= qpart*m; *q += qpart;
148  /* x <= 49146 */
149
150  qpart = (x*(uint64)v)>>31;
151  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
152  /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
153  /* 0 <= newx <= m + 0.4 */
154  /* 0 <= newx <= m */
155
156  x -= qpart*m; *q += qpart;
157  /* x <= m */
158
159  x -= m; *q += 1;
160  mask = -(x>>31);
161  x += mask&(uint32)m; *q += mask;
162  /* x < m */
163
164  *r = x;
165}
166
167
168static uint16 uint32_mod_uint14(uint32 x,uint16 m)
169{
170  uint32 q;
171  uint16 r;
172  uint32_divmod_uint14(&q,&r,x,m);
173  return r;
174}
175
176/* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
177
178static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
179{
180  uint32 uq,uq2;
181  uint16 ur,ur2;
182  uint32 mask;
183
184  uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
185  uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
186  ur -= ur2; uq -= uq2;
187  mask = -(uint32)(ur>>15);
188  ur += mask&m; uq += mask;
189  *r = ur; *q = uq;
190}
191
192
193static uint16 int32_mod_uint14(int32 x,uint16 m)
194{
195  int32 q;
196  uint16 r;
197  int32_divmod_uint14(&q,&r,x,m);
198  return r;
199}
200
201/* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
202/* pick one of these three: */
203#define SIZE761
204#undef SIZE653
205#undef SIZE857
206
207/* pick one of these two: */
208#define SNTRUP /* Streamlined NTRU Prime */
209#undef LPR /* NTRU LPRime */
210
211/* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
212#ifndef params_H
213#define params_H
214
215/* menu of parameter choices: */
216
217
218/* what the menu means: */
219
220#if defined(SIZE761)
221#define p 761
222#define q 4591
223#define Rounded_bytes 1007
224#ifndef LPR
225#define Rq_bytes 1158
226#define w 286
227#else
228#define w 250
229#define tau0 2156
230#define tau1 114
231#define tau2 2007
232#define tau3 287
233#endif
234
235#elif defined(SIZE653)
236#define p 653
237#define q 4621
238#define Rounded_bytes 865
239#ifndef LPR
240#define Rq_bytes 994
241#define w 288
242#else
243#define w 252
244#define tau0 2175
245#define tau1 113
246#define tau2 2031
247#define tau3 290
248#endif
249
250#elif defined(SIZE857)
251#define p 857
252#define q 5167
253#define Rounded_bytes 1152
254#ifndef LPR
255#define Rq_bytes 1322
256#define w 322
257#else
258#define w 281
259#define tau0 2433
260#define tau1 101
261#define tau2 2265
262#define tau3 324
263#endif
264
265#else
266#error "no parameter set defined"
267#endif
268
269#ifdef LPR
270#define I 256
271#endif
272
273#endif
274
275/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
276#ifndef Decode_H
277#define Decode_H
278
279
280/* Decode(R,s,M,len) */
281/* assumes 0 < M[i] < 16384 */
282/* produces 0 <= R[i] < M[i] */
283
284#endif
285
286/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
287
288static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
289{
290  if (len == 1) {
291    if (M[0] == 1)
292      *out = 0;
293    else if (M[0] <= 256)
294      *out = uint32_mod_uint14(S[0],M[0]);
295    else
296      *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
297  }
298  if (len > 1) {
299    uint16 R2[(len+1)/2];
300    uint16 M2[(len+1)/2];
301    uint16 bottomr[len/2];
302    uint32 bottomt[len/2];
303    long long i;
304    for (i = 0;i < len-1;i += 2) {
305      uint32 m = M[i]*(uint32) M[i+1];
306      if (m > 256*16383) {
307        bottomt[i/2] = 256*256;
308        bottomr[i/2] = S[0]+256*S[1];
309        S += 2;
310        M2[i/2] = (((m+255)>>8)+255)>>8;
311      } else if (m >= 16384) {
312        bottomt[i/2] = 256;
313        bottomr[i/2] = S[0];
314        S += 1;
315        M2[i/2] = (m+255)>>8;
316      } else {
317        bottomt[i/2] = 1;
318        bottomr[i/2] = 0;
319        M2[i/2] = m;
320      }
321    }
322    if (i < len)
323      M2[i/2] = M[i];
324    Decode(R2,S,M2,(len+1)/2);
325    for (i = 0;i < len-1;i += 2) {
326      uint32 r = bottomr[i/2];
327      uint32 r1;
328      uint16 r0;
329      r += bottomt[i/2]*R2[i/2];
330      uint32_divmod_uint14(&r1,&r0,r,M[i]);
331      r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
332      *out++ = r0;
333      *out++ = r1;
334    }
335    if (i < len)
336      *out++ = R2[i/2];
337  }
338}
339
340/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
341#ifndef Encode_H
342#define Encode_H
343
344
345/* Encode(s,R,M,len) */
346/* assumes 0 <= R[i] < M[i] < 16384 */
347
348#endif
349
350/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
351
352/* 0 <= R[i] < M[i] < 16384 */
353static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
354{
355  if (len == 1) {
356    uint16 r = R[0];
357    uint16 m = M[0];
358    while (m > 1) {
359      *out++ = r;
360      r >>= 8;
361      m = (m+255)>>8;
362    }
363  }
364  if (len > 1) {
365    uint16 R2[(len+1)/2];
366    uint16 M2[(len+1)/2];
367    long long i;
368    for (i = 0;i < len-1;i += 2) {
369      uint32 m0 = M[i];
370      uint32 r = R[i]+R[i+1]*m0;
371      uint32 m = M[i+1]*m0;
372      while (m >= 16384) {
373        *out++ = r;
374        r >>= 8;
375        m = (m+255)>>8;
376      }
377      R2[i/2] = r;
378      M2[i/2] = m;
379    }
380    if (i < len) {
381      R2[i/2] = R[i];
382      M2[i/2] = M[i];
383    }
384    Encode(out,R2,M2,(len+1)/2);
385  }
386}
387
388/* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
389
390#ifdef LPR
391#endif
392
393
394/* ----- masks */
395
396#ifndef LPR
397
398/* return -1 if x!=0; else return 0 */
399static int int16_nonzero_mask(int16 x)
400{
401  uint16 u = x; /* 0, else 1...65535 */
402  uint32 v = u; /* 0, else 1...65535 */
403  v = -v; /* 0, else 2^32-65535...2^32-1 */
404  v >>= 31; /* 0, else 1 */
405  return -v; /* 0, else -1 */
406}
407
408#endif
409
410/* return -1 if x<0; otherwise return 0 */
411static int int16_negative_mask(int16 x)
412{
413  uint16 u = x;
414  u >>= 15;
415  return -(int) u;
416  /* alternative with gcc -fwrapv: */
417  /* x>>15 compiles to CPU's arithmetic right shift */
418}
419
420/* ----- arithmetic mod 3 */
421
422typedef int8 small;
423
424/* F3 is always represented as -1,0,1 */
425/* so ZZ_fromF3 is a no-op */
426
427/* x must not be close to top int16 */
428static small F3_freeze(int16 x)
429{
430  return int32_mod_uint14(x+1,3)-1;
431}
432
433/* ----- arithmetic mod q */
434
435#define q12 ((q-1)/2)
436typedef int16 Fq;
437/* always represented as -q12...q12 */
438/* so ZZ_fromFq is a no-op */
439
440/* x must not be close to top int32 */
441static Fq Fq_freeze(int32 x)
442{
443  return int32_mod_uint14(x+q12,q)-q12;
444}
445
446#ifndef LPR
447
448static Fq Fq_recip(Fq a1)
449{
450  int i = 1;
451  Fq ai = a1;
452
453  while (i < q-2) {
454    ai = Fq_freeze(a1*(int32)ai);
455    i += 1;
456  }
457  return ai;
458}
459
460#endif
461
462/* ----- Top and Right */
463
464#ifdef LPR
465#define tau 16
466
467static int8 Top(Fq C)
468{
469  return (tau1*(int32)(C+tau0)+16384)>>15;
470}
471
472static Fq Right(int8 T)
473{
474  return Fq_freeze(tau3*(int32)T-tau2);
475}
476#endif
477
478/* ----- small polynomials */
479
480#ifndef LPR
481
482/* 0 if Weightw_is(r), else -1 */
483static int Weightw_mask(small *r)
484{
485  int weight = 0;
486  int i;
487
488  for (i = 0;i < p;++i) weight += r[i]&1;
489  return int16_nonzero_mask(weight-w);
490}
491
492/* R3_fromR(R_fromRq(r)) */
493static void R3_fromRq(small *out,const Fq *r)
494{
495  int i;
496  for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
497}
498
499/* h = f*g in the ring R3 */
500static void R3_mult(small *h,const small *f,const small *g)
501{
502  small fg[p+p-1];
503  small result;
504  int i,j;
505
506  for (i = 0;i < p;++i) {
507    result = 0;
508    for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
509    fg[i] = result;
510  }
511  for (i = p;i < p+p-1;++i) {
512    result = 0;
513    for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
514    fg[i] = result;
515  }
516
517  for (i = p+p-2;i >= p;--i) {
518    fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
519    fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
520  }
521
522  for (i = 0;i < p;++i) h[i] = fg[i];
523}
524
525/* returns 0 if recip succeeded; else -1 */
526static int R3_recip(small *out,const small *in)
527{
528  small f[p+1],g[p+1],v[p+1],r[p+1];
529  int i,loop,delta;
530  int sign,swap,t;
531
532  for (i = 0;i < p+1;++i) v[i] = 0;
533  for (i = 0;i < p+1;++i) r[i] = 0;
534  r[0] = 1;
535  for (i = 0;i < p;++i) f[i] = 0;
536  f[0] = 1; f[p-1] = f[p] = -1;
537  for (i = 0;i < p;++i) g[p-1-i] = in[i];
538  g[p] = 0;
539
540  delta = 1;
541
542  for (loop = 0;loop < 2*p-1;++loop) {
543    for (i = p;i > 0;--i) v[i] = v[i-1];
544    v[0] = 0;
545
546    sign = -g[0]*f[0];
547    swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
548    delta ^= swap&(delta^-delta);
549    delta += 1;
550
551    for (i = 0;i < p+1;++i) {
552      t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
553      t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
554    }
555
556    for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
557    for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
558
559    for (i = 0;i < p;++i) g[i] = g[i+1];
560    g[p] = 0;
561  }
562
563  sign = f[0];
564  for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
565
566  return int16_nonzero_mask(delta);
567}
568
569#endif
570
571/* ----- polynomials mod q */
572
573/* h = f*g in the ring Rq */
574static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
575{
576  Fq fg[p+p-1];
577  Fq result;
578  int i,j;
579
580  for (i = 0;i < p;++i) {
581    result = 0;
582    for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
583    fg[i] = result;
584  }
585  for (i = p;i < p+p-1;++i) {
586    result = 0;
587    for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
588    fg[i] = result;
589  }
590
591  for (i = p+p-2;i >= p;--i) {
592    fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
593    fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
594  }
595
596  for (i = 0;i < p;++i) h[i] = fg[i];
597}
598
599#ifndef LPR
600
601/* h = 3f in Rq */
602static void Rq_mult3(Fq *h,const Fq *f)
603{
604  int i;
605
606  for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
607}
608
609/* out = 1/(3*in) in Rq */
610/* returns 0 if recip succeeded; else -1 */
611static int Rq_recip3(Fq *out,const small *in)
612{
613  Fq f[p+1],g[p+1],v[p+1],r[p+1];
614  int i,loop,delta;
615  int swap,t;
616  int32 f0,g0;
617  Fq scale;
618
619  for (i = 0;i < p+1;++i) v[i] = 0;
620  for (i = 0;i < p+1;++i) r[i] = 0;
621  r[0] = Fq_recip(3);
622  for (i = 0;i < p;++i) f[i] = 0;
623  f[0] = 1; f[p-1] = f[p] = -1;
624  for (i = 0;i < p;++i) g[p-1-i] = in[i];
625  g[p] = 0;
626
627  delta = 1;
628
629  for (loop = 0;loop < 2*p-1;++loop) {
630    for (i = p;i > 0;--i) v[i] = v[i-1];
631    v[0] = 0;
632
633    swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
634    delta ^= swap&(delta^-delta);
635    delta += 1;
636
637    for (i = 0;i < p+1;++i) {
638      t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
639      t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
640    }
641
642    f0 = f[0];
643    g0 = g[0];
644    for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
645    for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
646
647    for (i = 0;i < p;++i) g[i] = g[i+1];
648    g[p] = 0;
649  }
650
651  scale = Fq_recip(f[0]);
652  for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
653
654  return int16_nonzero_mask(delta);
655}
656
657#endif
658
659/* ----- rounded polynomials mod q */
660
661static void Round(Fq *out,const Fq *a)
662{
663  int i;
664  for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
665}
666
667/* ----- sorting to generate short polynomial */
668
669static void Short_fromlist(small *out,const uint32 *in)
670{
671  uint32 L[p];
672  int i;
673
674  for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
675  for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
676  crypto_sort_uint32(L,p);
677  for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
678}
679
680/* ----- underlying hash function */
681
682#define Hash_bytes 32
683
684/* e.g., b = 0 means out = Hash0(in) */
685static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
686{
687  unsigned char x[inlen+1];
688  unsigned char h[64];
689  int i;
690
691  x[0] = b;
692  for (i = 0;i < inlen;++i) x[i+1] = in[i];
693  crypto_hash_sha512(h,x,inlen+1);
694  for (i = 0;i < 32;++i) out[i] = h[i];
695}
696
697/* ----- higher-level randomness */
698
699static uint32 urandom32(void)
700{
701  unsigned char c[4];
702  uint32 out[4];
703
704  randombytes(c,4);
705  out[0] = (uint32)c[0];
706  out[1] = ((uint32)c[1])<<8;
707  out[2] = ((uint32)c[2])<<16;
708  out[3] = ((uint32)c[3])<<24;
709  return out[0]+out[1]+out[2]+out[3];
710}
711
712static void Short_random(small *out)
713{
714  uint32 L[p];
715  int i;
716
717  for (i = 0;i < p;++i) L[i] = urandom32();
718  Short_fromlist(out,L);
719}
720
721#ifndef LPR
722
723static void Small_random(small *out)
724{
725  int i;
726
727  for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
728}
729
730#endif
731
732/* ----- Streamlined NTRU Prime Core */
733
734#ifndef LPR
735
736/* h,(f,ginv) = KeyGen() */
737static void KeyGen(Fq *h,small *f,small *ginv)
738{
739  small g[p];
740  Fq finv[p];
741
742  for (;;) {
743    Small_random(g);
744    if (R3_recip(ginv,g) == 0) break;
745  }
746  Short_random(f);
747  Rq_recip3(finv,f); /* always works */
748  Rq_mult_small(h,finv,g);
749}
750
751/* c = Encrypt(r,h) */
752static void Encrypt(Fq *c,const small *r,const Fq *h)
753{
754  Fq hr[p];
755
756  Rq_mult_small(hr,h,r);
757  Round(c,hr);
758}
759
760/* r = Decrypt(c,(f,ginv)) */
761static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
762{
763  Fq cf[p];
764  Fq cf3[p];
765  small e[p];
766  small ev[p];
767  int mask;
768  int i;
769
770  Rq_mult_small(cf,c,f);
771  Rq_mult3(cf3,cf);
772  R3_fromRq(e,cf3);
773  R3_mult(ev,e,ginv);
774
775  mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
776  for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
777  for (i = w;i < p;++i) r[i] = ev[i]&~mask;
778}
779
780#endif
781
782/* ----- NTRU LPRime Core */
783
784#ifdef LPR
785
786/* (G,A),a = KeyGen(G); leaves G unchanged */
787static void KeyGen(Fq *A,small *a,const Fq *G)
788{
789  Fq aG[p];
790
791  Short_random(a);
792  Rq_mult_small(aG,G,a);
793  Round(A,aG);
794}
795
796/* B,T = Encrypt(r,(G,A),b) */
797static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
798{
799  Fq bG[p];
800  Fq bA[p];
801  int i;
802
803  Rq_mult_small(bG,G,b);
804  Round(B,bG);
805  Rq_mult_small(bA,A,b);
806  for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
807}
808
809/* r = Decrypt((B,T),a) */
810static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
811{
812  Fq aB[p];
813  int i;
814
815  Rq_mult_small(aB,B,a);
816  for (i = 0;i < I;++i)
817    r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
818}
819
820#endif
821
822/* ----- encoding I-bit inputs */
823
824#ifdef LPR
825
826#define Inputs_bytes (I/8)
827typedef int8 Inputs[I]; /* passed by reference */
828
829static void Inputs_encode(unsigned char *s,const Inputs r)
830{
831  int i;
832  for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
833  for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
834}
835
836#endif
837
838/* ----- Expand */
839
840#ifdef LPR
841
842static const unsigned char aes_nonce[16] = {0};
843
844static void Expand(uint32 *L,const unsigned char *k)
845{
846  int i;
847  crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
848  for (i = 0;i < p;++i) {
849    uint32 L0 = ((unsigned char *) L)[4*i];
850    uint32 L1 = ((unsigned char *) L)[4*i+1];
851    uint32 L2 = ((unsigned char *) L)[4*i+2];
852    uint32 L3 = ((unsigned char *) L)[4*i+3];
853    L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
854  }
855}
856
857#endif
858
859/* ----- Seeds */
860
861#ifdef LPR
862
863#define Seeds_bytes 32
864
865static void Seeds_random(unsigned char *s)
866{
867  randombytes(s,Seeds_bytes);
868}
869
870#endif
871
872/* ----- Generator, HashShort */
873
874#ifdef LPR
875
876/* G = Generator(k) */
877static void Generator(Fq *G,const unsigned char *k)
878{
879  uint32 L[p];
880  int i;
881
882  Expand(L,k);
883  for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
884}
885
886/* out = HashShort(r) */
887static void HashShort(small *out,const Inputs r)
888{
889  unsigned char s[Inputs_bytes];
890  unsigned char h[Hash_bytes];
891  uint32 L[p];
892
893  Inputs_encode(s,r);
894  Hash_prefix(h,5,s,sizeof s);
895  Expand(L,h);
896  Short_fromlist(out,L);
897}
898
899#endif
900
901/* ----- NTRU LPRime Expand */
902
903#ifdef LPR
904
905/* (S,A),a = XKeyGen() */
906static void XKeyGen(unsigned char *S,Fq *A,small *a)
907{
908  Fq G[p];
909
910  Seeds_random(S);
911  Generator(G,S);
912  KeyGen(A,a,G);
913}
914
915/* B,T = XEncrypt(r,(S,A)) */
916static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
917{
918  Fq G[p];
919  small b[p];
920
921  Generator(G,S);
922  HashShort(b,r);
923  Encrypt(B,T,r,G,A,b);
924}
925
926#define XDecrypt Decrypt
927
928#endif
929
930/* ----- encoding small polynomials (including short polynomials) */
931
932#define Small_bytes ((p+3)/4)
933
934/* these are the only functions that rely on p mod 4 = 1 */
935
936static void Small_encode(unsigned char *s,const small *f)
937{
938  small x;
939  int i;
940
941  for (i = 0;i < p/4;++i) {
942    x = *f++ + 1;
943    x += (*f++ + 1)<<2;
944    x += (*f++ + 1)<<4;
945    x += (*f++ + 1)<<6;
946    *s++ = x;
947  }
948  x = *f++ + 1;
949  *s++ = x;
950}
951
952static void Small_decode(small *f,const unsigned char *s)
953{
954  unsigned char x;
955  int i;
956
957  for (i = 0;i < p/4;++i) {
958    x = *s++;
959    *f++ = ((small)(x&3))-1; x >>= 2;
960    *f++ = ((small)(x&3))-1; x >>= 2;
961    *f++ = ((small)(x&3))-1; x >>= 2;
962    *f++ = ((small)(x&3))-1;
963  }
964  x = *s++;
965  *f++ = ((small)(x&3))-1;
966}
967
968/* ----- encoding general polynomials */
969
970#ifndef LPR
971
972static void Rq_encode(unsigned char *s,const Fq *r)
973{
974  uint16 R[p],M[p];
975  int i;
976
977  for (i = 0;i < p;++i) R[i] = r[i]+q12;
978  for (i = 0;i < p;++i) M[i] = q;
979  Encode(s,R,M,p);
980}
981
982static void Rq_decode(Fq *r,const unsigned char *s)
983{
984  uint16 R[p],M[p];
985  int i;
986
987  for (i = 0;i < p;++i) M[i] = q;
988  Decode(R,s,M,p);
989  for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
990}
991
992#endif
993
994/* ----- encoding rounded polynomials */
995
996static void Rounded_encode(unsigned char *s,const Fq *r)
997{
998  uint16 R[p],M[p];
999  int i;
1000
1001  for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
1002  for (i = 0;i < p;++i) M[i] = (q+2)/3;
1003  Encode(s,R,M,p);
1004}
1005
1006static void Rounded_decode(Fq *r,const unsigned char *s)
1007{
1008  uint16 R[p],M[p];
1009  int i;
1010
1011  for (i = 0;i < p;++i) M[i] = (q+2)/3;
1012  Decode(R,s,M,p);
1013  for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
1014}
1015
1016/* ----- encoding top polynomials */
1017
1018#ifdef LPR
1019
1020#define Top_bytes (I/2)
1021
1022static void Top_encode(unsigned char *s,const int8 *T)
1023{
1024  int i;
1025  for (i = 0;i < Top_bytes;++i)
1026    s[i] = T[2*i]+(T[2*i+1]<<4);
1027}
1028
1029static void Top_decode(int8 *T,const unsigned char *s)
1030{
1031  int i;
1032  for (i = 0;i < Top_bytes;++i) {
1033    T[2*i] = s[i]&15;
1034    T[2*i+1] = s[i]>>4;
1035  }
1036}
1037
1038#endif
1039
1040/* ----- Streamlined NTRU Prime Core plus encoding */
1041
1042#ifndef LPR
1043
1044typedef small Inputs[p]; /* passed by reference */
1045#define Inputs_random Short_random
1046#define Inputs_encode Small_encode
1047#define Inputs_bytes Small_bytes
1048
1049#define Ciphertexts_bytes Rounded_bytes
1050#define SecretKeys_bytes (2*Small_bytes)
1051#define PublicKeys_bytes Rq_bytes
1052
1053/* pk,sk = ZKeyGen() */
1054static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1055{
1056  Fq h[p];
1057  small f[p],v[p];
1058
1059  KeyGen(h,f,v);
1060  Rq_encode(pk,h);
1061  Small_encode(sk,f); sk += Small_bytes;
1062  Small_encode(sk,v);
1063}
1064
1065/* C = ZEncrypt(r,pk) */
1066static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
1067{
1068  Fq h[p];
1069  Fq c[p];
1070  Rq_decode(h,pk);
1071  Encrypt(c,r,h);
1072  Rounded_encode(C,c);
1073}
1074
1075/* r = ZDecrypt(C,sk) */
1076static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
1077{
1078  small f[p],v[p];
1079  Fq c[p];
1080
1081  Small_decode(f,sk); sk += Small_bytes;
1082  Small_decode(v,sk);
1083  Rounded_decode(c,C);
1084  Decrypt(r,c,f,v);
1085}
1086
1087#endif
1088
1089/* ----- NTRU LPRime Expand plus encoding */
1090
1091#ifdef LPR
1092
1093#define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
1094#define SecretKeys_bytes Small_bytes
1095#define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
1096
1097static void Inputs_random(Inputs r)
1098{
1099  unsigned char s[Inputs_bytes];
1100  int i;
1101
1102  randombytes(s,sizeof s);
1103  for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
1104}
1105
1106/* pk,sk = ZKeyGen() */
1107static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1108{
1109  Fq A[p];
1110  small a[p];
1111
1112  XKeyGen(pk,A,a); pk += Seeds_bytes;
1113  Rounded_encode(pk,A);
1114  Small_encode(sk,a);
1115}
1116
1117/* c = ZEncrypt(r,pk) */
1118static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
1119{
1120  Fq A[p];
1121  Fq B[p];
1122  int8 T[I];
1123
1124  Rounded_decode(A,pk+Seeds_bytes);
1125  XEncrypt(B,T,r,pk,A);
1126  Rounded_encode(c,B); c += Rounded_bytes;
1127  Top_encode(c,T);
1128}
1129
1130/* r = ZDecrypt(C,sk) */
1131static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
1132{
1133  small a[p];
1134  Fq B[p];
1135  int8 T[I];
1136
1137  Small_decode(a,sk);
1138  Rounded_decode(B,c);
1139  Top_decode(T,c+Rounded_bytes);
1140  XDecrypt(r,B,T,a);
1141}
1142
1143#endif
1144
1145/* ----- confirmation hash */
1146
1147#define Confirm_bytes 32
1148
1149/* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
1150static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
1151{
1152#ifndef LPR
1153  unsigned char x[Hash_bytes*2];
1154  int i;
1155
1156  Hash_prefix(x,3,r,Inputs_bytes);
1157  for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
1158#else
1159  unsigned char x[Inputs_bytes+Hash_bytes];
1160  int i;
1161
1162  for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
1163  for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
1164#endif
1165  Hash_prefix(h,2,x,sizeof x);
1166}
1167
1168/* ----- session-key hash */
1169
1170/* k = HashSession(b,y,z) */
1171static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
1172{
1173#ifndef LPR
1174  unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
1175  int i;
1176
1177  Hash_prefix(x,3,y,Inputs_bytes);
1178  for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
1179#else
1180  unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
1181  int i;
1182
1183  for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
1184  for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
1185#endif
1186  Hash_prefix(k,b,x,sizeof x);
1187}
1188
1189/* ----- Streamlined NTRU Prime and NTRU LPRime */
1190
1191/* pk,sk = KEM_KeyGen() */
1192static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
1193{
1194  int i;
1195
1196  ZKeyGen(pk,sk); sk += SecretKeys_bytes;
1197  for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
1198  randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
1199  Hash_prefix(sk,4,pk,PublicKeys_bytes);
1200}
1201
1202/* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1203static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
1204{
1205  Inputs_encode(r_enc,r);
1206  ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
1207  HashConfirm(c,r_enc,pk,cache);
1208}
1209
1210/* c,k = Encap(pk) */
1211static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
1212{
1213  Inputs r;
1214  unsigned char r_enc[Inputs_bytes];
1215  unsigned char cache[Hash_bytes];
1216
1217  Hash_prefix(cache,4,pk,PublicKeys_bytes);
1218  Inputs_random(r);
1219  Hide(c,r_enc,r,pk,cache);
1220  HashSession(k,1,r_enc,c);
1221}
1222
1223/* 0 if matching ciphertext+confirm, else -1 */
1224static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
1225{
1226  uint16 differentbits = 0;
1227  int len = Ciphertexts_bytes+Confirm_bytes;
1228
1229  while (len-- > 0) differentbits |= (*c++)^(*c2++);
1230  return (1&((differentbits-1)>>8))-1;
1231}
1232
1233/* k = Decap(c,sk) */
1234static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1235{
1236  const unsigned char *pk = sk + SecretKeys_bytes;
1237  const unsigned char *rho = pk + PublicKeys_bytes;
1238  const unsigned char *cache = rho + Inputs_bytes;
1239  Inputs r;
1240  unsigned char r_enc[Inputs_bytes];
1241  unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
1242  int mask;
1243  int i;
1244
1245  ZDecrypt(r,c,sk);
1246  Hide(cnew,r_enc,r,pk,cache);
1247  mask = Ciphertexts_diff_mask(c,cnew);
1248  for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
1249  HashSession(k,1+mask,r_enc,c);
1250}
1251
1252/* ----- crypto_kem API */
1253
1254
1255int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
1256{
1257  KEM_KeyGen(pk,sk);
1258  return 0;
1259}
1260
1261int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
1262{
1263  Encap(c,k,pk);
1264  return 0;
1265}
1266
1267int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1268{
1269  Decap(k,c,sk);
1270  return 0;
1271}
1272
1273