001/*
002 * This file is part of the JDrupes non-blocking HTTP Codec
003 * Copyright (C) 2016, 2018  Michael N. Lipp
004 *
005 * This program is free software; you can redistribute it and/or modify it 
006 * under the terms of the GNU Lesser General Public License as published
007 * by the Free Software Foundation; either version 3 of the License, or 
008 * (at your option) any later version.
009 *
010 * This program is distributed in the hope that it will be useful, but 
011 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
012 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 
013 * License for more details.
014 *
015 * You should have received a copy of the GNU Lesser General Public License along 
016 * with this program; if not, see <http://www.gnu.org/licenses/>.
017 */
018
019package org.jdrupes.httpcodec.protocols.websocket;
020
021import java.nio.Buffer;
022import java.nio.ByteBuffer;
023import java.nio.CharBuffer;
024import java.nio.charset.Charset;
025import java.nio.charset.CoderResult;
026import java.util.Optional;
027
028import org.jdrupes.httpcodec.Decoder;
029import org.jdrupes.httpcodec.Encoder;
030import org.jdrupes.httpcodec.ProtocolException;
031import org.jdrupes.httpcodec.util.ByteBufferUtils;
032import org.jdrupes.httpcodec.util.OptimizedCharsetDecoder;
033
034/**
035 * The Websocket decoder.
036 */
037public class WsDecoder  extends WsCodec
038        implements Decoder<WsFrameHeader, WsFrameHeader> {
039
040        private static enum State { READING_HEADER, READING_LENGTH,
041                READING_MASK, READING_PAYLOAD, READING_PING_DATA,
042                READING_PONG_DATA, READING_CLOSE_DATA }
043        
044        private static enum Opcode { CONT_FRAME, TEXT_FRAME, BIN_FRAME,
045                CON_CLOSE, PING, PONG;
046
047                public static Opcode fromInt(int value) {
048                        switch (value) {
049                        case 0: return Opcode.CONT_FRAME;
050                        case 1: return Opcode.TEXT_FRAME;
051                        case 2: return Opcode.BIN_FRAME;
052                        case 8: return Opcode.CON_CLOSE;
053                        case 9: return Opcode.PING;
054                        case 10: return Opcode.PONG;
055                        }
056                        throw new IllegalArgumentException();
057                }
058        }
059        
060        private static Result.Factory resultFactory = new Result.Factory();
061        
062        private State state = State.READING_HEADER;
063        private long bytesExpected = 2;
064        private boolean dataMessageFinished = true;
065        private int curHeaderHead = 0;
066        private byte[] maskingKey = new byte[4];
067        private int maskIndex;
068        private long payloadLength = 0;
069        private OptimizedCharsetDecoder charDecoder = null;
070        private boolean receivedDataIsMasked;
071        private WsFrameHeader receivedHeader = null;
072        private WsFrameHeader reportedHeader = null;
073        private ByteBuffer controlData = null;
074        private CharBuffer controlChars = null;
075        
076        public Decoder<WsFrameHeader, WsFrameHeader> setPeerEncoder(
077                        Encoder<WsFrameHeader, WsFrameHeader> encoder) {
078                linkClosingState((WsCodec)encoder);
079                return this;
080        }
081        
082        /**
083         * Returns the result factory for this codec.
084         * 
085         * @return the factory
086         */
087        protected Result.Factory resultFactory() {
088                return resultFactory;
089        }
090        
091        /* (non-Javadoc)
092         * @see org.jdrupes.httpcodec.Decoder#decoding()
093         */
094        @Override
095        public Class<WsFrameHeader> decoding() {
096                return WsFrameHeader.class;
097        }
098
099        private void expectNextFrame() {
100                state = State.READING_HEADER;
101                bytesExpected = 2;
102                curHeaderHead = 0;
103                payloadLength = 0;
104                if (dataMessageFinished && charDecoder != null) {
105                        charDecoder.reset();
106                }               
107        }
108        
109        /* (non-Javadoc)
110         * @see org.jdrupes.httpcodec.Decoder#getHeader()
111         */
112        @Override
113        public Optional<WsFrameHeader> header() {
114                return Optional.ofNullable(receivedHeader);
115        }
116
117        private Result createResult(boolean overflow, boolean underflow, 
118                                boolean closeConnection, WsFrameHeader response, 
119                                boolean responseOnly) {
120                if (receivedHeader != null && receivedHeader != reportedHeader) {
121                        reportedHeader = receivedHeader;
122                        return resultFactory().newResult(overflow, underflow, 
123                                        closeConnection, true, response, responseOnly);
124                }
125                return resultFactory().newResult(overflow, underflow, 
126                                closeConnection, false, response, responseOnly);
127        }
128
129        private Result createResult(boolean overflow, boolean underflow) {
130                return createResult(overflow, underflow, false, null, false);
131        }
132
133        
134        /* (non-Javadoc)
135         * @see RequestDecoder#decode(java.nio.ByteBuffer, java.nio.Buffer, boolean)
136         */
137        @Override
138        public Decoder.Result<WsFrameHeader> decode(ByteBuffer in, Buffer out, 
139                        boolean endOfInput) throws ProtocolException {
140                Decoder.Result<WsFrameHeader> result = null;
141                while (in.hasRemaining()) {
142                        switch (state) {
143                        case READING_HEADER:
144                                curHeaderHead = (curHeaderHead << 8) | (in.get() & 0xFF);
145                                if (--bytesExpected == 0) {
146                                        // "Header head" is complete, retrieve some information
147                                        receivedDataIsMasked = (curHeaderHead & 0x80) != 0;
148                                        payloadLength = curHeaderHead & 0x7f;
149                                        if (payloadLength == 126) {
150                                                payloadLength = 0;
151                                                bytesExpected = 2;
152                                                state = State.READING_LENGTH;
153                                                continue; // shortcut, no need to check result
154                                        }
155                                        if (payloadLength == 127) {
156                                                payloadLength = 0;
157                                                bytesExpected = 8;
158                                                state = State.READING_LENGTH;
159                                                continue; // shortcut, no need to check result
160                                        }
161                                        if (receivedDataIsMasked) {
162                                                bytesExpected = 4;
163                                                state = State.READING_MASK;
164                                                continue; // shortcut, no need to check result
165                                        }
166                                        result = headerComplete();
167                                        break;
168                                }
169                                break;
170                                
171                        case READING_LENGTH:
172                                payloadLength = (payloadLength << 8) | (in.get() & 0xff);
173                                if (--bytesExpected > 0) {
174                                        continue; // shortcut, no need to check result
175                                }
176                                if (receivedDataIsMasked) {
177                                        bytesExpected = 4;
178                                        state = State.READING_MASK;
179                                        continue; // shortcut, no need to check result
180                                }
181                                result = headerComplete();
182                                break;
183                                
184                        case READING_MASK:
185                                maskingKey[4 - (int)bytesExpected] = in.get();
186                                if (--bytesExpected > 0) {
187                                        continue; // shortcut, no need to check result
188                                }
189                                maskIndex = 0;
190                                result = headerComplete();
191                                break;
192                                
193                        case READING_PAYLOAD:
194                                if (out == null) {
195                                        return createResult(true, false);
196                                }
197                                int initiallyAvailable = in.remaining();
198                                CoderResult decRes = copyData(out, in,
199                                        bytesExpected > Integer.MAX_VALUE
200                                        ? Integer.MAX_VALUE : (int) bytesExpected, 
201                                    endOfInput);
202                                bytesExpected -= (initiallyAvailable - in.remaining());
203                                if (bytesExpected == 0) {
204                                        expectNextFrame();
205                                        if (dataMessageFinished) {
206                                                result = createResult(false, false);
207                                        }
208                                        break;
209                                }
210                                return createResult(
211                                        (in.hasRemaining() && !out.hasRemaining())
212                                                || (decRes != null && decRes.isOverflow()),
213                                        !in.hasRemaining()
214                                                || (decRes != null && decRes.isUnderflow()));
215
216                        case READING_PING_DATA:
217                        case READING_PONG_DATA:
218                                initiallyAvailable = in.remaining();
219                                copyData(controlData, in, (int) bytesExpected, endOfInput);
220                                bytesExpected -= (initiallyAvailable - in.remaining());
221                                if (bytesExpected == 0) {
222                                        controlData.flip();
223                                        if (state == State.READING_PING_DATA) {
224                                                receivedHeader = new WsPingFrame(controlData);
225                                                result = createResult(false, !dataMessageFinished, 
226                                                        false, new WsPongFrame(controlData.duplicate()), true);
227                                                expectNextFrame();
228                                        } else {
229                                                receivedHeader = new WsPongFrame(controlData);
230                                                result = createResult(false, !dataMessageFinished);
231                                                expectNextFrame();
232                                        }
233                                        controlData = null;
234                                        return result;
235                                }
236                                return createResult(false, true);
237                                
238                        case READING_CLOSE_DATA:
239                                if (controlData.position() < 2) {
240                                        if (receivedDataIsMasked) {
241                                                controlData.put(
242                                                                (byte)(in.get() ^ maskingKey[maskIndex]));
243                                                maskIndex = (maskIndex + 1) % 4;
244                                        } else {
245                                                controlData.put(in.get());
246                                        }
247                                        bytesExpected -= 1;
248                                        if (bytesExpected == 0) {
249                                                // Close frame with status code only
250                                                expectNextFrame();
251                                                return createCloseResult();
252                                        }
253                                        continue;
254                                }
255                                if (charDecoder == null) {
256                                        charDecoder = new OptimizedCharsetDecoder(
257                                                Charset.forName("UTF-8").newDecoder());
258                                }
259                                initiallyAvailable = in.remaining();
260                                copyData(controlChars, in, (int) bytesExpected, endOfInput);
261                                bytesExpected -= (initiallyAvailable - in.remaining());
262                                if (bytesExpected == 0) {
263                                        expectNextFrame();
264                                        return createCloseResult();
265                                }
266                                return createResult(false, true);
267                        }
268                        if (result != null) {
269                                return result;
270                        }
271                }
272                return createResult(false, bytesExpected > 0);
273        }
274
275        private Decoder.Result<WsFrameHeader> headerComplete() {
276                receivedHeader = null;
277                reportedHeader = null;
278                boolean finalFrame = isFinalFrame();
279                if ((curHeaderHead >> 8 & 0x8) == 0) {
280                        // Not a control frame, update from FIN bit
281                        dataMessageFinished = finalFrame;
282                }
283                bytesExpected = payloadLength;
284                Opcode opcode = Opcode.fromInt(curHeaderHead >> 8 & 0xf);
285                switch (opcode) {
286                case CONT_FRAME:
287                        if (bytesExpected == 0) {
288                                // kind of ridiculous
289                                expectNextFrame();
290                                return createResult(false, !finalFrame);
291                        }
292                        state = State.READING_PAYLOAD;
293                        return null;
294                case TEXT_FRAME:
295                        if (charDecoder == null) {
296                                charDecoder = new OptimizedCharsetDecoder(
297                                        Charset.forName("UTF-8").newDecoder());
298                        }
299                        break;
300                case PING:
301                        if (bytesExpected == 0) {
302                                expectNextFrame();
303                                receivedHeader = new WsPingFrame(null);
304                                return createResult(false, !dataMessageFinished, 
305                                                false, new WsPongFrame(null), true);
306                        }
307                        controlData = ByteBuffer.allocate((int)bytesExpected);
308                        state = State.READING_PING_DATA;
309                        return null;
310                case PONG:
311                        if (bytesExpected == 0) {
312                                expectNextFrame();
313                                receivedHeader = new WsPongFrame(null);
314                                return createResult(false, !dataMessageFinished);
315                        }
316                        controlData = ByteBuffer.allocate((int)bytesExpected);
317                        state = State.READING_PONG_DATA;
318                        return null;
319                case CON_CLOSE:
320                        if (bytesExpected == 0) {
321                                expectNextFrame();
322                                return createCloseResult();
323                        }
324                        controlData = ByteBuffer.allocate(2);
325                        // upper limit (reached if each byte becomes a char)
326                        controlChars = CharBuffer.allocate((int)bytesExpected);
327                        state = State.READING_CLOSE_DATA;
328                        return null;
329                default:
330                        break;
331                }
332                receivedHeader = new WsMessageHeader(opcode == Opcode.TEXT_FRAME,
333                                bytesExpected > 0);
334                if (bytesExpected == 0) {
335                        expectNextFrame();
336                        return createResult(false, false);
337                }
338                state = State.READING_PAYLOAD;
339                return null;
340        }
341        
342        private Decoder.Result<WsFrameHeader> createCloseResult() {
343                Integer status = null;
344                if (controlData != null) {
345                        controlData.flip();
346                        status = 0;
347                        while (controlData.hasRemaining()) {
348                                status = (status << 8) | (controlData.get() & 0xff);
349                        }
350                        controlData = null;
351                        controlChars.flip();
352                        receivedHeader = new WsCloseFrame(status, controlChars);
353                        controlChars = null;
354                } else {
355                        receivedHeader = new WsCloseFrame(status, null);
356                }
357                
358                // Handle close status
359                WsCloseResponse ctrlResponse = null;
360                switch (closingState()) {
361                case OPEN:
362                        setClosingState(ClosingState.CLOSE_RECEIVED);
363                        // fall through
364                case CLOSE_RECEIVED:
365                        // Actually *sending* (i.e. encoding) the response will
366                        // advance the state furher.
367                        ctrlResponse = new WsCloseResponse(status);
368                        break;
369                case CLOSE_SENT:
370                        // Was sent, is now received, cycle completed.
371                        setClosingState(ClosingState.CLOSED);
372                        break;
373                case CLOSED:
374                        break;
375                }
376                // If received data is masked, we're on the server side.
377                return createResult(false, false, 
378                                receivedDataIsMasked && closingState() == ClosingState.CLOSED, 
379                                ctrlResponse, false);
380        }
381
382        private boolean isFinalFrame() {
383                return (curHeaderHead & 0x8000) != 0;
384        }
385        
386        private CoderResult copyData(
387                        Buffer out, ByteBuffer in, int limit, boolean endOfInput) {
388                if (out instanceof ByteBuffer) {
389                        if (!receivedDataIsMasked) {
390                                ByteBufferUtils.putAsMuchAsPossible((ByteBuffer) out, in, limit);
391                                return null;
392                        }
393                        while (limit > 0 && in.hasRemaining() && out.hasRemaining()) {
394                                ((ByteBuffer) out).put(
395                                                (byte)(in.get() ^ maskingKey[maskIndex]));
396                                maskIndex = (maskIndex + 1) % 4;
397                                limit -= 1;
398                        }
399                        return null;
400                } 
401                if (out instanceof CharBuffer) {
402                        if (receivedDataIsMasked) {
403                                ByteBuffer unmasked = ByteBuffer.allocate(1);
404                                CoderResult res = null;
405                                while (limit > 0 && in.hasRemaining() && out.hasRemaining()) {
406                                        unmasked.put((byte)(in.get() ^ maskingKey[maskIndex]));
407                                        maskIndex = (maskIndex + 1) % 4;
408                                        limit -= 1;
409                                        unmasked.flip();
410                                        res = charDecoder.decode(unmasked, (CharBuffer)out, 
411                                                        !in.hasRemaining() && endOfInput);
412                                        unmasked.clear();
413                                }
414                                return res;
415                        }
416                        int oldLimit = in.limit();
417                        try {
418                                if (in.remaining() > limit) {
419                                        in.limit(in.position() + limit);
420                                }
421                                return charDecoder.decode(in, (CharBuffer)out, endOfInput);
422                        } finally {
423                                in.limit(oldLimit);
424                        }
425                } else {
426                        throw new IllegalArgumentException(
427                                "Only Byte- or CharBuffer are allowed.");
428                }
429        }
430
431        /**
432         * Results from {@link WsDecoder} add no additional
433         * information to {@link org.jdrupes.httpcodec.Decoder.Result}. This
434         * class just provides a factory for creating concrete results.
435         * 
436         * The class is declared abstract to promote the usage of the factory
437         * method.
438         */
439        public abstract static class Result
440                extends Decoder.Result<WsFrameHeader> {
441
442                protected Result(boolean overflow, boolean underflow,
443                        boolean closeConnection, boolean headerCompleted,
444                        WsFrameHeader response, boolean responseOnly) {
445                        super(overflow, underflow, closeConnection, headerCompleted, response,
446                                responseOnly);
447                }
448
449                protected static class Factory 
450                        extends Decoder.Result.Factory<WsFrameHeader> {
451                        
452                        /**
453                         * Create a new result.
454                         * 
455                         * @param overflow
456                         *            {@code true} if the data didn't fit in the out buffer
457                         * @param underflow
458                         *            {@code true} if more data is expected
459                         * @param closeConnection
460                         *            {@code true} if the connection should be closed
461                         * @param headerCompleted
462                         *            {@code true} if the header has completely been decoded
463                         * @param response
464                         *            a response to send due to an error
465                         * @param responseOnly
466                         *            if the result includes a response this flag indicates
467                         *            that no further processing besides sending the
468                         *            response is required
469                         * @return the result
470                         */
471                        public Result newResult(boolean overflow, boolean underflow, 
472                                        boolean closeConnection, boolean headerCompleted, 
473                                        WsFrameHeader response, boolean responseOnly) {
474                                return new Result(overflow, underflow, closeConnection,
475                                                headerCompleted, response, responseOnly) {
476                                };
477                        }
478                }
479        }
480}