"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.WsHandlerLookup = exports.TimeoutHandler = exports.ResetConnectionHandler = exports.CloseConnectionHandler = exports.RejectWebSocketHandler = exports.ListenWebSocketHandler = exports.EchoWebSocketHandler = exports.PassThroughWebSocketHandler = void 0;
const _ = require("lodash");
const url = require("url");
const tls = require("tls");
const fs = require("fs/promises");
const WebSocket = require("ws");
const serialization_1 = require("../../serialization/serialization");
const request_handlers_1 = require("../requests/request-handlers");
Object.defineProperty(exports, "CloseConnectionHandler", { enumerable: true, get: function () { return request_handlers_1.CloseConnectionHandler; } });
Object.defineProperty(exports, "ResetConnectionHandler", { enumerable: true, get: function () { return request_handlers_1.ResetConnectionHandler; } });
Object.defineProperty(exports, "TimeoutHandler", { enumerable: true, get: function () { return request_handlers_1.TimeoutHandler; } });
const request_utils_1 = require("../../util/request-utils");
const header_utils_1 = require("../../util/header-utils");
const buffer_utils_1 = require("../../util/buffer-utils");
const http_agents_1 = require("../http-agents");
const rule_parameters_1 = require("../rule-parameters");
const passthrough_handling_1 = require("../passthrough-handling");
const websocket_handler_definitions_1 = require("./websocket-handler-definitions");
function isOpen(socket) {
    return socket.readyState === WebSocket.OPEN;
}
// Based on ws's validation.js
function isValidStatusCode(code) {
    return ( // Standard code:
    code >= 1000 &&
        code <= 1014 &&
        code !== 1004 &&
        code !== 1005 &&
        code !== 1006) || ( // Application-specific code:
    code >= 3000 && code <= 4999);
}
const INVALID_STATUS_REGEX = /Invalid WebSocket frame: invalid status code (\d+)/;
function pipeWebSocket(inSocket, outSocket) {
    const onPipeFailed = (op) => (err) => {
        if (!err)
            return;
        inSocket.close();
        console.error(`Websocket ${op} failed`, err);
    };
    inSocket.on('message', (msg, isBinary) => {
        if (isOpen(outSocket)) {
            outSocket.send(msg, { binary: isBinary }, onPipeFailed('message'));
        }
    });
    inSocket.on('close', (num, reason) => {
        if (isValidStatusCode(num)) {
            try {
                outSocket.close(num, reason);
            }
            catch (e) {
                console.warn(e);
                outSocket.close();
            }
        }
        else {
            outSocket.close();
        }
    });
    inSocket.on('ping', (data) => {
        if (isOpen(outSocket))
            outSocket.ping(data, undefined, onPipeFailed('ping'));
    });
    inSocket.on('pong', (data) => {
        if (isOpen(outSocket))
            outSocket.pong(data, undefined, onPipeFailed('pong'));
    });
    // If either socket has an general error (connection failure, but also could be invalid WS
    // frames) then we kill the raw connection upstream to simulate a generic connection error:
    inSocket.on('error', (err) => {
        console.log(`Error in proxied WebSocket:`, err);
        const rawOutSocket = outSocket;
        if (err.message.match(INVALID_STATUS_REGEX)) {
            const status = parseInt(INVALID_STATUS_REGEX.exec(err.message)[1]);
            // Simulate errors elsewhere by messing with ws internals. This may break things,
            // that's effectively on purpose: we're simulating the client going wrong:
            const buf = Buffer.allocUnsafe(2);
            buf.writeUInt16BE(status); // status comes from readUInt16BE, so always fits
            const sender = rawOutSocket._sender;
            sender.sendFrame(sender.constructor.frame(buf, {
                fin: true,
                rsv1: false,
                opcode: 0x08,
                mask: true,
                readOnly: false
            }), () => {
                rawOutSocket._socket.destroy();
            });
        }
        else {
            // Unknown error, just kill the connection with no explanation
            rawOutSocket._socket.destroy();
        }
    });
}
async function mirrorRejection(socket, rejectionResponse) {
    if (socket.writable) {
        const { statusCode, statusMessage, rawHeaders } = rejectionResponse;
        socket.write(rawResponse(statusCode || 500, statusMessage || 'Unknown error', (0, header_utils_1.pairFlatRawHeaders)(rawHeaders)));
        const body = await (0, buffer_utils_1.streamToBuffer)(rejectionResponse);
        if (socket.writable)
            socket.write(body);
    }
    socket.destroy();
}
const rawResponse = (statusCode, statusMessage, headers = []) => `HTTP/1.1 ${statusCode} ${statusMessage}\r\n` +
    _.map(headers, ([key, value]) => `${key}: ${value}`).join('\r\n') +
    '\r\n\r\n';
class PassThroughWebSocketHandler extends websocket_handler_definitions_1.PassThroughWebSocketHandlerDefinition {
    initializeWsServer() {
        if (this.wsServer)
            return;
        this.wsServer = new WebSocket.Server({
            noServer: true,
            // Mirror subprotocols back to the client:
            handleProtocols(protocols, request) {
                return request.upstreamWebSocketProtocol
                    // If there's no upstream socket, default to mirroring the first protocol. This matches
                    // WS's default behaviour - we could be stricter, but it'd be a breaking change.
                    ?? protocols.values().next().value
                    ?? false; // If there were no protocols specific and this is called for some reason
            },
        });
        this.wsServer.on('connection', (ws) => {
            pipeWebSocket(ws, ws.upstreamWebSocket);
            pipeWebSocket(ws.upstreamWebSocket, ws);
        });
    }
    async trustedCACertificates() {
        if (!this.extraCACertificates.length)
            return undefined;
        if (!this._trustedCACertificates) {
            this._trustedCACertificates = Promise.all(tls.rootCertificates
                .concat(this.extraCACertificates.map(certObject => {
                if ('cert' in certObject) {
                    return certObject.cert.toString('utf8');
                }
                else {
                    return fs.readFile(certObject.certPath, 'utf8');
                }
            })));
        }
        return this._trustedCACertificates;
    }
    async handle(req, socket, head) {
        this.initializeWsServer();
        let { protocol, hostname, port, path } = url.parse(req.url);
        const rawHeaders = req.rawHeaders;
        const reqMessage = req;
        const isH2Downstream = (0, request_utils_1.isHttp2)(req);
        const hostHeaderName = isH2Downstream ? ':authority' : 'host';
        hostname = await (0, passthrough_handling_1.getClientRelativeHostname)(hostname, req.remoteIpAddress, (0, passthrough_handling_1.getDnsLookupFunction)(this.lookupOptions));
        if (this.forwarding) {
            const { targetHost, updateHostHeader } = this.forwarding;
            let wsUrl;
            if (!targetHost.includes('/')) {
                // We're forwarding to a bare hostname, just overwrite that bit:
                [hostname, port] = targetHost.split(':');
            }
            else {
                // Forwarding to a full URL; override the host & protocol, but never the path.
                ({ protocol, hostname, port } = url.parse(targetHost));
            }
            // Connect directly to the forwarding target URL
            wsUrl = `${protocol}//${hostname}${port ? ':' + port : ''}${path}`;
            // Optionally update the host header too:
            let hostHeader = (0, header_utils_1.findRawHeader)(rawHeaders, hostHeaderName);
            if (!hostHeader) {
                // Should never happen really, but just in case:
                hostHeader = [hostHeaderName, hostname];
                rawHeaders.unshift(hostHeader);
            }
            ;
            if (updateHostHeader === undefined || updateHostHeader === true) {
                // If updateHostHeader is true, or just not specified, match the new target
                hostHeader[1] = hostname + (port ? `:${port}` : '');
            }
            else if (updateHostHeader) {
                // If it's an explicit custom value, use that directly.
                hostHeader[1] = updateHostHeader;
            } // Otherwise: falsey means don't touch it.
            await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
        }
        else if (!hostname) { // No hostname in URL means transparent proxy, so use Host header
            const hostHeader = req.headers[hostHeaderName];
            [hostname, port] = hostHeader.split(':');
            // __lastHopEncrypted is set in http-combo-server, for requests that have explicitly
            // CONNECTed upstream (which may then up/downgrade from the current encryption).
            if (socket.__lastHopEncrypted !== undefined) {
                protocol = socket.__lastHopEncrypted ? 'wss' : 'ws';
            }
            else {
                protocol = reqMessage.connection.encrypted ? 'wss' : 'ws';
            }
            const wsUrl = `${protocol}://${hostname}${port ? ':' + port : ''}${path}`;
            await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
        }
        else {
            // Connect directly according to the specified URL
            const wsUrl = `${protocol.replace('http', 'ws')}//${hostname}${port ? ':' + port : ''}${path}`;
            await this.connectUpstream(wsUrl, reqMessage, rawHeaders, socket, head);
        }
    }
    async connectUpstream(wsUrl, req, rawHeaders, incomingSocket, head) {
        const parsedUrl = url.parse(wsUrl);
        const effectivePort = (0, request_utils_1.getEffectivePort)(parsedUrl);
        const strictHttpsChecks = (0, passthrough_handling_1.shouldUseStrictHttps)(parsedUrl.hostname, effectivePort, this.ignoreHostHttpsErrors);
        // Use a client cert if it's listed for the host+port or whole hostname
        const hostWithPort = `${parsedUrl.hostname}:${effectivePort}`;
        const clientCert = this.clientCertificateHostMap[hostWithPort] ||
            this.clientCertificateHostMap[parsedUrl.hostname] ||
            {};
        const trustedCerts = await this.trustedCACertificates();
        const caConfig = trustedCerts
            ? { ca: trustedCerts }
            : {};
        const proxySettingSource = (0, rule_parameters_1.assertParamDereferenced)(this.proxyConfig);
        const agent = await (0, http_agents_1.getAgent)({
            protocol: parsedUrl.protocol,
            hostname: parsedUrl.hostname,
            port: effectivePort,
            proxySettingSource,
            tryHttp2: false,
            keepAlive: false // Not a thing for websockets: they take over the whole connection
        });
        // We have to flatten the headers, as WS doesn't support raw headers - it builds its own
        // header object internally.
        const headers = (0, header_utils_1.rawHeadersToObjectPreservingCase)(rawHeaders);
        // Subprotocols have to be handled explicitly. WS takes control of the headers itself,
        // and checks the response, so we need to parse the client headers and use them manually:
        const originalSubprotocols = (0, header_utils_1.findRawHeaders)(rawHeaders, 'sec-websocket-protocol')
            .flatMap(([_k, value]) => value.split(',').map(p => p.trim()));
        // Drop empty subprotocols, to better handle mildly badly behaved clients
        const filteredSubprotocols = originalSubprotocols.filter(p => !!p);
        // If the subprotocols are invalid (there are some empty strings, or an entirely empty value) then
        // WS will reject the upgrade. With this, we reset the header to the 'equivalent' valid version, to
        // avoid unnecessarily rejecting clients who send mildly wrong headers (empty protocol values).
        if (originalSubprotocols.length !== filteredSubprotocols.length) {
            if (filteredSubprotocols.length) {
                // Note that req.headers is auto-lowercased by Node, so we can ignore case
                req.headers['sec-websocket-protocol'] = filteredSubprotocols.join(',');
            }
            else {
                delete req.headers['sec-websocket-protocol'];
            }
        }
        const upstreamWebSocket = new WebSocket(wsUrl, filteredSubprotocols, {
            maxPayload: 0,
            agent,
            lookup: (0, passthrough_handling_1.getDnsLookupFunction)(this.lookupOptions),
            headers: _.omitBy(headers, (_v, headerName) => headerName.toLowerCase().startsWith('sec-websocket') ||
                headerName.toLowerCase() === 'connection' ||
                headerName.toLowerCase() === 'upgrade'),
            // TLS options:
            ...(0, passthrough_handling_1.getUpstreamTlsOptions)(strictHttpsChecks),
            ...clientCert,
            ...caConfig
        });
        upstreamWebSocket.once('open', () => {
            // Used in the subprotocol selection handler during the upgrade:
            req.upstreamWebSocketProtocol = upstreamWebSocket.protocol || false;
            this.wsServer.handleUpgrade(req, incomingSocket, head, (ws) => {
                ws.upstreamWebSocket = upstreamWebSocket;
                incomingSocket.emit('ws-upgrade', ws);
                this.wsServer.emit('connection', ws); // This pipes the connections together
            });
        });
        // If the upstream says no, we say no too.
        let unexpectedResponse = false;
        upstreamWebSocket.on('unexpected-response', (req, res) => {
            console.log(`Unexpected websocket response from ${wsUrl}: ${res.statusCode}`);
            // Clean up the downstream connection
            mirrorRejection(incomingSocket, res);
            // Clean up the upstream connection (WS would do this automatically, but doesn't if you listen to this event)
            // See https://github.com/websockets/ws/blob/45e17acea791d865df6b255a55182e9c42e5877a/lib/websocket.js#L1050
            // We don't match that perfectly, but this should be effectively equivalent:
            req.destroy();
            if (req.socket && !req.socket.destroyed) {
                res.socket.destroy();
            }
            unexpectedResponse = true; // So that we ignore this in the error handler
            upstreamWebSocket.terminate();
        });
        // If there's some other error, we just kill the socket:
        upstreamWebSocket.on('error', (e) => {
            if (unexpectedResponse)
                return; // Handled separately above
            console.warn(e);
            incomingSocket.end();
        });
        incomingSocket.on('error', () => upstreamWebSocket.close(1011)); // Internal error
    }
    /**
     * @internal
     */
    static deserialize(data, channel, ruleParams) {
        // By default, we assume we just need to assign the right prototype
        return _.create(this.prototype, {
            ...data,
            proxyConfig: (0, serialization_1.deserializeProxyConfig)(data.proxyConfig, channel, ruleParams),
            extraCACertificates: data.extraCACertificates || [],
            ignoreHostHttpsErrors: data.ignoreHostCertificateErrors,
            clientCertificateHostMap: _.mapValues(data.clientCertificateHostMap, ({ pfx, passphrase }) => ({ pfx: (0, serialization_1.deserializeBuffer)(pfx), passphrase })),
        });
    }
}
exports.PassThroughWebSocketHandler = PassThroughWebSocketHandler;
class EchoWebSocketHandler extends websocket_handler_definitions_1.EchoWebSocketHandlerDefinition {
    initializeWsServer() {
        if (this.wsServer)
            return;
        this.wsServer = new WebSocket.Server({ noServer: true });
        this.wsServer.on('connection', (ws) => {
            pipeWebSocket(ws, ws);
        });
    }
    async handle(req, socket, head) {
        this.initializeWsServer();
        this.wsServer.handleUpgrade(req, socket, head, (ws) => {
            socket.emit('ws-upgrade', ws);
            this.wsServer.emit('connection', ws);
        });
    }
}
exports.EchoWebSocketHandler = EchoWebSocketHandler;
class ListenWebSocketHandler extends websocket_handler_definitions_1.ListenWebSocketHandlerDefinition {
    initializeWsServer() {
        if (this.wsServer)
            return;
        this.wsServer = new WebSocket.Server({ noServer: true });
        this.wsServer.on('connection', (ws) => {
            // Accept but ignore the incoming websocket data
            ws.resume();
        });
    }
    async handle(req, socket, head) {
        this.initializeWsServer();
        this.wsServer.handleUpgrade(req, socket, head, (ws) => {
            socket.emit('ws-upgrade', ws);
            this.wsServer.emit('connection', ws);
        });
    }
}
exports.ListenWebSocketHandler = ListenWebSocketHandler;
class RejectWebSocketHandler extends websocket_handler_definitions_1.RejectWebSocketHandlerDefinition {
    async handle(req, socket, head) {
        socket.write(rawResponse(this.statusCode, this.statusMessage, (0, header_utils_1.objectHeadersToRaw)(this.headers)));
        if (this.body)
            socket.write(this.body);
        socket.write('\r\n');
        socket.destroy();
    }
}
exports.RejectWebSocketHandler = RejectWebSocketHandler;
exports.WsHandlerLookup = {
    'ws-passthrough': PassThroughWebSocketHandler,
    'ws-echo': EchoWebSocketHandler,
    'ws-listen': ListenWebSocketHandler,
    'ws-reject': RejectWebSocketHandler,
    'close-connection': request_handlers_1.CloseConnectionHandler,
    'reset-connection': request_handlers_1.ResetConnectionHandler,
    'timeout': request_handlers_1.TimeoutHandler
};
//# sourceMappingURL=websocket-handlers.js.map