1130812Smarcel#define _GNU_SOURCE
2130812Smarcel#include <stdint.h>
3130812Smarcel#include <string.h>
4130812Smarcel
5130812Smarcelstatic char* twobyte_memmem(const unsigned char* h, size_t k, const unsigned char* n) {
6130812Smarcel    uint16_t nw = n[0] << 8 | n[1], hw = h[0] << 8 | h[1];
7130812Smarcel    for (h += 2, k -= 2; k; k--, hw = hw << 8 | *h++)
8130812Smarcel        if (hw == nw)
9130812Smarcel            return (char*)h - 2;
10130812Smarcel    return hw == nw ? (char*)h - 2 : 0;
11130812Smarcel}
12130812Smarcel
13130812Smarcelstatic char* threebyte_memmem(const unsigned char* h, size_t k, const unsigned char* n) {
14130812Smarcel    uint32_t nw = n[0] << 24 | n[1] << 16 | n[2] << 8;
15130812Smarcel    uint32_t hw = h[0] << 24 | h[1] << 16 | h[2] << 8;
16130812Smarcel    for (h += 3, k -= 3; k; k--, hw = (hw | *h++) << 8)
17130812Smarcel        if (hw == nw)
18130812Smarcel            return (char*)h - 3;
19130812Smarcel    return hw == nw ? (char*)h - 3 : 0;
20130812Smarcel}
21130812Smarcel
22130812Smarcelstatic char* fourbyte_memmem(const unsigned char* h, size_t k, const unsigned char* n) {
23130812Smarcel    uint32_t nw = n[0] << 24 | n[1] << 16 | n[2] << 8 | n[3];
24130812Smarcel    uint32_t hw = h[0] << 24 | h[1] << 16 | h[2] << 8 | h[3];
25130812Smarcel    for (h += 4, k -= 4; k; k--, hw = hw << 8 | *h++)
26130812Smarcel        if (hw == nw)
27130812Smarcel            return (char*)h - 4;
28130812Smarcel    return hw == nw ? (char*)h - 4 : 0;
29130812Smarcel}
30130812Smarcel
31130812Smarcel#define MAX(a, b) ((a) > (b) ? (a) : (b))
32130812Smarcel#define MIN(a, b) ((a) < (b) ? (a) : (b))
33130812Smarcel
34130812Smarcel#define BITOP(a, b, op) \
35130812Smarcel    ((a)[(size_t)(b) / (8 * sizeof *(a))] op(size_t) 1 << ((size_t)(b) % (8 * sizeof *(a))))
36130812Smarcel
37130812Smarcelstatic char* twoway_memmem(const unsigned char* h, const unsigned char* z, const unsigned char* n,
38130812Smarcel                           size_t l) {
39130812Smarcel    size_t i, ip, jp, k, p, ms, p0, mem, mem0;
40130812Smarcel    size_t byteset[32 / sizeof(size_t)] = {};
41130812Smarcel    size_t shift[256];
42130812Smarcel
43130812Smarcel    /* Computing length of needle and fill shift table */
44130812Smarcel    for (i = 0; i < l; i++)
45130812Smarcel        BITOP(byteset, n[i], |=)
46130812Smarcel    , shift[n[i]] = i + 1;
47130812Smarcel
48130812Smarcel    /* Compute maximal suffix */
49130812Smarcel    ip = -1;
50130812Smarcel    jp = 0;
51130812Smarcel    k = p = 1;
52130812Smarcel    while (jp + k < l) {
53130812Smarcel        if (n[ip + k] == n[jp + k]) {
54130812Smarcel            if (k == p) {
55130812Smarcel                jp += p;
56130812Smarcel                k = 1;
57130812Smarcel            } else
58130812Smarcel                k++;
59130812Smarcel        } else if (n[ip + k] > n[jp + k]) {
60130812Smarcel            jp += k;
61130812Smarcel            k = 1;
62130812Smarcel            p = jp - ip;
63130812Smarcel        } else {
64130812Smarcel            ip = jp++;
65130812Smarcel            k = p = 1;
66130812Smarcel        }
67130812Smarcel    }
68130812Smarcel    ms = ip;
69130812Smarcel    p0 = p;
70130812Smarcel
71130812Smarcel    /* And with the opposite comparison */
72130812Smarcel    ip = -1;
73130812Smarcel    jp = 0;
74130812Smarcel    k = p = 1;
75130812Smarcel    while (jp + k < l) {
76130812Smarcel        if (n[ip + k] == n[jp + k]) {
77130812Smarcel            if (k == p) {
78130812Smarcel                jp += p;
79130812Smarcel                k = 1;
80130812Smarcel            } else
81130812Smarcel                k++;
82130812Smarcel        } else if (n[ip + k] < n[jp + k]) {
83130812Smarcel            jp += k;
84130812Smarcel            k = 1;
85130812Smarcel            p = jp - ip;
86130812Smarcel        } else {
87130812Smarcel            ip = jp++;
88130812Smarcel            k = p = 1;
89130812Smarcel        }
90130812Smarcel    }
91130812Smarcel    if (ip + 1 > ms + 1)
92130812Smarcel        ms = ip;
93130812Smarcel    else
94130812Smarcel        p = p0;
95130812Smarcel
96130812Smarcel    /* Periodic needle? */
97130812Smarcel    if (memcmp(n, n + p, ms + 1)) {
98130812Smarcel        mem0 = 0;
99130812Smarcel        p = MAX(ms, l - ms - 1) + 1;
100130812Smarcel    } else
101130812Smarcel        mem0 = l - p;
102130812Smarcel    mem = 0;
103130812Smarcel
104130812Smarcel    /* Search loop */
105130812Smarcel    for (;;) {
106130812Smarcel        /* If remainder of haystack is shorter than needle, done */
107130812Smarcel        if (z - h < l)
108130812Smarcel            return 0;
109130812Smarcel
110130812Smarcel        /* Check last byte first; advance by shift on mismatch */
111130812Smarcel        if (BITOP(byteset, h[l - 1], &)) {
112130812Smarcel            k = l - shift[h[l - 1]];
113130812Smarcel            if (k) {
114130812Smarcel                if (mem0 && mem && k < p)
115130812Smarcel                    k = l - p;
116130812Smarcel                h += k;
117130812Smarcel                mem = 0;
118130812Smarcel                continue;
119130812Smarcel            }
120130812Smarcel        } else {
121130812Smarcel            h += l;
122130812Smarcel            mem = 0;
123130812Smarcel            continue;
124130812Smarcel        }
125130812Smarcel
126130812Smarcel        /* Compare right half */
127130812Smarcel        for (k = MAX(ms + 1, mem); k < l && n[k] == h[k]; k++)
128130812Smarcel            ;
129130812Smarcel        if (k < l) {
130130812Smarcel            h += k - ms;
131130812Smarcel            mem = 0;
132130812Smarcel            continue;
133130812Smarcel        }
134130812Smarcel        /* Compare left half */
135130812Smarcel        for (k = ms + 1; k > mem && n[k - 1] == h[k - 1]; k--)
136130812Smarcel            ;
137130812Smarcel        if (k <= mem)
138130812Smarcel            return (char*)h;
139130812Smarcel        h += p;
140130812Smarcel        mem = mem0;
141130812Smarcel    }
142130812Smarcel}
143130812Smarcel
144130812Smarcelvoid* memmem(const void* h0, size_t k, const void* n0, size_t l) {
145130812Smarcel    const unsigned char *h = h0, *n = n0;
146130812Smarcel
147130812Smarcel    /* Return immediately on empty needle */
148130812Smarcel    if (!l)
149130812Smarcel        return (void*)h;
150130812Smarcel
151130812Smarcel    /* Return immediately when needle is longer than haystack */
152130812Smarcel    if (k < l)
153130812Smarcel        return 0;
154130812Smarcel
155130812Smarcel    /* Use faster algorithms for short needles */
156130812Smarcel    h = memchr(h0, *n, k);
157130812Smarcel    if (!h || l == 1)
158130812Smarcel        return (void*)h;
159130812Smarcel    k -= h - (const unsigned char*)h0;
160130812Smarcel    if (k < l)
161130812Smarcel        return 0;
162130812Smarcel    if (l == 2)
163130812Smarcel        return twobyte_memmem(h, k, n);
164130812Smarcel    if (l == 3)
165130812Smarcel        return threebyte_memmem(h, k, n);
166130812Smarcel    if (l == 4)
167130812Smarcel        return fourbyte_memmem(h, k, n);
168130812Smarcel
169130812Smarcel    return twoway_memmem(h, h + k, n, l);
170130812Smarcel}
171130812Smarcel