001/*
002 * This file is part of the JDrupes non-blocking HTTP Codec
003 * Copyright (C) 2016, 2017  Michael N. Lipp
004 *
005 * This program is free software; you can redistribute it and/or modify it 
006 * under the terms of the GNU Lesser General Public License as published
007 * by the Free Software Foundation; either version 3 of the License, or 
008 * (at your option) any later version.
009 *
010 * This program is distributed in the hope that it will be useful, but 
011 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
012 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 
013 * License for more details.
014 *
015 * You should have received a copy of the GNU Lesser General Public License along 
016 * with this program; if not, see <http://www.gnu.org/licenses/>.
017 */
018
019package org.jdrupes.httpcodec.protocols.websocket;
020
021import java.io.IOException;
022import java.io.OutputStreamWriter;
023import java.nio.Buffer;
024import java.nio.ByteBuffer;
025import java.nio.CharBuffer;
026import java.nio.charset.Charset;
027import java.security.SecureRandom;
028import java.util.EmptyStackException;
029import java.util.Optional;
030import java.util.Stack;
031
032import org.jdrupes.httpcodec.Codec;
033import org.jdrupes.httpcodec.Decoder;
034import org.jdrupes.httpcodec.Encoder;
035import org.jdrupes.httpcodec.protocols.http.HttpEncoder;
036import org.jdrupes.httpcodec.util.ByteBufferOutputStream;
037import org.jdrupes.httpcodec.util.ByteBufferUtils;
038
039/**
040 * The Websocket encoder.
041 */
042public class WsEncoder extends WsCodec 
043        implements Encoder<WsFrameHeader, WsFrameHeader> {
044
045        private static enum State { STARTING_FRAME, WRITING_HEADER,  
046                WRITING_LENGTH, WRITING_MASK, WRITING_PAYLOAD }
047        
048        private static float bytesPerCharUtf8           
049                = Charset.forName("utf-8").newEncoder().averageBytesPerChar();
050        private static final Result.Factory resultFactory = new Result.Factory();
051        
052        private SecureRandom randoms = new SecureRandom();
053        private State state = State.STARTING_FRAME;
054        private boolean continuationFrame;
055        private Stack<WsFrameHeader> messageHeaders = new Stack<>();
056        private int headerHead;
057        private long bytesToSend;
058        private long payloadSize;
059        private int payloadBytes;
060        private boolean doMask = false;
061        private byte[] maskingKey = new byte[4];
062        private int maskIndex;
063        private ByteBufferOutputStream convData = new ByteBufferOutputStream();
064
065        /**
066         * Creates new encoder.
067         * 
068         * @param mask set if the data is to be masked (client)
069         */
070        public WsEncoder(boolean mask) {
071                super();
072                this.doMask = mask;
073        }
074
075        public Encoder<WsFrameHeader, WsFrameHeader> setPeerDecoder(
076                        Decoder<WsFrameHeader, WsFrameHeader> decoder) {
077                linkClosingState((WsCodec)decoder);
078                return this;
079        }
080        
081        /**
082         * Returns the result factory for this codec.
083         * 
084         * @return the factory
085         */
086        protected Result.Factory resultFactory() {
087                return resultFactory;
088        }
089        
090        /* (non-Javadoc)
091         * @see org.jdrupes.httpcodec.Encoder#encoding()
092         */
093        @Override
094        public Class<WsFrameHeader> encoding() {
095                return WsFrameHeader.class;
096        }
097
098        private Result frameFinished(boolean endOfInput) {
099                // If we have encoded a close, adapt
100                boolean close = false;
101                if (messageHeaders.peek() instanceof WsCloseFrame) {
102                        switch (closingState()) {
103                        case OPEN:
104                                setClosingState(ClosingState.CLOSE_SENT);
105                                break;
106                        case CLOSE_RECEIVED:
107                                setClosingState(ClosingState.CLOSED);
108                                // fall through
109                        case CLOSED:
110                                if (!doMask) {
111                                        // Server side encoder
112                                        close = true;
113                                }
114                                break;
115                        case CLOSE_SENT:
116                                // Shouldn't happen
117                                break;
118                        }
119                }
120                // Fix statck
121                if (!(messageHeaders.peek() instanceof WsMessageHeader) 
122                                || endOfInput) {
123                        messageHeaders.pop();
124                }
125                state = State.STARTING_FRAME;
126                bytesToSend = 2;
127                return resultFactory().newResult(false, 
128                                !endOfInput || !messageHeaders.isEmpty(), close);
129        }
130        
131        /* (non-Javadoc)
132         * @see org.jdrupes.httpcodec.ResponseEncoder#encode(org.jdrupes.httpcodec.MessageHeader)
133         */
134        @Override
135        public void encode(WsFrameHeader messageHeader) {
136                if (state != State.STARTING_FRAME) {
137                        throw new IllegalStateException(
138                                        "Trying to start new frame while previous "
139                                                + "has not completely been sent");
140                }
141                if (messageHeader instanceof WsMessageHeader) {
142                        messageHeaders.clear();
143                        messageHeaders.push(messageHeader);
144                        if (((WsMessageHeader) messageHeader).isTextMode()) {
145                                headerHead = (1 << 8);
146                        } else {
147                                headerHead = (2 << 8);
148                        }
149                        continuationFrame = false;
150                } else {
151                        messageHeaders.push(messageHeader);
152                        if (messageHeader instanceof WsCloseFrame) {
153                                headerHead = (8 << 8);
154                        } else if (messageHeader instanceof WsPingFrame) {
155                                headerHead = (9 << 8);
156                        } else if (messageHeader instanceof WsPongFrame) {
157                                headerHead = (10 << 8);
158                        } else {
159                                throw new IllegalArgumentException(
160                                        "Invalid hessage header type");
161                        }
162                }
163                state = State.STARTING_FRAME;
164                bytesToSend = 2;
165        }
166
167        @Override
168        public Result encode(Buffer in, ByteBuffer out, boolean endOfInput) {
169                if (closingState() == ClosingState.CLOSED) {
170                        // Must no longer send anything. 
171                        // If server (!doMask) close connection.
172                        return resultFactory().newResult(false, false, !doMask);
173                }
174                Result result = null;
175                while (out.remaining() > 0) {
176                        switch(state) {
177                        case STARTING_FRAME:
178                                prepareHeaderHead(in, endOfInput);
179                                // If called again without new message header...
180                                continuationFrame = true;
181                                state = State.WRITING_HEADER;
182                                // fall through
183                        case WRITING_HEADER:
184                                out.put((byte)(headerHead >> 8 * --bytesToSend));
185                                if (bytesToSend > 0) {
186                                        continue;
187                                }
188                                if (payloadBytes > 0) {
189                                        state = State.WRITING_LENGTH;
190                                        bytesToSend = payloadBytes;
191                                        continue;
192                                }
193                                // Length written
194                                result = nextAfterLength(endOfInput);
195                                break;
196                        case WRITING_LENGTH:
197                                out.put((byte)(payloadSize >> 8 * --bytesToSend));
198                                if (bytesToSend > 0) {
199                                        continue;
200                                }
201                                result = nextAfterLength(endOfInput);
202                                break;
203                        case WRITING_MASK:
204                                out.put(maskingKey[4 - (int)bytesToSend]);
205                                if (--bytesToSend > 0) {
206                                        continue;
207                                }
208                                result = nextAfterMask(endOfInput);
209                                break;
210                        case WRITING_PAYLOAD:
211                                int posBefore = out.position();
212                                outputPayload(in, out);
213                                bytesToSend -= (out.position() - posBefore);
214                                if (bytesToSend == 0) {
215                                        convData.clear();
216                                        return frameFinished(endOfInput);
217                                }
218                                return resultFactory().newResult(!out.hasRemaining(),
219                                                (messageHeaders.peek() instanceof WsMessageHeader) 
220                                                        && !in.hasRemaining(), false);
221                        }
222                        if (result != null) {
223                                return result;
224                        }
225                }
226                return resultFactory().newResult(true, false, false);
227        }
228
229        /**
230         * Prepares the start (head) of the header. As a side effect, if
231         * "in" holds textual data (or if the data is obtained from the
232         * to be encoded message header (close frame)) it is written into
233         * convData because this is the only way to "calculate" the payload 
234         * size. 
235         * 
236         * @param in input data
237         * @param endOfInput set if end of input
238         */
239        private void prepareHeaderHead(Buffer in, boolean endOfInput) {
240                WsFrameHeader hdr = messageHeaders.peek();
241                if (hdr instanceof WsMessageHeader) {
242                        if (continuationFrame) {
243                                headerHead = 0;
244                        }
245                        if (endOfInput) {
246                                headerHead |= 0x8000;
247                        }
248                        // Prepare payload
249                        if (in instanceof CharBuffer) {
250                                convData.clear();
251                                payloadSize = convTextData(in);
252                        } else {
253                                payloadSize = in.remaining();
254                        }
255                } else {
256                        // Control frame
257                        headerHead |= 0x8000;
258                        // Prepare payload
259                        if (hdr instanceof WsCloseFrame) {
260                                payloadSize = 0;
261                                ((WsCloseFrame)hdr).statusCode().ifPresent(code -> {
262                                        convData.clear();
263                                        try {
264                                                convData.write(code >> 8);
265                                                convData.write(code & 0xff);
266                                                payloadSize = 2;
267                                        } catch (IOException e) {
268                                                // Formally thrown, cannot happen
269                                        }
270                                });
271                                ((WsCloseFrame)hdr).reason().ifPresent(reason -> {
272                                        payloadSize = convTextData(CharBuffer.wrap(reason));
273                                });
274                        } else if (hdr instanceof WsDefaultControlFrame) {
275                                payloadSize = ((WsDefaultControlFrame)hdr)
276                                                .applicationData().map(ByteBuffer::remaining).orElse(0);
277                        }
278                }
279                
280                // Finally add mask bit
281                if (doMask) {
282                        headerHead |= 0x80;
283                        randoms.nextBytes(maskingKey);
284                }
285
286                // Code payload size
287                if (payloadSize <= 125) {
288                        headerHead |= payloadSize;
289                        payloadBytes = 0;
290                } else if (payloadSize < 0x10000) {
291                        headerHead |= 126;
292                        payloadBytes = 2;
293                } else {
294                        headerHead |= 127;
295                        payloadBytes = 8;
296                }
297        }
298
299        private long convTextData(Buffer in) {
300                convData.setOverflowBufferSize(
301                                (int) (in.remaining() * bytesPerCharUtf8));
302                try {
303                        OutputStreamWriter charWriter = new OutputStreamWriter(
304                                convData, "utf-8");
305                        if (in.hasArray()) {
306                                // more efficient than CharSequence
307                                charWriter.write(((CharBuffer) in).array(),
308                                        in.arrayOffset() + in.position(),
309                                        in.remaining());
310                        } else {
311                                charWriter.append((CharBuffer) in);
312                        }
313                        // "in" is consumed, but don't move the position
314                        // until all data has been processed (from convData).
315                        charWriter.flush();
316                        return convData.bytesWritten();
317                } catch (IOException e) {
318                        // Formally thrown, cannot happen
319                        return 0;
320                }
321        }
322        
323        private Result nextAfterLength(boolean endOfInput) {
324                if (doMask) {
325                        bytesToSend = 4;
326                        state = State.WRITING_MASK;
327                        return null;
328                }
329                return nextAfterMask(endOfInput);
330        }
331        
332        private Result nextAfterMask(boolean endOfInput) {
333                if (payloadSize == 0) {
334                        return frameFinished(endOfInput);
335                }
336                maskIndex = 0;
337                bytesToSend = payloadSize;
338                state = State.WRITING_PAYLOAD;
339                return null;
340        }
341        
342        /**
343         * Copy payload to "out". Note that if we have textual data
344         * or a close frame, data has already been written into
345         * convData (see {@link #prepareHeaderHead(Buffer, boolean)}.
346         *
347         * @param in the input data, unless already wriiten to convData 
348         * @param out the out
349         */
350        private void outputPayload(Buffer in, ByteBuffer out) {
351                // Default is to use data directly from in buffer.
352                Buffer src = in;
353                WsFrameHeader hdr = messageHeaders.peek();
354                boolean textPayload = (hdr instanceof WsMessageHeader) 
355                                && ((WsMessageHeader)hdr).isTextMode();
356                if (textPayload || (hdr instanceof WsCloseFrame)) {
357                        // Data has been put into convData
358                        if (!doMask) {
359                                // Moves data from temporary buffers to "out"
360                                convData.assignBuffer(out);
361                        } else {
362                                // Retrieve into src as much as fits in 
363                                // out buffer for masking.
364                                src = ByteBuffer.allocate(out.remaining());
365                                convData.assignBuffer((ByteBuffer)src);
366                                src.flip();
367                        }
368                        if (convData.remaining() >= 0 && textPayload) {
369                                // Make full consumption visible "outside",
370                                // see convTextData.
371                                in.position(in.limit());
372                        }
373                        if (!doMask) {
374                                return;
375                        }
376                } else {
377                        if (hdr instanceof WsDefaultControlFrame) {
378                                // Data is taken from control frame.
379                                src = ((WsDefaultControlFrame)hdr)
380                                                .applicationData().orElse(Codec.EMPTY_IN);
381                        }
382                        if (!doMask) {
383                                ByteBufferUtils.putAsMuchAsPossible(out, (ByteBuffer) src);
384                                return;
385                        }
386                }
387                // Mask while writing
388                while (bytesToSend > 0
389                        && src.hasRemaining() && out.hasRemaining()) {
390                        out.put((byte) (((ByteBuffer) src)
391                                .get() ^ maskingKey[maskIndex]));
392                        maskIndex = (maskIndex + 1) % 4;
393                }
394        }
395
396        /* (non-Javadoc)
397         * @see org.jdrupes.httpcodec.Decoder#getHeader()
398         */
399        @Override
400        public Optional<WsFrameHeader> header() {
401                try {
402                        return Optional.of(messageHeaders.peek());
403                } catch (EmptyStackException e) {
404                        return Optional.empty();
405                }
406        }
407
408        /**
409         * Results from {@link HttpEncoder} provide no additional
410         * information compared to {@link org.jdrupes.httpcodec.Codec.Result}. This
411         * class only provides a factory for creating concrete results.
412         */
413        public static class Result extends Codec.Result {
414        
415                protected Result(boolean overflow, boolean underflow,
416                        boolean closeConnection) {
417                        super(overflow, underflow, closeConnection);
418                }
419
420                /**
421                 * A factory for creating new Results.
422                 */
423                protected static class Factory extends Codec.Result.Factory {
424
425                        /**
426                         * Create new result.
427                         * 
428                         * @param overflow
429                         *            {@code true} if the data didn't fit in the out buffer
430                         * @param underflow
431                         *            {@code true} if more data is expected
432                         * @param closeConnection
433                         *            {@code true} if the connection should be closed
434                         * @return the result
435                         */
436                        public Result newResult(boolean overflow, boolean underflow,
437                                boolean closeConnection) {
438                                return new Result(overflow, underflow, closeConnection) {
439                                };
440                        }
441                }
442        }
443}