1package CyrusSasl;
2
3import java.io.*;
4
5public class SaslInputStream extends InputStream
6{
7    static final boolean DoEncrypt = true;
8    static final boolean DoDebug = false;
9    private static int BUFFERSIZE = 16384;
10
11    // if bufferend < bufferstart, we've wrapped around
12    private byte[] buffer=new byte[BUFFERSIZE];
13    private int bufferstart = 0;
14    private int bufferend = 0;
15    private int size = 0;
16
17    private GenericCommon conn;
18
19    public InputStream in;
20
21    public SaslInputStream(InputStream in, GenericCommon conn)
22    {
23	if (DoDebug) {
24	    System.err.println("DEBUG constructing SaslInputStream");
25	}
26	this.in = in;
27	this.conn = conn;
28    }
29
30    public synchronized int available() throws IOException
31    {
32	int ina = in.available();
33	if (ina > 1) ina = 1;
34
35	return size + ina;
36    }
37
38    private void buffer_add(byte[] str,int len) throws IOException
39    {
40	if (str == null) {
41	    // nothing to add
42	    return;
43	}
44
45	byte[] b = str;
46
47	/* xxx this can be optimized */
48	for (int lup=0;lup<len;lup++) {
49	    buffer[bufferend]=b[lup];
50	    bufferend = (bufferend + 1) % BUFFERSIZE;
51
52	    size++;
53	    if (size >= BUFFERSIZE) {
54		throw new IOException();
55	    }
56	}
57    }
58
59    private void buffer_add(byte[] str) throws IOException
60    {
61	buffer_add(str,str.length);
62    }
63
64    private void readsome() throws IOException
65    {
66	int len=in.available();
67
68	if (DoDebug) {
69	    System.err.println("DEBUG in readsome(), avail " + len);
70	}
71
72	if (len > BUFFERSIZE || len == 0)
73	    len = BUFFERSIZE;
74
75	byte[]tmp=new byte[len];
76	len = in.read(tmp);
77
78	if (len>0) {
79	    if (DoEncrypt) {
80		buffer_add( conn.decode(tmp,len) );
81	    } else {
82		buffer_add(tmp, len);
83	    }
84	}
85    }
86
87    public synchronized void close() throws IOException
88    {
89	super.close();
90    }
91
92    public synchronized void reset() throws IOException
93    {
94	throw new IOException();
95    }
96
97    public synchronized void mark(int readlimit)
98    {
99	return;
100    }
101
102    public boolean markSupported()
103    {
104	return false;
105    }
106
107    /* read a single byte */
108    public synchronized int read() throws IOException
109    {
110	int ret;
111
112	if (DoDebug) {
113	    System.err.println("DEBUG in read(), size " + size);
114	}
115	if (size == 0) {
116	    readsome();
117	}
118
119	if (size == 0) {
120	    if (DoDebug) {
121		System.err.println("DEBUG read() returning -1");
122	    }
123	    return -1;
124	}
125
126	ret = buffer[bufferstart];
127	bufferstart = (bufferstart + 1) % BUFFERSIZE;
128	size--;
129
130	if (DoDebug) {
131	    System.err.println("DEBUG read() returning " + ret);
132	}
133	return ret;
134    }
135
136    public synchronized int read(byte b[]) throws IOException
137    {
138	return read(b,0,b.length);
139    }
140
141    public synchronized int read(byte b[],
142				 int off,
143				 int len) throws IOException
144    {
145	if (DoDebug) {
146	    System.err.println("DEBUG in read(b, off, len), size " + size);
147	}
148	if (off < 0 || len < 0) {
149	    throw new IndexOutOfBoundsException();
150	}
151	if (len == 0) {
152	    return 0;
153	}
154
155	// block only if we need to
156	if (size == 0) {
157	    readsome();
158	    if (size == 0) {
159		if (DoDebug) {
160		    System.err.println("DEBUG read(b, off, len) returning -1");
161		}
162		return -1;
163	    }
164	}
165
166	int l;
167	for (l = off; l < len + off; l++) {
168	    if (bufferstart == bufferend) break;
169
170	    b[l] = buffer[bufferstart];
171	    bufferstart = (bufferstart + 1) % BUFFERSIZE;
172	    size--;
173	}
174
175	if (DoDebug) {
176	    System.err.println("DEBUG read() returning " + (l - off));
177	}
178	return l - off;
179    }
180
181    public synchronized long skip(long n) throws IOException
182    {
183	if (n<=0) return 0;
184
185	long toskip = n;
186	while (toskip > 0) {
187	    if (size == 0) {
188		readsome();
189		if (size == 0) {
190		    return n - toskip;
191		}
192	    }
193
194	    if (toskip > size) {
195		toskip -= size;
196		bufferstart = bufferend = size = 0;
197	    } else {
198		// we've got all the data we need to skip
199		size -= toskip;
200		bufferstart = (int) ((bufferstart + toskip) % BUFFERSIZE);
201	    }
202	}
203
204	// skipped the full amount
205	return n;
206    }
207}
208
209