001/*
002 * This file is part of the JDrupes non-blocking HTTP Codec
003 * Copyright (C) 2016, 2017  Michael N. Lipp
004 *
005 * This program is free software; you can redistribute it and/or modify it 
006 * under the terms of the GNU Lesser General Public License as published
007 * by the Free Software Foundation; either version 3 of the License, or 
008 * (at your option) any later version.
009 *
010 * This program is distributed in the hope that it will be useful, but 
011 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
012 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 
013 * License for more details.
014 *
015 * You should have received a copy of the GNU Lesser General Public License along 
016 * with this program; if not, see <http://www.gnu.org/licenses/>.
017 */
018
019package org.jdrupes.httpcodec.protocols.websocket;
020
021import java.io.IOException;
022import java.io.OutputStreamWriter;
023import java.nio.Buffer;
024import java.nio.ByteBuffer;
025import java.nio.CharBuffer;
026import java.nio.charset.Charset;
027import java.security.SecureRandom;
028import java.util.EmptyStackException;
029import java.util.Optional;
030import java.util.Stack;
031
032import org.jdrupes.httpcodec.Codec;
033import org.jdrupes.httpcodec.Decoder;
034import org.jdrupes.httpcodec.Encoder;
035import org.jdrupes.httpcodec.protocols.http.HttpEncoder;
036import org.jdrupes.httpcodec.util.ByteBufferOutputStream;
037import org.jdrupes.httpcodec.util.ByteBufferUtils;
038
039/**
040 * The Websocket encoder.
041 */
042public class WsEncoder extends WsCodec 
043        implements Encoder<WsFrameHeader, WsFrameHeader> {
044
045        private static enum State { STARTING_FRAME, WRITING_HEADER,  
046                WRITING_LENGTH, WRITING_MASK, WRITING_PAYLOAD }
047        
048        private static float bytesPerCharUtf8           
049                = Charset.forName("utf-8").newEncoder().averageBytesPerChar();
050        private static final Result.Factory resultFactory = new Result.Factory();
051        
052        private SecureRandom randoms = new SecureRandom();
053        private State state = State.STARTING_FRAME;
054        private boolean continuationFrame;
055        private Stack<WsFrameHeader> messageHeaders = new Stack<>();
056        private int headerHead;
057        private long bytesToSend;
058        private long payloadSize;
059        private int payloadBytes;
060        private boolean doMask = false;
061        private byte[] maskingKey = new byte[4];
062        private int maskIndex;
063        private ByteBufferOutputStream convData = new ByteBufferOutputStream();
064
065        /**
066         * Creates new encoder.
067         * 
068         * @param mask set if the data is to be masked (client)
069         */
070        public WsEncoder(boolean mask) {
071                super();
072                this.doMask = mask;
073        }
074
075        public Encoder<WsFrameHeader, WsFrameHeader> setPeerDecoder(
076                        Decoder<WsFrameHeader, WsFrameHeader> decoder) {
077                linkClosingState((WsCodec)decoder);
078                return this;
079        }
080        
081        /**
082         * Returns the result factory for this codec.
083         * 
084         * @return the factory
085         */
086        protected Result.Factory resultFactory() {
087                return resultFactory;
088        }
089        
090        /* (non-Javadoc)
091         * @see org.jdrupes.httpcodec.Encoder#encoding()
092         */
093        @Override
094        public Class<WsFrameHeader> encoding() {
095                return WsFrameHeader.class;
096        }
097
098        private Result frameFinished(boolean endOfInput) {
099                // If we have encoded a close, adapt
100                boolean close = false;
101                if (messageHeaders.peek() instanceof WsCloseFrame) {
102                        switch (closingState()) {
103                        case OPEN:
104                                setClosingState(ClosingState.CLOSE_SENT);
105                                break;
106                        case CLOSE_RECEIVED:
107                                setClosingState(ClosingState.CLOSED);
108                                // fall through
109                        case CLOSED:
110                                if (!doMask) {
111                                        // Server side encoder
112                                        close = true;
113                                }
114                                break;
115                        case CLOSE_SENT:
116                                // Shouldn't happen
117                                break;
118                        }
119                }
120                // Fix statck
121                if (!(messageHeaders.peek() instanceof WsMessageHeader) 
122                                || endOfInput) {
123                        messageHeaders.pop();
124                }
125                state = State.STARTING_FRAME;
126                bytesToSend = 2;
127                return resultFactory().newResult(false, 
128                                !endOfInput || !messageHeaders.isEmpty(), close);
129        }
130        
131        /* (non-Javadoc)
132         * @see org.jdrupes.httpcodec.ResponseEncoder#encode(org.jdrupes.httpcodec.MessageHeader)
133         */
134        @Override
135        public void encode(WsFrameHeader messageHeader) {
136                if (state != State.STARTING_FRAME) {
137                        throw new IllegalStateException(
138                                        "Trying to start new frame while previous "
139                                                + "has not completely been sent");
140                }
141                if (messageHeader instanceof WsMessageHeader) {
142                        messageHeaders.clear();
143                        messageHeaders.push(messageHeader);
144                        if (((WsMessageHeader) messageHeader).isTextMode()) {
145                                headerHead = (1 << 8);
146                        } else {
147                                headerHead = (2 << 8);
148                        }
149                        continuationFrame = false;
150                } else {
151                        messageHeaders.push(messageHeader);
152                        if (messageHeader instanceof WsCloseFrame) {
153                                headerHead = (8 << 8);
154                        } else if (messageHeader instanceof WsPingFrame) {
155                                headerHead = (9 << 8);
156                        } else if (messageHeader instanceof WsPongFrame) {
157                                headerHead = (10 << 8);
158                        } else {
159                                throw new IllegalArgumentException(
160                                        "Invalid hessage header type");
161                        }
162                }
163                state = State.STARTING_FRAME;
164                bytesToSend = 2;
165        }
166
167        @Override
168        public Result encode(Buffer in, ByteBuffer out, boolean endOfInput) {
169                if (closingState() == ClosingState.CLOSED) {
170                        // Must no longer send anything. 
171                        // If server (!doMask) close connection.
172                        return resultFactory().newResult(false, false, !doMask);
173                }
174                Result result = null;
175                while (out.remaining() > 0) {
176                        switch(state) {
177                        case STARTING_FRAME:
178                                prepareHeaderHead(in, endOfInput);
179                                // If called again without new message header...
180                                continuationFrame = true;
181                                state = State.WRITING_HEADER;
182                                // fall through
183                        case WRITING_HEADER:
184                                out.put((byte)(headerHead >> 8 * --bytesToSend));
185                                if (bytesToSend > 0) {
186                                        continue;
187                                }
188                                if (payloadBytes > 0) {
189                                        state = State.WRITING_LENGTH;
190                                        bytesToSend = payloadBytes;
191                                        continue;
192                                }
193                                // Length written
194                                result = nextAfterLength(endOfInput);
195                                break;
196                        case WRITING_LENGTH:
197                                out.put((byte)(payloadSize >> 8 * --bytesToSend));
198                                if (bytesToSend > 0) {
199                                        continue;
200                                }
201                                result = nextAfterLength(endOfInput);
202                                break;
203                        case WRITING_MASK:
204                                out.put(maskingKey[4 - (int)bytesToSend]);
205                                if (--bytesToSend > 0) {
206                                        continue;
207                                }
208                                result = nextAfterMask(endOfInput);
209                                break;
210                        case WRITING_PAYLOAD:
211                                int posBefore = out.position();
212                                outputPayload(in, out);
213                                bytesToSend -= (out.position() - posBefore);
214                                if (bytesToSend == 0) {
215                                        convData.clear();
216                                        return frameFinished(endOfInput);
217                                }
218                                return resultFactory().newResult(!out.hasRemaining(),
219                                                (messageHeaders.peek() instanceof WsMessageHeader) 
220                                                        && !in.hasRemaining(), false);
221                        }
222                        if (result != null) {
223                                return result;
224                        }
225                }
226                return resultFactory().newResult(true, false, false);
227        }
228
229        /**
230         * Prepares the start (head) of the header. As a side effect, if
231         * "in" holds textual data (or if the data is obtained from the
232         * to be encoded message header (close frame)) it is written into
233         * convData because this is the only way to "calculate" the payload 
234         * size. 
235         * 
236         * @param in input data
237         * @param endOfInput set if end of input
238         */
239        private void prepareHeaderHead(Buffer in, boolean endOfInput) {
240                WsFrameHeader hdr = messageHeaders.peek();
241                if (hdr instanceof WsMessageHeader) {
242                        if (continuationFrame) {
243                                headerHead = 0;
244                        }
245                        if (endOfInput) {
246                                headerHead |= 0x8000;
247                        }
248                        // Prepare payload
249                        if (in instanceof CharBuffer) {
250                                convData.clear();
251                                convTextData(in);
252                        } else {
253                                payloadSize = in.remaining();
254                        }
255                } else {
256                        // Control frame
257                        headerHead |= 0x8000;
258                        // Prepare payload
259                        if (hdr instanceof WsCloseFrame) {
260                                payloadSize = 0;
261                                ((WsCloseFrame)hdr).statusCode().ifPresent(code -> {
262                                        convData.clear();
263                                        try {
264                                                convData.write(code >> 8);
265                                                convData.write(code & 0xff);
266                                                payloadSize = 2;
267                                        } catch (IOException e) {
268                                                // Formally thrown, cannot happen
269                                        }
270                                });
271                                ((WsCloseFrame)hdr).reason().ifPresent(reason -> {
272                                                convTextData(CharBuffer.wrap(reason));
273                                });
274                        } else if (hdr instanceof WsDefaultControlFrame) {
275                                payloadSize = ((WsDefaultControlFrame)hdr)
276                                                .applicationData().map(ByteBuffer::remaining).orElse(0);
277                        }
278                }
279                
280                // Finally add mask bit
281                if (doMask) {
282                        headerHead |= 0x80;
283                        randoms.nextBytes(maskingKey);
284                }
285
286                // Code payload size
287                if (payloadSize <= 125) {
288                        headerHead |= payloadSize;
289                        payloadBytes = 0;
290                } else if (payloadSize < 0x10000) {
291                        headerHead |= 126;
292                        payloadBytes = 2;
293                } else {
294                        headerHead |= 127;
295                        payloadBytes = 8;
296                }
297        }
298
299        private void convTextData(Buffer in) {
300                convData.setOverflowBufferSize(
301                                (int) (in.remaining() * bytesPerCharUtf8));
302                try {
303                        OutputStreamWriter charWriter = new OutputStreamWriter(
304                                convData, "utf-8");
305                        if (in.hasArray()) {
306                                // more efficient than CharSequence
307                                charWriter.write(((CharBuffer) in).array(),
308                                        in.arrayOffset() + in.position(),
309                                        in.remaining());
310                        } else {
311                                charWriter.append((CharBuffer) in);
312                        }
313                        in.position(in.limit());
314                        charWriter.flush();
315                        payloadSize = convData.bytesWritten();
316                } catch (IOException e) {
317                        // Formally thrown, cannot happen
318                }
319        }
320        
321        private Result nextAfterLength(boolean endOfInput) {
322                if (doMask) {
323                        bytesToSend = 4;
324                        state = State.WRITING_MASK;
325                        return null;
326                }
327                return nextAfterMask(endOfInput);
328        }
329        
330        private Result nextAfterMask(boolean endOfInput) {
331                if (payloadSize == 0) {
332                        return frameFinished(endOfInput);
333                }
334                maskIndex = 0;
335                bytesToSend = payloadSize;
336                state = State.WRITING_PAYLOAD;
337                return null;
338        }
339        
340        /**
341         * Copy payload to "out". Note that if we have textual data
342         * or a close frame, data has already been written into
343         * convData (see {@link #prepareHeaderHead(Buffer, boolean)}.
344         *
345         * @param in the input data, unless already wriiten to convData 
346         * @param out the out
347         */
348        private void outputPayload(Buffer in, ByteBuffer out) {
349                WsFrameHeader hdr = messageHeaders.peek();
350                if ((hdr instanceof WsMessageHeader) 
351                                && ((WsMessageHeader)hdr).isTextMode()
352                                || (hdr instanceof WsCloseFrame)) {
353                        // Data has been put into convData
354                        if (!doMask) {
355                                // Moves data from temporary buffers to "out"
356                                convData.assignBuffer(out);
357                                return;
358                        }
359                        // Retrieving as much as fits in out buffer for masking
360                        in = ByteBuffer.allocate(out.remaining());
361                        convData.assignBuffer((ByteBuffer)in);
362                        in.flip();
363                } else {
364                        // Take data from in
365                        if (hdr instanceof WsDefaultControlFrame) {
366                                in = ((WsDefaultControlFrame)hdr)
367                                                .applicationData().orElse(Codec.EMPTY_IN);
368                        }
369                        if (!doMask) {
370                                ByteBufferUtils.putAsMuchAsPossible(out, (ByteBuffer) in);
371                                return;
372                        }
373                }
374                // Mask while writing
375                while (bytesToSend > 0
376                        && in.hasRemaining() && out.hasRemaining()) {
377                        out.put((byte) (((ByteBuffer) in)
378                                .get() ^ maskingKey[maskIndex]));
379                        maskIndex = (maskIndex + 1) % 4;
380                }
381        }
382
383        /* (non-Javadoc)
384         * @see org.jdrupes.httpcodec.Decoder#getHeader()
385         */
386        @Override
387        public Optional<WsFrameHeader> header() {
388                try {
389                        return Optional.of(messageHeaders.peek());
390                } catch (EmptyStackException e) {
391                        return Optional.empty();
392                }
393        }
394
395        /**
396         * Results from {@link HttpEncoder} provide no additional
397         * information compared to {@link org.jdrupes.httpcodec.Codec.Result}. This
398         * class only provides a factory for creating concrete results.
399         */
400        public static class Result extends Codec.Result {
401        
402                protected Result(boolean overflow, boolean underflow,
403                        boolean closeConnection) {
404                        super(overflow, underflow, closeConnection);
405                }
406
407                /**
408                 * A factory for creating new Results.
409                 */
410                protected static class Factory extends Codec.Result.Factory {
411
412                        /**
413                         * Create new result.
414                         * 
415                         * @param overflow
416                         *            {@code true} if the data didn't fit in the out buffer
417                         * @param underflow
418                         *            {@code true} if more data is expected
419                         * @param closeConnection
420                         *            {@code true} if the connection should be closed
421                         * @return the result
422                         */
423                        public Result newResult(boolean overflow, boolean underflow,
424                                boolean closeConnection) {
425                                return new Result(overflow, underflow, closeConnection) {
426                                };
427                        }
428                }
429        }
430}