1#include <stdint.h>
2#include <string.h>
3
4static char* twobyte_strstr(const unsigned char* h, const unsigned char* n) {
5    uint16_t nw = n[0] << 8 | n[1], hw = h[0] << 8 | h[1];
6    for (h++; *h && hw != nw; hw = hw << 8 | *++h)
7        ;
8    return *h ? (char*)h - 1 : 0;
9}
10
11static char* threebyte_strstr(const unsigned char* h, const unsigned char* n) {
12    uint32_t nw = n[0] << 24 | n[1] << 16 | n[2] << 8;
13    uint32_t hw = h[0] << 24 | h[1] << 16 | h[2] << 8;
14    for (h += 2; *h && hw != nw; hw = (hw | *++h) << 8)
15        ;
16    return *h ? (char*)h - 2 : 0;
17}
18
19static char* fourbyte_strstr(const unsigned char* h, const unsigned char* n) {
20    uint32_t nw = n[0] << 24 | n[1] << 16 | n[2] << 8 | n[3];
21    uint32_t hw = h[0] << 24 | h[1] << 16 | h[2] << 8 | h[3];
22    for (h += 3; *h && hw != nw; hw = hw << 8 | *++h)
23        ;
24    return *h ? (char*)h - 3 : 0;
25}
26
27#define MAX(a, b) ((a) > (b) ? (a) : (b))
28#define MIN(a, b) ((a) < (b) ? (a) : (b))
29
30#define BITOP(a, b, op) \
31    ((a)[(size_t)(b) / (8 * sizeof *(a))] op(size_t) 1 << ((size_t)(b) % (8 * sizeof *(a))))
32
33static char* twoway_strstr(const unsigned char* h, const unsigned char* n) {
34    const unsigned char* z;
35    size_t l, ip, jp, k, p, ms, p0, mem, mem0;
36    size_t byteset[32 / sizeof(size_t)] = {};
37    size_t shift[256];
38
39    /* Computing length of needle and fill shift table */
40    for (l = 0; n[l] && h[l]; l++)
41        BITOP(byteset, n[l], |=)
42    , shift[n[l]] = l + 1;
43    if (n[l])
44        return 0; /* hit the end of h */
45
46    /* Compute maximal suffix */
47    ip = -1;
48    jp = 0;
49    k = p = 1;
50    while (jp + k < l) {
51        if (n[ip + k] == n[jp + k]) {
52            if (k == p) {
53                jp += p;
54                k = 1;
55            } else
56                k++;
57        } else if (n[ip + k] > n[jp + k]) {
58            jp += k;
59            k = 1;
60            p = jp - ip;
61        } else {
62            ip = jp++;
63            k = p = 1;
64        }
65    }
66    ms = ip;
67    p0 = p;
68
69    /* And with the opposite comparison */
70    ip = -1;
71    jp = 0;
72    k = p = 1;
73    while (jp + k < l) {
74        if (n[ip + k] == n[jp + k]) {
75            if (k == p) {
76                jp += p;
77                k = 1;
78            } else
79                k++;
80        } else if (n[ip + k] < n[jp + k]) {
81            jp += k;
82            k = 1;
83            p = jp - ip;
84        } else {
85            ip = jp++;
86            k = p = 1;
87        }
88    }
89    if (ip + 1 > ms + 1)
90        ms = ip;
91    else
92        p = p0;
93
94    /* Periodic needle? */
95    if (memcmp(n, n + p, ms + 1)) {
96        mem0 = 0;
97        p = MAX(ms, l - ms - 1) + 1;
98    } else
99        mem0 = l - p;
100    mem = 0;
101
102    /* Initialize incremental end-of-haystack pointer */
103    z = h;
104
105    /* Search loop */
106    for (;;) {
107        /* Update incremental end-of-haystack pointer */
108        if (z - h < l) {
109            /* Fast estimate for MIN(l,63) */
110            size_t grow = l | 63;
111            const unsigned char* z2 = memchr(z, 0, grow);
112            if (z2) {
113                z = z2;
114                if (z - h < l)
115                    return 0;
116            } else
117                z += grow;
118        }
119
120        /* Check last byte first; advance by shift on mismatch */
121        if (BITOP(byteset, h[l - 1], &)) {
122            k = l - shift[h[l - 1]];
123            // printf("adv by %zu (on %c) at [%s] (%zu;l=%zu)\n", k, h[l-1], h, shift[h[l-1]], l);
124            if (k) {
125                if (mem0 && mem && k < p)
126                    k = l - p;
127                h += k;
128                mem = 0;
129                continue;
130            }
131        } else {
132            h += l;
133            mem = 0;
134            continue;
135        }
136
137        /* Compare right half */
138        for (k = MAX(ms + 1, mem); n[k] && n[k] == h[k]; k++)
139            ;
140        if (n[k]) {
141            h += k - ms;
142            mem = 0;
143            continue;
144        }
145        /* Compare left half */
146        for (k = ms + 1; k > mem && n[k - 1] == h[k - 1]; k--)
147            ;
148        if (k <= mem)
149            return (char*)h;
150        h += p;
151        mem = mem0;
152    }
153}
154
155char* strstr(const char* h, const char* n) {
156    /* Return immediately on empty needle */
157    if (!n[0])
158        return (char*)h;
159
160    /* Use faster algorithms for short needles */
161    h = strchr(h, *n);
162    if (!h || !n[1])
163        return (char*)h;
164    if (!h[1])
165        return 0;
166    if (!n[2])
167        return twobyte_strstr((void*)h, (void*)n);
168    if (!h[2])
169        return 0;
170    if (!n[3])
171        return threebyte_strstr((void*)h, (void*)n);
172    if (!h[3])
173        return 0;
174    if (!n[4])
175        return fourbyte_strstr((void*)h, (void*)n);
176
177    return twoway_strstr((void*)h, (void*)n);
178}
179