1/* MMS extension for TCP NAT alteration.
2 * (C) 2002 by Filip Sneppe <filip.sneppe@cronos.be>
3 * based on ip_nat_ftp.c and ip_nat_irc.c
4 *
5 * ip_nat_mms.c v0.3 2002-09-22
6 *
7 *      This program is free software; you can redistribute it and/or
8 *      modify it under the terms of the GNU General Public License
9 *      as published by the Free Software Foundation; either version
10 *      2 of the License, or (at your option) any later version.
11 *
12 *      Module load syntax:
13 *      insmod ip_nat_mms.o ports=port1,port2,...port<MAX_PORTS>
14 *
15 *      Please give the ports of all MMS servers You wish to connect to.
16 *      If you don't specify ports, the default will be TCP port 1755.
17 *
18 *      More info on MMS protocol, firewalls and NAT:
19 *      http://msdn.microsoft.com/library/default.asp?url=/library/en-us/dnwmt/html/MMSFirewall.asp
20 *      http://www.microsoft.com/windows/windowsmedia/serve/firewall.asp
21 *
22 *      The SDP project people are reverse-engineering MMS:
23 *      http://get.to/sdp
24 */
25
26
27#include <linux/module.h>
28#include <linux/netfilter_ipv4.h>
29#include <linux/ip.h>
30#include <linux/tcp.h>
31#include <net/tcp.h>
32#include <linux/netfilter_ipv4/ip_nat.h>
33#include <linux/netfilter_ipv4/ip_nat_helper.h>
34#include <linux/netfilter_ipv4/ip_nat_rule.h>
35#include <linux/netfilter_ipv4/ip_conntrack_mms.h>
36#include <linux/netfilter_ipv4/ip_conntrack_helper.h>
37
38#define DEBUGP(format, args...)
39#define DUMP_BYTES(address, counter)
40
41#define MAX_PORTS 8
42static int ports[MAX_PORTS];
43static int ports_c = 0;
44
45#ifdef MODULE_PARM
46MODULE_PARM(ports, "1-" __MODULE_STRING(MAX_PORTS) "i");
47#endif
48
49MODULE_AUTHOR("Filip Sneppe <filip.sneppe@cronos.be>");
50MODULE_DESCRIPTION("Microsoft Windows Media Services (MMS) NAT module");
51MODULE_LICENSE("GPL");
52
53DECLARE_LOCK_EXTERN(ip_mms_lock);
54
55
56static int mms_data_fixup(const struct ip_ct_mms_expect *ct_mms_info,
57                          struct ip_conntrack *ct,
58                          struct sk_buff **pskb,
59                          enum ip_conntrack_info ctinfo,
60                          struct ip_conntrack_expect *expect)
61{
62	u_int32_t newip;
63	struct ip_conntrack_tuple t;
64	struct iphdr *iph = (*pskb)->nh.iph;
65	struct tcphdr *tcph = (void *) iph + iph->ihl * 4;
66	char *data = (char *)tcph + tcph->doff * 4;
67	int i, j, k, port;
68	u_int16_t mms_proto;
69
70	u_int32_t *mms_chunkLenLV    = (u_int32_t *)(data + MMS_SRV_CHUNKLENLV_OFFSET);
71	u_int32_t *mms_chunkLenLM    = (u_int32_t *)(data + MMS_SRV_CHUNKLENLM_OFFSET);
72	u_int32_t *mms_messageLength = (u_int32_t *)(data + MMS_SRV_MESSAGELENGTH_OFFSET);
73
74	int zero_padding;
75
76	char buffer[28];         /* "\\255.255.255.255\UDP\65635" * 2 (for unicode) */
77	char unicode_buffer[75]; /* 27*2 (unicode) + 20 + 1 */
78	char proto_string[6];
79
80	MUST_BE_LOCKED(&ip_mms_lock);
81
82	/* what was the protocol again ? */
83	mms_proto = expect->tuple.dst.protonum;
84	sprintf(proto_string, "%u", mms_proto);
85
86	DEBUGP("ip_nat_mms: mms_data_fixup: info (seq %u + %u) in %u, proto %s\n",
87	       expect->seq, ct_mms_info->len, ntohl(tcph->seq),
88	       mms_proto == IPPROTO_UDP ? "UDP"
89	       : mms_proto == IPPROTO_TCP ? "TCP":proto_string);
90
91	newip = ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.ip;
92
93	/* Alter conntrack's expectations. */
94	t = expect->tuple;
95	t.dst.ip = newip;
96	for (port = ct_mms_info->port; port != 0; port++) {
97		t.dst.u.tcp.port = htons(port);
98		if (ip_conntrack_change_expect(expect, &t) == 0) {
99			DEBUGP("ip_nat_mms: mms_data_fixup: using port %d\n", port);
100			break;
101		}
102	}
103
104	if(port == 0)
105		return 0;
106
107	sprintf(buffer, "\\\\%u.%u.%u.%u\\%s\\%u",
108	        NIPQUAD(newip),
109		expect->tuple.dst.protonum == IPPROTO_UDP ? "UDP"
110		: expect->tuple.dst.protonum == IPPROTO_TCP ? "TCP":proto_string,
111		port);
112	DEBUGP("ip_nat_mms: new unicode string=%s\n", buffer);
113
114	memset(unicode_buffer, 0, sizeof(char)*75);
115
116	for (i=0; i<strlen(buffer); ++i)
117		*(unicode_buffer+i*2)=*(buffer+i);
118
119	DEBUGP("ip_nat_mms: mms_data_fixup: padding: %u len: %u\n", ct_mms_info->padding, ct_mms_info->len);
120	DEBUGP("ip_nat_mms: mms_data_fixup: offset: %u\n", MMS_SRV_UNICODE_STRING_OFFSET+ct_mms_info->len);
121	DUMP_BYTES(data+MMS_SRV_UNICODE_STRING_OFFSET, 60);
122
123	/* add end of packet to it */
124	for (j=0; j<ct_mms_info->padding; ++j) {
125		DEBUGP("ip_nat_mms: mms_data_fixup: i=%u j=%u byte=%u\n",
126		       i, j, (u8)*(data+MMS_SRV_UNICODE_STRING_OFFSET+ct_mms_info->len+j));
127		*(unicode_buffer+i*2+j) = *(data+MMS_SRV_UNICODE_STRING_OFFSET+ct_mms_info->len+j);
128	}
129
130	/* pad with zeroes at the end ? see explanation of weird math below */
131	zero_padding = (8-(strlen(buffer)*2 + ct_mms_info->padding + 4)%8)%8;
132	for (k=0; k<zero_padding; ++k)
133		*(unicode_buffer+i*2+j+k)= (char)0;
134
135	DEBUGP("ip_nat_mms: mms_data_fixup: zero_padding = %u\n", zero_padding);
136	DEBUGP("ip_nat_mms: original=> chunkLenLV=%u chunkLenLM=%u messageLength=%u\n",
137	       *mms_chunkLenLV, *mms_chunkLenLM, *mms_messageLength);
138
139	/* explanation, before I forget what I did:
140	   strlen(buffer)*2 + ct_mms_info->padding + 4 must be divisable by 8;
141	   divide by 8 and add 3 to compute the mms_chunkLenLM field,
142	   but note that things may have to be padded with zeroes to align by 8
143	   bytes, hence we add 7 and divide by 8 to get the correct length */
144	*mms_chunkLenLM    = (u_int32_t) (3+(strlen(buffer)*2+ct_mms_info->padding+11)/8);
145	*mms_chunkLenLV    = *mms_chunkLenLM+2;
146	*mms_messageLength = *mms_chunkLenLV*8;
147
148	DEBUGP("ip_nat_mms: modified=> chunkLenLV=%u chunkLenLM=%u messageLength=%u\n",
149	       *mms_chunkLenLV, *mms_chunkLenLM, *mms_messageLength);
150
151	ip_nat_mangle_tcp_packet(pskb, ct, ctinfo,
152	                         expect->seq - ntohl(tcph->seq),
153	                         ct_mms_info->len + ct_mms_info->padding, unicode_buffer,
154	                         strlen(buffer)*2 + ct_mms_info->padding + zero_padding);
155	DUMP_BYTES(unicode_buffer, 60);
156
157	return 1;
158}
159
160static unsigned int
161mms_nat_expected(struct sk_buff **pskb,
162                 unsigned int hooknum,
163                 struct ip_conntrack *ct,
164                 struct ip_nat_info *info)
165{
166	struct ip_nat_multi_range mr;
167	u_int32_t newdstip, newsrcip, newip;
168
169	struct ip_conntrack *master = master_ct(ct);
170
171	IP_NF_ASSERT(info);
172	IP_NF_ASSERT(master);
173
174	IP_NF_ASSERT(!(info->initialized & (1 << HOOK2MANIP(hooknum))));
175
176	DEBUGP("ip_nat_mms: mms_nat_expected: We have a connection!\n");
177
178	newdstip = master->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.ip;
179	newsrcip = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.ip;
180	DEBUGP("ip_nat_mms: mms_nat_expected: hook %s: newsrc->newdst %u.%u.%u.%u->%u.%u.%u.%u\n",
181	       hooknum == NF_IP_POST_ROUTING ? "POSTROUTING"
182	       : hooknum == NF_IP_PRE_ROUTING ? "PREROUTING"
183	       : hooknum == NF_IP_LOCAL_OUT ? "OUTPUT" : "???",
184	       NIPQUAD(newsrcip), NIPQUAD(newdstip));
185
186	if (HOOK2MANIP(hooknum) == IP_NAT_MANIP_SRC)
187		newip = newsrcip;
188	else
189		newip = newdstip;
190
191	DEBUGP("ip_nat_mms: mms_nat_expected: IP to %u.%u.%u.%u\n", NIPQUAD(newip));
192
193	mr.rangesize = 1;
194	/* We don't want to manip the per-protocol, just the IPs. */
195	mr.range[0].flags = IP_NAT_RANGE_MAP_IPS;
196	mr.range[0].min_ip = mr.range[0].max_ip = newip;
197
198	return ip_nat_setup_info(ct, &mr, hooknum);
199}
200
201
202static unsigned int mms_nat_help(struct ip_conntrack *ct,
203			 struct ip_conntrack_expect *exp,
204			 struct ip_nat_info *info,
205			 enum ip_conntrack_info ctinfo,
206			 unsigned int hooknum,
207			 struct sk_buff **pskb)
208{
209	struct iphdr *iph = (*pskb)->nh.iph;
210	struct tcphdr *tcph = (void *) iph + iph->ihl * 4;
211	unsigned int datalen;
212	int dir;
213	struct ip_ct_mms_expect *ct_mms_info;
214
215	if (!exp)
216		DEBUGP("ip_nat_mms: no exp!!");
217
218	ct_mms_info = &exp->help.exp_mms_info;
219
220	/* Only mangle things once: original direction in POST_ROUTING
221	   and reply direction on PRE_ROUTING. */
222	dir = CTINFO2DIR(ctinfo);
223	if (!((hooknum == NF_IP_POST_ROUTING && dir == IP_CT_DIR_ORIGINAL)
224	    ||(hooknum == NF_IP_PRE_ROUTING && dir == IP_CT_DIR_REPLY))) {
225		DEBUGP("ip_nat_mms: mms_nat_help: not touching dir %s at hook %s\n",
226		       dir == IP_CT_DIR_ORIGINAL ? "ORIG" : "REPLY",
227		       hooknum == NF_IP_POST_ROUTING ? "POSTROUTING"
228		       : hooknum == NF_IP_PRE_ROUTING ? "PREROUTING"
229		       : hooknum == NF_IP_LOCAL_OUT ? "OUTPUT" : "???");
230		return NF_ACCEPT;
231	}
232	DEBUGP("ip_nat_mms: mms_nat_help: beyond not touching (dir %s at hook %s)\n",
233	       dir == IP_CT_DIR_ORIGINAL ? "ORIG" : "REPLY",
234	       hooknum == NF_IP_POST_ROUTING ? "POSTROUTING"
235	       : hooknum == NF_IP_PRE_ROUTING ? "PREROUTING"
236	       : hooknum == NF_IP_LOCAL_OUT ? "OUTPUT" : "???");
237
238	datalen = (*pskb)->len - iph->ihl * 4 - tcph->doff * 4;
239
240	DEBUGP("ip_nat_mms: mms_nat_help: %u+%u=%u %u %u\n", exp->seq, ct_mms_info->len,
241	       exp->seq + ct_mms_info->len,
242	       ntohl(tcph->seq),
243	       ntohl(tcph->seq) + datalen);
244
245	LOCK_BH(&ip_mms_lock);
246	/* Check wether the whole IP/proto/port pattern is carried in the payload */
247	if (between(exp->seq + ct_mms_info->len,
248	    ntohl(tcph->seq),
249	    ntohl(tcph->seq) + datalen)) {
250		if (!mms_data_fixup(ct_mms_info, ct, pskb, ctinfo, exp)) {
251			UNLOCK_BH(&ip_mms_lock);
252			return NF_DROP;
253		}
254	} else {
255		/* Half a match?  This means a partial retransmisison.
256		   It's a cracker being funky. */
257		if (net_ratelimit()) {
258			printk("ip_nat_mms: partial packet %u/%u in %u/%u\n",
259			       exp->seq, ct_mms_info->len,
260			       ntohl(tcph->seq),
261			       ntohl(tcph->seq) + datalen);
262		}
263		UNLOCK_BH(&ip_mms_lock);
264		return NF_DROP;
265	}
266	UNLOCK_BH(&ip_mms_lock);
267
268	return NF_ACCEPT;
269}
270
271static struct ip_nat_helper mms[MAX_PORTS];
272static char mms_names[MAX_PORTS][10];
273
274/* Not __exit: called from init() */
275static void fini(void)
276{
277	int i;
278
279	for (i = 0; (i < MAX_PORTS) && ports[i]; i++) {
280		DEBUGP("ip_nat_mms: unregistering helper for port %d\n", ports[i]);
281		ip_nat_helper_unregister(&mms[i]);
282	}
283}
284
285static int __init init(void)
286{
287	int i, ret = 0;
288	char *tmpname;
289
290	if (ports[0] == 0)
291		ports[0] = MMS_PORT;
292
293	for (i = 0; (i < MAX_PORTS) && ports[i]; i++) {
294
295		memset(&mms[i], 0, sizeof(struct ip_nat_helper));
296
297		mms[i].tuple.dst.protonum = IPPROTO_TCP;
298		mms[i].tuple.src.u.tcp.port = htons(ports[i]);
299		mms[i].mask.dst.protonum = 0xFFFF;
300		mms[i].mask.src.u.tcp.port = 0xFFFF;
301		mms[i].help = mms_nat_help;
302		mms[i].me = THIS_MODULE;
303		mms[i].flags = 0;
304		mms[i].expect = mms_nat_expected;
305
306		tmpname = &mms_names[i][0];
307		if (ports[i] == MMS_PORT)
308			sprintf(tmpname, "mms");
309		else
310			sprintf(tmpname, "mms-%d", i);
311		mms[i].name = tmpname;
312
313		DEBUGP("ip_nat_mms: register helper for port %d\n",
314				ports[i]);
315		ret = ip_nat_helper_register(&mms[i]);
316
317		if (ret) {
318			printk("ip_nat_mms: error registering "
319			       "helper for port %d\n", ports[i]);
320			fini();
321			return ret;
322		}
323		ports_c++;
324	}
325
326	return ret;
327}
328
329module_init(init);
330module_exit(fini);
331