1/* Program for computing integer expressions using the GNU Multiple Precision
2   Arithmetic Library.
3
4Copyright 1997, 1999-2002, 2005, 2008, 2012, 2015 Free Software Foundation, Inc.
5
6This program is free software; you can redistribute it and/or modify it under
7the terms of the GNU General Public License as published by the Free Software
8Foundation; either version 3 of the License, or (at your option) any later
9version.
10
11This program is distributed in the hope that it will be useful, but WITHOUT ANY
12WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
13PARTICULAR PURPOSE.  See the GNU General Public License for more details.
14
15You should have received a copy of the GNU General Public License along with
16this program.  If not, see https://www.gnu.org/licenses/.  */
17
18
19/* This expressions evaluator works by building an expression tree (using a
20   recursive descent parser) which is then evaluated.  The expression tree is
21   useful since we want to optimize certain expressions (like a^b % c).
22
23   Usage: pexpr [options] expr ...
24   (Assuming you called the executable `pexpr' of course.)
25
26   Command line options:
27
28   -b        print output in binary
29   -o        print output in octal
30   -d        print output in decimal (the default)
31   -x        print output in hexadecimal
32   -b<NUM>   print output in base NUM
33   -t        print timing information
34   -html     output html
35   -wml      output wml
36   -split    split long lines each 80th digit
37*/
38
39/* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
40   use up extensive resources (cpu, memory).  Useful for the GMP demo on the
41   GMP web site, since we cannot load the server too much.  */
42
43#include "pexpr-config.h"
44
45#include <string.h>
46#include <stdio.h>
47#include <stdlib.h>
48#include <setjmp.h>
49#include <signal.h>
50#include <ctype.h>
51
52#include <time.h>
53#include <sys/types.h>
54#include <sys/time.h>
55#if HAVE_SYS_RESOURCE_H
56#include <sys/resource.h>
57#endif
58
59#include "gmp.h"
60
61/* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
62#ifndef SIGSTKSZ
63#define SIGSTKSZ  4096
64#endif
65
66
67#define TIME(t,func)							\
68  do { int __t0, __tmp;							\
69    __t0 = cputime ();							\
70    {func;}								\
71    __tmp = cputime () - __t0;						\
72    (t) = __tmp;							\
73  } while (0)
74
75/* GMP version 1.x compatibility.  */
76#if ! (__GNU_MP_VERSION >= 2)
77typedef MP_INT __mpz_struct;
78typedef __mpz_struct mpz_t[1];
79typedef __mpz_struct *mpz_ptr;
80#define mpz_fdiv_q	mpz_div
81#define mpz_fdiv_r	mpz_mod
82#define mpz_tdiv_q_2exp	mpz_div_2exp
83#define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
84#endif
85
86/* GMP version 2.0 compatibility.  */
87#if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
88#define mpz_swap(a,b) \
89  do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
90#endif
91
92jmp_buf errjmpbuf;
93
94enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
95	   AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
96	   LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM,
97	   TIMING};
98
99/* Type for the expression tree.  */
100struct expr
101{
102  enum op_t op;
103  union
104  {
105    struct {struct expr *lhs, *rhs;} ops;
106    mpz_t val;
107  } operands;
108};
109
110typedef struct expr *expr_t;
111
112void cleanup_and_exit (int);
113
114char *skipspace (char *);
115void makeexp (expr_t *, enum op_t, expr_t, expr_t);
116void free_expr (expr_t);
117char *expr (char *, expr_t *);
118char *term (char *, expr_t *);
119char *power (char *, expr_t *);
120char *factor (char *, expr_t *);
121int match (char *, char *);
122int matchp (char *, char *);
123int cputime (void);
124
125void mpz_eval_expr (mpz_ptr, expr_t);
126void mpz_eval_mod_expr (mpz_ptr, expr_t, mpz_ptr);
127
128char *error;
129int flag_print = 1;
130int print_timing = 0;
131int flag_html = 0;
132int flag_wml = 0;
133int flag_splitup_output = 0;
134char *newline = "";
135gmp_randstate_t rstate;
136
137
138
139/* cputime() returns user CPU time measured in milliseconds.  */
140#if ! HAVE_CPUTIME
141#if HAVE_GETRUSAGE
142int
143cputime (void)
144{
145  struct rusage rus;
146
147  getrusage (0, &rus);
148  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
149}
150#else
151#if HAVE_CLOCK
152int
153cputime (void)
154{
155  if (CLOCKS_PER_SEC < 100000)
156    return clock () * 1000 / CLOCKS_PER_SEC;
157  return clock () / (CLOCKS_PER_SEC / 1000);
158}
159#else
160int
161cputime (void)
162{
163  return 0;
164}
165#endif
166#endif
167#endif
168
169
170int
171stack_downwards_helper (char *xp)
172{
173  char  y;
174  return &y < xp;
175}
176int
177stack_downwards_p (void)
178{
179  char  x;
180  return stack_downwards_helper (&x);
181}
182
183
184void
185setup_error_handler (void)
186{
187#if HAVE_SIGACTION
188  struct sigaction act;
189  act.sa_handler = cleanup_and_exit;
190  sigemptyset (&(act.sa_mask));
191#define SIGNAL(sig)  sigaction (sig, &act, NULL)
192#else
193  struct { int sa_flags; } act;
194#define SIGNAL(sig)  signal (sig, cleanup_and_exit)
195#endif
196  act.sa_flags = 0;
197
198  /* Set up a stack for signal handling.  A typical cause of error is stack
199     overflow, and in such situation a signal can not be delivered on the
200     overflown stack.  */
201#if HAVE_SIGALTSTACK
202  {
203    /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
204       systems have both. */
205#if HAVE_STACK_T
206    stack_t s;
207#else
208    struct sigaltstack s;
209#endif
210    s.ss_sp = malloc (SIGSTKSZ);
211    s.ss_size = SIGSTKSZ;
212    s.ss_flags = 0;
213    if (sigaltstack (&s, NULL) != 0)
214      perror("sigaltstack");
215    act.sa_flags = SA_ONSTACK;
216  }
217#else
218#if HAVE_SIGSTACK
219  {
220    struct sigstack s;
221    s.ss_sp = malloc (SIGSTKSZ);
222    if (stack_downwards_p ())
223      s.ss_sp += SIGSTKSZ;
224    s.ss_onstack = 0;
225    if (sigstack (&s, NULL) != 0)
226      perror("sigstack");
227    act.sa_flags = SA_ONSTACK;
228  }
229#else
230#endif
231#endif
232
233#ifdef LIMIT_RESOURCE_USAGE
234  {
235    struct rlimit limit;
236
237    limit.rlim_cur = limit.rlim_max = 0;
238    setrlimit (RLIMIT_CORE, &limit);
239
240    limit.rlim_cur = 3;
241    limit.rlim_max = 4;
242    setrlimit (RLIMIT_CPU, &limit);
243
244    limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
245    setrlimit (RLIMIT_DATA, &limit);
246
247    getrlimit (RLIMIT_STACK, &limit);
248    limit.rlim_cur = 4 * 1024 * 1024;
249    setrlimit (RLIMIT_STACK, &limit);
250
251    SIGNAL (SIGXCPU);
252  }
253#endif /* LIMIT_RESOURCE_USAGE */
254
255  SIGNAL (SIGILL);
256  SIGNAL (SIGSEGV);
257#ifdef SIGBUS /* not in mingw */
258  SIGNAL (SIGBUS);
259#endif
260  SIGNAL (SIGFPE);
261  SIGNAL (SIGABRT);
262}
263
264int
265main (int argc, char **argv)
266{
267  struct expr *e;
268  int i;
269  mpz_t r;
270  int errcode = 0;
271  char *str;
272  int base = 10;
273
274  setup_error_handler ();
275
276  gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);
277
278  {
279#if HAVE_GETTIMEOFDAY
280    struct timeval tv;
281    gettimeofday (&tv, NULL);
282    gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
283#else
284    time_t t;
285    time (&t);
286    gmp_randseed_ui (rstate, t);
287#endif
288  }
289
290  mpz_init (r);
291
292  while (argc > 1 && argv[1][0] == '-')
293    {
294      char *arg = argv[1];
295
296      if (arg[1] >= '0' && arg[1] <= '9')
297	break;
298
299      if (arg[1] == 't')
300	print_timing = 1;
301      else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
302	{
303	  base = atoi (arg + 2);
304	  if (base < 2 || base > 62)
305	    {
306	      fprintf (stderr, "error: invalid output base\n");
307	      exit (-1);
308	    }
309	}
310      else if (arg[1] == 'b' && arg[2] == 0)
311	base = 2;
312      else if (arg[1] == 'x' && arg[2] == 0)
313	base = 16;
314      else if (arg[1] == 'X' && arg[2] == 0)
315	base = -16;
316      else if (arg[1] == 'o' && arg[2] == 0)
317	base = 8;
318      else if (arg[1] == 'd' && arg[2] == 0)
319	base = 10;
320      else if (arg[1] == 'v' && arg[2] == 0)
321	{
322	  printf ("pexpr linked to gmp %s\n", __gmp_version);
323	}
324      else if (strcmp (arg, "-html") == 0)
325	{
326	  flag_html = 1;
327	  newline = "<br>";
328	}
329      else if (strcmp (arg, "-wml") == 0)
330	{
331	  flag_wml = 1;
332	  newline = "<br/>";
333	}
334      else if (strcmp (arg, "-split") == 0)
335	{
336	  flag_splitup_output = 1;
337	}
338      else if (strcmp (arg, "-noprint") == 0)
339	{
340	  flag_print = 0;
341	}
342      else
343	{
344	  fprintf (stderr, "error: unknown option `%s'\n", arg);
345	  exit (-1);
346	}
347      argv++;
348      argc--;
349    }
350
351  for (i = 1; i < argc; i++)
352    {
353      int s;
354      int jmpval;
355
356      /* Set up error handler for parsing expression.  */
357      jmpval = setjmp (errjmpbuf);
358      if (jmpval != 0)
359	{
360	  fprintf (stderr, "error: %s%s\n", error, newline);
361	  fprintf (stderr, "       %s%s\n", argv[i], newline);
362	  if (! flag_html)
363	    {
364	      /* ??? Dunno how to align expression position with arrow in
365		 HTML ??? */
366	      fprintf (stderr, "       ");
367	      for (s = jmpval - (long) argv[i]; --s >= 0; )
368		putc (' ', stderr);
369	      fprintf (stderr, "^\n");
370	    }
371
372	  errcode |= 1;
373	  continue;
374	}
375
376      str = expr (argv[i], &e);
377
378      if (str[0] != 0)
379	{
380	  fprintf (stderr,
381		   "error: garbage where end of expression expected%s\n",
382		   newline);
383	  fprintf (stderr, "       %s%s\n", argv[i], newline);
384	  if (! flag_html)
385	    {
386	      /* ??? Dunno how to align expression position with arrow in
387		 HTML ??? */
388	      fprintf (stderr, "        ");
389	      for (s = str - argv[i]; --s; )
390		putc (' ', stderr);
391	      fprintf (stderr, "^\n");
392	    }
393
394	  errcode |= 1;
395	  free_expr (e);
396	  continue;
397	}
398
399      /* Set up error handler for evaluating expression.  */
400      if (setjmp (errjmpbuf))
401	{
402	  fprintf (stderr, "error: %s%s\n", error, newline);
403	  fprintf (stderr, "       %s%s\n", argv[i], newline);
404	  if (! flag_html)
405	    {
406	      /* ??? Dunno how to align expression position with arrow in
407		 HTML ??? */
408	      fprintf (stderr, "       ");
409	      for (s = str - argv[i]; --s >= 0; )
410		putc (' ', stderr);
411	      fprintf (stderr, "^\n");
412	    }
413
414	  errcode |= 2;
415	  continue;
416	}
417
418      if (print_timing)
419	{
420	  int t;
421	  TIME (t, mpz_eval_expr (r, e));
422	  printf ("computation took %d ms%s\n", t, newline);
423	}
424      else
425	mpz_eval_expr (r, e);
426
427      if (flag_print)
428	{
429	  size_t out_len;
430	  char *tmp, *s;
431
432	  out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
433#ifdef LIMIT_RESOURCE_USAGE
434	  if (out_len > 100000)
435	    {
436	      printf ("result is about %ld digits, not printing it%s\n",
437		      (long) out_len - 3, newline);
438	      exit (-2);
439	    }
440#endif
441	  tmp = malloc (out_len);
442
443	  if (print_timing)
444	    {
445	      int t;
446	      printf ("output conversion ");
447	      TIME (t, mpz_get_str (tmp, base, r));
448	      printf ("took %d ms%s\n", t, newline);
449	    }
450	  else
451	    mpz_get_str (tmp, base, r);
452
453	  out_len = strlen (tmp);
454	  if (flag_splitup_output)
455	    {
456	      for (s = tmp; out_len > 80; s += 80)
457		{
458		  fwrite (s, 1, 80, stdout);
459		  printf ("%s\n", newline);
460		  out_len -= 80;
461		}
462
463	      fwrite (s, 1, out_len, stdout);
464	    }
465	  else
466	    {
467	      fwrite (tmp, 1, out_len, stdout);
468	    }
469
470	  free (tmp);
471	  printf ("%s\n", newline);
472	}
473      else
474	{
475	  printf ("result is approximately %ld digits%s\n",
476		  (long) mpz_sizeinbase (r, base >= 0 ? base : -base),
477		  newline);
478	}
479
480      free_expr (e);
481    }
482
483  mpz_clear (r);
484
485  exit (errcode);
486}
487
488char *
489expr (char *str, expr_t *e)
490{
491  expr_t e2;
492
493  str = skipspace (str);
494  if (str[0] == '+')
495    {
496      str = term (str + 1, e);
497    }
498  else if (str[0] == '-')
499    {
500      str = term (str + 1, e);
501      makeexp (e, NEG, *e, NULL);
502    }
503  else if (str[0] == '~')
504    {
505      str = term (str + 1, e);
506      makeexp (e, NOT, *e, NULL);
507    }
508  else
509    {
510      str = term (str, e);
511    }
512
513  for (;;)
514    {
515      str = skipspace (str);
516      switch (str[0])
517	{
518	case 'p':
519	  if (match ("plus", str))
520	    {
521	      str = term (str + 4, &e2);
522	      makeexp (e, PLUS, *e, e2);
523	    }
524	  else
525	    return str;
526	  break;
527	case 'm':
528	  if (match ("minus", str))
529	    {
530	      str = term (str + 5, &e2);
531	      makeexp (e, MINUS, *e, e2);
532	    }
533	  else
534	    return str;
535	  break;
536	case '+':
537	  str = term (str + 1, &e2);
538	  makeexp (e, PLUS, *e, e2);
539	  break;
540	case '-':
541	  str = term (str + 1, &e2);
542	  makeexp (e, MINUS, *e, e2);
543	  break;
544	default:
545	  return str;
546	}
547    }
548}
549
550char *
551term (char *str, expr_t *e)
552{
553  expr_t e2;
554
555  str = power (str, e);
556  for (;;)
557    {
558      str = skipspace (str);
559      switch (str[0])
560	{
561	case 'm':
562	  if (match ("mul", str))
563	    {
564	      str = power (str + 3, &e2);
565	      makeexp (e, MULT, *e, e2);
566	      break;
567	    }
568	  if (match ("mod", str))
569	    {
570	      str = power (str + 3, &e2);
571	      makeexp (e, MOD, *e, e2);
572	      break;
573	    }
574	  return str;
575	case 'd':
576	  if (match ("div", str))
577	    {
578	      str = power (str + 3, &e2);
579	      makeexp (e, DIV, *e, e2);
580	      break;
581	    }
582	  return str;
583	case 'r':
584	  if (match ("rem", str))
585	    {
586	      str = power (str + 3, &e2);
587	      makeexp (e, REM, *e, e2);
588	      break;
589	    }
590	  return str;
591	case 'i':
592	  if (match ("invmod", str))
593	    {
594	      str = power (str + 6, &e2);
595	      makeexp (e, REM, *e, e2);
596	      break;
597	    }
598	  return str;
599	case 't':
600	  if (match ("times", str))
601	    {
602	      str = power (str + 5, &e2);
603	      makeexp (e, MULT, *e, e2);
604	      break;
605	    }
606	  if (match ("thru", str))
607	    {
608	      str = power (str + 4, &e2);
609	      makeexp (e, DIV, *e, e2);
610	      break;
611	    }
612	  if (match ("through", str))
613	    {
614	      str = power (str + 7, &e2);
615	      makeexp (e, DIV, *e, e2);
616	      break;
617	    }
618	  return str;
619	case '*':
620	  str = power (str + 1, &e2);
621	  makeexp (e, MULT, *e, e2);
622	  break;
623	case '/':
624	  str = power (str + 1, &e2);
625	  makeexp (e, DIV, *e, e2);
626	  break;
627	case '%':
628	  str = power (str + 1, &e2);
629	  makeexp (e, MOD, *e, e2);
630	  break;
631	default:
632	  return str;
633	}
634    }
635}
636
637char *
638power (char *str, expr_t *e)
639{
640  expr_t e2;
641
642  str = factor (str, e);
643  while (str[0] == '!')
644    {
645      str++;
646      makeexp (e, FAC, *e, NULL);
647    }
648  str = skipspace (str);
649  if (str[0] == '^')
650    {
651      str = power (str + 1, &e2);
652      makeexp (e, POW, *e, e2);
653    }
654  return str;
655}
656
657int
658match (char *s, char *str)
659{
660  char *ostr = str;
661  int i;
662
663  for (i = 0; s[i] != 0; i++)
664    {
665      if (str[i] != s[i])
666	return 0;
667    }
668  str = skipspace (str + i);
669  return str - ostr;
670}
671
672int
673matchp (char *s, char *str)
674{
675  char *ostr = str;
676  int i;
677
678  for (i = 0; s[i] != 0; i++)
679    {
680      if (str[i] != s[i])
681	return 0;
682    }
683  str = skipspace (str + i);
684  if (str[0] == '(')
685    return str - ostr + 1;
686  return 0;
687}
688
689struct functions
690{
691  char *spelling;
692  enum op_t op;
693  int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
694};
695
696struct functions fns[] =
697{
698  {"sqrt", SQRT, 1},
699#if __GNU_MP_VERSION >= 2
700  {"root", ROOT, 2},
701  {"popc", POPCNT, 1},
702  {"hamdist", HAMDIST, 2},
703#endif
704  {"gcd", GCD, 0},
705#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
706  {"lcm", LCM, 0},
707#endif
708  {"and", AND, 0},
709  {"ior", IOR, 0},
710#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
711  {"xor", XOR, 0},
712#endif
713  {"plus", PLUS, 0},
714  {"pow", POW, 2},
715  {"minus", MINUS, 2},
716  {"mul", MULT, 0},
717  {"div", DIV, 2},
718  {"mod", MOD, 2},
719  {"rem", REM, 2},
720#if __GNU_MP_VERSION >= 2
721  {"invmod", INVMOD, 2},
722#endif
723  {"log", LOG, 2},
724  {"log2", LOG2, 1},
725  {"F", FERMAT, 1},
726  {"M", MERSENNE, 1},
727  {"fib", FIBONACCI, 1},
728  {"Fib", FIBONACCI, 1},
729  {"random", RANDOM, 1},
730  {"nextprime", NEXTPRIME, 1},
731  {"binom", BINOM, 2},
732  {"binomial", BINOM, 2},
733  {"fac", FAC, 1},
734  {"fact", FAC, 1},
735  {"factorial", FAC, 1},
736  {"time", TIMING, 1},
737  {"", NOP, 0}
738};
739
740char *
741factor (char *str, expr_t *e)
742{
743  expr_t e1, e2;
744
745  str = skipspace (str);
746
747  if (isalpha (str[0]))
748    {
749      int i;
750      int cnt;
751
752      for (i = 0; fns[i].op != NOP; i++)
753	{
754	  if (fns[i].arity == 1)
755	    {
756	      cnt = matchp (fns[i].spelling, str);
757	      if (cnt != 0)
758		{
759		  str = expr (str + cnt, &e1);
760		  str = skipspace (str);
761		  if (str[0] != ')')
762		    {
763		      error = "expected `)'";
764		      longjmp (errjmpbuf, (int) (long) str);
765		    }
766		  makeexp (e, fns[i].op, e1, NULL);
767		  return str + 1;
768		}
769	    }
770	}
771
772      for (i = 0; fns[i].op != NOP; i++)
773	{
774	  if (fns[i].arity != 1)
775	    {
776	      cnt = matchp (fns[i].spelling, str);
777	      if (cnt != 0)
778		{
779		  str = expr (str + cnt, &e1);
780		  str = skipspace (str);
781
782		  if (str[0] != ',')
783		    {
784		      error = "expected `,' and another operand";
785		      longjmp (errjmpbuf, (int) (long) str);
786		    }
787
788		  str = skipspace (str + 1);
789		  str = expr (str, &e2);
790		  str = skipspace (str);
791
792		  if (fns[i].arity == 0)
793		    {
794		      while (str[0] == ',')
795			{
796			  makeexp (&e1, fns[i].op, e1, e2);
797			  str = skipspace (str + 1);
798			  str = expr (str, &e2);
799			  str = skipspace (str);
800			}
801		    }
802
803		  if (str[0] != ')')
804		    {
805		      error = "expected `)'";
806		      longjmp (errjmpbuf, (int) (long) str);
807		    }
808
809		  makeexp (e, fns[i].op, e1, e2);
810		  return str + 1;
811		}
812	    }
813	}
814    }
815
816  if (str[0] == '(')
817    {
818      str = expr (str + 1, e);
819      str = skipspace (str);
820      if (str[0] != ')')
821	{
822	  error = "expected `)'";
823	  longjmp (errjmpbuf, (int) (long) str);
824	}
825      str++;
826    }
827  else if (str[0] >= '0' && str[0] <= '9')
828    {
829      expr_t res;
830      char *s, *sc;
831
832      res = malloc (sizeof (struct expr));
833      res -> op = LIT;
834      mpz_init (res->operands.val);
835
836      s = str;
837      while (isalnum (str[0]))
838	str++;
839      sc = malloc (str - s + 1);
840      memcpy (sc, s, str - s);
841      sc[str - s] = 0;
842
843      mpz_set_str (res->operands.val, sc, 0);
844      *e = res;
845      free (sc);
846    }
847  else
848    {
849      error = "operand expected";
850      longjmp (errjmpbuf, (int) (long) str);
851    }
852  return str;
853}
854
855char *
856skipspace (char *str)
857{
858  while (str[0] == ' ')
859    str++;
860  return str;
861}
862
863/* Make a new expression with operation OP and right hand side
864   RHS and left hand side lhs.  Put the result in R.  */
865void
866makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
867{
868  expr_t res;
869  res = malloc (sizeof (struct expr));
870  res -> op = op;
871  res -> operands.ops.lhs = lhs;
872  res -> operands.ops.rhs = rhs;
873  *r = res;
874  return;
875}
876
877/* Free the memory used by expression E.  */
878void
879free_expr (expr_t e)
880{
881  if (e->op != LIT)
882    {
883      free_expr (e->operands.ops.lhs);
884      if (e->operands.ops.rhs != NULL)
885	free_expr (e->operands.ops.rhs);
886    }
887  else
888    {
889      mpz_clear (e->operands.val);
890    }
891}
892
893/* Evaluate the expression E and put the result in R.  */
894void
895mpz_eval_expr (mpz_ptr r, expr_t e)
896{
897  mpz_t lhs, rhs;
898
899  switch (e->op)
900    {
901    case LIT:
902      mpz_set (r, e->operands.val);
903      return;
904    case PLUS:
905      mpz_init (lhs); mpz_init (rhs);
906      mpz_eval_expr (lhs, e->operands.ops.lhs);
907      mpz_eval_expr (rhs, e->operands.ops.rhs);
908      mpz_add (r, lhs, rhs);
909      mpz_clear (lhs); mpz_clear (rhs);
910      return;
911    case MINUS:
912      mpz_init (lhs); mpz_init (rhs);
913      mpz_eval_expr (lhs, e->operands.ops.lhs);
914      mpz_eval_expr (rhs, e->operands.ops.rhs);
915      mpz_sub (r, lhs, rhs);
916      mpz_clear (lhs); mpz_clear (rhs);
917      return;
918    case MULT:
919      mpz_init (lhs); mpz_init (rhs);
920      mpz_eval_expr (lhs, e->operands.ops.lhs);
921      mpz_eval_expr (rhs, e->operands.ops.rhs);
922      mpz_mul (r, lhs, rhs);
923      mpz_clear (lhs); mpz_clear (rhs);
924      return;
925    case DIV:
926      mpz_init (lhs); mpz_init (rhs);
927      mpz_eval_expr (lhs, e->operands.ops.lhs);
928      mpz_eval_expr (rhs, e->operands.ops.rhs);
929      mpz_fdiv_q (r, lhs, rhs);
930      mpz_clear (lhs); mpz_clear (rhs);
931      return;
932    case MOD:
933      mpz_init (rhs);
934      mpz_eval_expr (rhs, e->operands.ops.rhs);
935      mpz_abs (rhs, rhs);
936      mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
937      mpz_clear (rhs);
938      return;
939    case REM:
940      /* Check if lhs operand is POW expression and optimize for that case.  */
941      if (e->operands.ops.lhs->op == POW)
942	{
943	  mpz_t powlhs, powrhs;
944	  mpz_init (powlhs);
945	  mpz_init (powrhs);
946	  mpz_init (rhs);
947	  mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
948	  mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
949	  mpz_eval_expr (rhs, e->operands.ops.rhs);
950	  mpz_powm (r, powlhs, powrhs, rhs);
951	  if (mpz_cmp_si (rhs, 0L) < 0)
952	    mpz_neg (r, r);
953	  mpz_clear (powlhs);
954	  mpz_clear (powrhs);
955	  mpz_clear (rhs);
956	  return;
957	}
958
959      mpz_init (lhs); mpz_init (rhs);
960      mpz_eval_expr (lhs, e->operands.ops.lhs);
961      mpz_eval_expr (rhs, e->operands.ops.rhs);
962      mpz_fdiv_r (r, lhs, rhs);
963      mpz_clear (lhs); mpz_clear (rhs);
964      return;
965#if __GNU_MP_VERSION >= 2
966    case INVMOD:
967      mpz_init (lhs); mpz_init (rhs);
968      mpz_eval_expr (lhs, e->operands.ops.lhs);
969      mpz_eval_expr (rhs, e->operands.ops.rhs);
970      mpz_invert (r, lhs, rhs);
971      mpz_clear (lhs); mpz_clear (rhs);
972      return;
973#endif
974    case POW:
975      mpz_init (lhs); mpz_init (rhs);
976      mpz_eval_expr (lhs, e->operands.ops.lhs);
977      if (mpz_cmpabs_ui (lhs, 1) <= 0)
978	{
979	  /* For 0^rhs and 1^rhs, we just need to verify that
980	     rhs is well-defined.  For (-1)^rhs we need to
981	     determine (rhs mod 2).  For simplicity, compute
982	     (rhs mod 2) for all three cases.  */
983	  expr_t two, et;
984	  two = malloc (sizeof (struct expr));
985	  two -> op = LIT;
986	  mpz_init_set_ui (two->operands.val, 2L);
987	  makeexp (&et, MOD, e->operands.ops.rhs, two);
988	  e->operands.ops.rhs = et;
989	}
990
991      mpz_eval_expr (rhs, e->operands.ops.rhs);
992      if (mpz_cmp_si (rhs, 0L) == 0)
993	/* x^0 is 1 */
994	mpz_set_ui (r, 1L);
995      else if (mpz_cmp_si (lhs, 0L) == 0)
996	/* 0^y (where y != 0) is 0 */
997	mpz_set_ui (r, 0L);
998      else if (mpz_cmp_ui (lhs, 1L) == 0)
999	/* 1^y is 1 */
1000	mpz_set_ui (r, 1L);
1001      else if (mpz_cmp_si (lhs, -1L) == 0)
1002	/* (-1)^y just depends on whether y is even or odd */
1003	mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
1004      else if (mpz_cmp_si (rhs, 0L) < 0)
1005	/* x^(-n) is 0 */
1006	mpz_set_ui (r, 0L);
1007      else
1008	{
1009	  unsigned long int cnt;
1010	  unsigned long int y;
1011	  /* error if exponent does not fit into an unsigned long int.  */
1012	  if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1013	    goto pow_err;
1014
1015	  y = mpz_get_ui (rhs);
1016	  /* x^y == (x/(2^c))^y * 2^(c*y) */
1017#if __GNU_MP_VERSION >= 2
1018	  cnt = mpz_scan1 (lhs, 0);
1019#else
1020	  cnt = 0;
1021#endif
1022	  if (cnt != 0)
1023	    {
1024	      if (y * cnt / cnt != y)
1025		goto pow_err;
1026	      mpz_tdiv_q_2exp (lhs, lhs, cnt);
1027	      mpz_pow_ui (r, lhs, y);
1028	      mpz_mul_2exp (r, r, y * cnt);
1029	    }
1030	  else
1031	    mpz_pow_ui (r, lhs, y);
1032	}
1033      mpz_clear (lhs); mpz_clear (rhs);
1034      return;
1035    pow_err:
1036      error = "result of `pow' operator too large";
1037      mpz_clear (lhs); mpz_clear (rhs);
1038      longjmp (errjmpbuf, 1);
1039    case GCD:
1040      mpz_init (lhs); mpz_init (rhs);
1041      mpz_eval_expr (lhs, e->operands.ops.lhs);
1042      mpz_eval_expr (rhs, e->operands.ops.rhs);
1043      mpz_gcd (r, lhs, rhs);
1044      mpz_clear (lhs); mpz_clear (rhs);
1045      return;
1046#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1047    case LCM:
1048      mpz_init (lhs); mpz_init (rhs);
1049      mpz_eval_expr (lhs, e->operands.ops.lhs);
1050      mpz_eval_expr (rhs, e->operands.ops.rhs);
1051      mpz_lcm (r, lhs, rhs);
1052      mpz_clear (lhs); mpz_clear (rhs);
1053      return;
1054#endif
1055    case AND:
1056      mpz_init (lhs); mpz_init (rhs);
1057      mpz_eval_expr (lhs, e->operands.ops.lhs);
1058      mpz_eval_expr (rhs, e->operands.ops.rhs);
1059      mpz_and (r, lhs, rhs);
1060      mpz_clear (lhs); mpz_clear (rhs);
1061      return;
1062    case IOR:
1063      mpz_init (lhs); mpz_init (rhs);
1064      mpz_eval_expr (lhs, e->operands.ops.lhs);
1065      mpz_eval_expr (rhs, e->operands.ops.rhs);
1066      mpz_ior (r, lhs, rhs);
1067      mpz_clear (lhs); mpz_clear (rhs);
1068      return;
1069#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1070    case XOR:
1071      mpz_init (lhs); mpz_init (rhs);
1072      mpz_eval_expr (lhs, e->operands.ops.lhs);
1073      mpz_eval_expr (rhs, e->operands.ops.rhs);
1074      mpz_xor (r, lhs, rhs);
1075      mpz_clear (lhs); mpz_clear (rhs);
1076      return;
1077#endif
1078    case NEG:
1079      mpz_eval_expr (r, e->operands.ops.lhs);
1080      mpz_neg (r, r);
1081      return;
1082    case NOT:
1083      mpz_eval_expr (r, e->operands.ops.lhs);
1084      mpz_com (r, r);
1085      return;
1086    case SQRT:
1087      mpz_init (lhs);
1088      mpz_eval_expr (lhs, e->operands.ops.lhs);
1089      if (mpz_sgn (lhs) < 0)
1090	{
1091	  error = "cannot take square root of negative numbers";
1092	  mpz_clear (lhs);
1093	  longjmp (errjmpbuf, 1);
1094	}
1095      mpz_sqrt (r, lhs);
1096      return;
1097#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1098    case ROOT:
1099      mpz_init (lhs); mpz_init (rhs);
1100      mpz_eval_expr (lhs, e->operands.ops.lhs);
1101      mpz_eval_expr (rhs, e->operands.ops.rhs);
1102      if (mpz_sgn (rhs) <= 0)
1103	{
1104	  error = "cannot take non-positive root orders";
1105	  mpz_clear (lhs); mpz_clear (rhs);
1106	  longjmp (errjmpbuf, 1);
1107	}
1108      if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
1109	{
1110	  error = "cannot take even root orders of negative numbers";
1111	  mpz_clear (lhs); mpz_clear (rhs);
1112	  longjmp (errjmpbuf, 1);
1113	}
1114
1115      {
1116	unsigned long int nth = mpz_get_ui (rhs);
1117	if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1118	  {
1119	    /* If we are asked to take an awfully large root order, cheat and
1120	       ask for the largest order we can pass to mpz_root.  This saves
1121	       some error prone special cases.  */
1122	    nth = ~(unsigned long int) 0;
1123	  }
1124	mpz_root (r, lhs, nth);
1125      }
1126      mpz_clear (lhs); mpz_clear (rhs);
1127      return;
1128#endif
1129    case FAC:
1130      mpz_eval_expr (r, e->operands.ops.lhs);
1131      if (mpz_size (r) > 1)
1132	{
1133	  error = "result of `!' operator too large";
1134	  longjmp (errjmpbuf, 1);
1135	}
1136      mpz_fac_ui (r, mpz_get_ui (r));
1137      return;
1138#if __GNU_MP_VERSION >= 2
1139    case POPCNT:
1140      mpz_eval_expr (r, e->operands.ops.lhs);
1141      { long int cnt;
1142	cnt = mpz_popcount (r);
1143	mpz_set_si (r, cnt);
1144      }
1145      return;
1146    case HAMDIST:
1147      { long int cnt;
1148	mpz_init (lhs); mpz_init (rhs);
1149	mpz_eval_expr (lhs, e->operands.ops.lhs);
1150	mpz_eval_expr (rhs, e->operands.ops.rhs);
1151	cnt = mpz_hamdist (lhs, rhs);
1152	mpz_clear (lhs); mpz_clear (rhs);
1153	mpz_set_si (r, cnt);
1154      }
1155      return;
1156#endif
1157    case LOG2:
1158      mpz_eval_expr (r, e->operands.ops.lhs);
1159      { unsigned long int cnt;
1160	if (mpz_sgn (r) <= 0)
1161	  {
1162	    error = "logarithm of non-positive number";
1163	    longjmp (errjmpbuf, 1);
1164	  }
1165	cnt = mpz_sizeinbase (r, 2);
1166	mpz_set_ui (r, cnt - 1);
1167      }
1168      return;
1169    case LOG:
1170      { unsigned long int cnt;
1171	mpz_init (lhs); mpz_init (rhs);
1172	mpz_eval_expr (lhs, e->operands.ops.lhs);
1173	mpz_eval_expr (rhs, e->operands.ops.rhs);
1174	if (mpz_sgn (lhs) <= 0)
1175	  {
1176	    error = "logarithm of non-positive number";
1177	    mpz_clear (lhs); mpz_clear (rhs);
1178	    longjmp (errjmpbuf, 1);
1179	  }
1180	if (mpz_cmp_ui (rhs, 256) >= 0)
1181	  {
1182	    error = "logarithm base too large";
1183	    mpz_clear (lhs); mpz_clear (rhs);
1184	    longjmp (errjmpbuf, 1);
1185	  }
1186	cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
1187	mpz_set_ui (r, cnt - 1);
1188	mpz_clear (lhs); mpz_clear (rhs);
1189      }
1190      return;
1191    case FERMAT:
1192      {
1193	unsigned long int t;
1194	mpz_init (lhs);
1195	mpz_eval_expr (lhs, e->operands.ops.lhs);
1196	t = (unsigned long int) 1 << mpz_get_ui (lhs);
1197	if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
1198	  {
1199	    error = "too large Mersenne number index";
1200	    mpz_clear (lhs);
1201	    longjmp (errjmpbuf, 1);
1202	  }
1203	mpz_set_ui (r, 1);
1204	mpz_mul_2exp (r, r, t);
1205	mpz_add_ui (r, r, 1);
1206	mpz_clear (lhs);
1207      }
1208      return;
1209    case MERSENNE:
1210      mpz_init (lhs);
1211      mpz_eval_expr (lhs, e->operands.ops.lhs);
1212      if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
1213	{
1214	  error = "too large Mersenne number index";
1215	  mpz_clear (lhs);
1216	  longjmp (errjmpbuf, 1);
1217	}
1218      mpz_set_ui (r, 1);
1219      mpz_mul_2exp (r, r, mpz_get_ui (lhs));
1220      mpz_sub_ui (r, r, 1);
1221      mpz_clear (lhs);
1222      return;
1223    case FIBONACCI:
1224      { mpz_t t;
1225	unsigned long int n, i;
1226	mpz_init (lhs);
1227	mpz_eval_expr (lhs, e->operands.ops.lhs);
1228	if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1229	  {
1230	    error = "Fibonacci index out of range";
1231	    mpz_clear (lhs);
1232	    longjmp (errjmpbuf, 1);
1233	  }
1234	n = mpz_get_ui (lhs);
1235	mpz_clear (lhs);
1236
1237#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1238	mpz_fib_ui (r, n);
1239#else
1240	mpz_init_set_ui (t, 1);
1241	mpz_set_ui (r, 1);
1242
1243	if (n <= 2)
1244	  mpz_set_ui (r, 1);
1245	else
1246	  {
1247	    for (i = 3; i <= n; i++)
1248	      {
1249		mpz_add (t, t, r);
1250		mpz_swap (t, r);
1251	      }
1252	  }
1253	mpz_clear (t);
1254#endif
1255      }
1256      return;
1257    case RANDOM:
1258      {
1259	unsigned long int n;
1260	mpz_init (lhs);
1261	mpz_eval_expr (lhs, e->operands.ops.lhs);
1262	if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1263	  {
1264	    error = "random number size out of range";
1265	    mpz_clear (lhs);
1266	    longjmp (errjmpbuf, 1);
1267	  }
1268	n = mpz_get_ui (lhs);
1269	mpz_clear (lhs);
1270	mpz_urandomb (r, rstate, n);
1271      }
1272      return;
1273    case NEXTPRIME:
1274      {
1275	mpz_eval_expr (r, e->operands.ops.lhs);
1276	mpz_nextprime (r, r);
1277      }
1278      return;
1279    case BINOM:
1280      mpz_init (lhs); mpz_init (rhs);
1281      mpz_eval_expr (lhs, e->operands.ops.lhs);
1282      mpz_eval_expr (rhs, e->operands.ops.rhs);
1283      {
1284	unsigned long int k;
1285	if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1286	  {
1287	    error = "k too large in (n over k) expression";
1288	    mpz_clear (lhs); mpz_clear (rhs);
1289	    longjmp (errjmpbuf, 1);
1290	  }
1291	k = mpz_get_ui (rhs);
1292	mpz_bin_ui (r, lhs, k);
1293      }
1294      mpz_clear (lhs); mpz_clear (rhs);
1295      return;
1296    case TIMING:
1297      {
1298	int t0;
1299	t0 = cputime ();
1300	mpz_eval_expr (r, e->operands.ops.lhs);
1301	printf ("time: %d\n", cputime () - t0);
1302      }
1303      return;
1304    default:
1305      abort ();
1306    }
1307}
1308
1309/* Evaluate the expression E modulo MOD and put the result in R.  */
1310void
1311mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
1312{
1313  mpz_t lhs, rhs;
1314
1315  switch (e->op)
1316    {
1317      case POW:
1318	mpz_init (lhs); mpz_init (rhs);
1319	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1320	mpz_eval_expr (rhs, e->operands.ops.rhs);
1321	mpz_powm (r, lhs, rhs, mod);
1322	mpz_clear (lhs); mpz_clear (rhs);
1323	return;
1324      case PLUS:
1325	mpz_init (lhs); mpz_init (rhs);
1326	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1327	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1328	mpz_add (r, lhs, rhs);
1329	if (mpz_cmp_si (r, 0L) < 0)
1330	  mpz_add (r, r, mod);
1331	else if (mpz_cmp (r, mod) >= 0)
1332	  mpz_sub (r, r, mod);
1333	mpz_clear (lhs); mpz_clear (rhs);
1334	return;
1335      case MINUS:
1336	mpz_init (lhs); mpz_init (rhs);
1337	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1338	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1339	mpz_sub (r, lhs, rhs);
1340	if (mpz_cmp_si (r, 0L) < 0)
1341	  mpz_add (r, r, mod);
1342	else if (mpz_cmp (r, mod) >= 0)
1343	  mpz_sub (r, r, mod);
1344	mpz_clear (lhs); mpz_clear (rhs);
1345	return;
1346      case MULT:
1347	mpz_init (lhs); mpz_init (rhs);
1348	mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1349	mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1350	mpz_mul (r, lhs, rhs);
1351	mpz_mod (r, r, mod);
1352	mpz_clear (lhs); mpz_clear (rhs);
1353	return;
1354      default:
1355	mpz_init (lhs);
1356	mpz_eval_expr (lhs, e);
1357	mpz_mod (r, lhs, mod);
1358	mpz_clear (lhs);
1359	return;
1360    }
1361}
1362
1363void
1364cleanup_and_exit (int sig)
1365{
1366  switch (sig) {
1367#ifdef LIMIT_RESOURCE_USAGE
1368  case SIGXCPU:
1369    printf ("expression took too long to evaluate%s\n", newline);
1370    break;
1371#endif
1372  case SIGFPE:
1373    printf ("divide by zero%s\n", newline);
1374    break;
1375  default:
1376    printf ("expression required too much memory to evaluate%s\n", newline);
1377    break;
1378  }
1379  exit (-2);
1380}
1381