/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.milo.opcua.stack.core.channel;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.util.ReferenceCountUtil;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.util.List;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.eclipse.milo.opcua.stack.core.UaException;
import org.eclipse.milo.opcua.stack.core.channel.ChannelParameters;
import org.eclipse.milo.opcua.stack.core.channel.ChannelSecurity;
import org.eclipse.milo.opcua.stack.core.channel.EncodingLimits;
import org.eclipse.milo.opcua.stack.core.channel.MessageAbortException;
import org.eclipse.milo.opcua.stack.core.channel.MessageDecodeException;
import org.eclipse.milo.opcua.stack.core.channel.SecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import org.eclipse.milo.opcua.stack.core.channel.headers.SequenceHeader;
import org.eclipse.milo.opcua.stack.core.channel.headers.SymmetricSecurityHeader;
import org.eclipse.milo.opcua.stack.core.channel.messages.ErrorMessage;
import org.eclipse.milo.opcua.stack.core.security.SecurityAlgorithm;
import org.eclipse.milo.opcua.stack.core.util.BufferUtil;
import org.eclipse.milo.opcua.stack.core.util.SignatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ChunkDecoder {
    private final AsymmetricDecoder asymmetricDecoder = new AsymmetricDecoder();
    private final SymmetricDecoder symmetricDecoder = new SymmetricDecoder();
    private volatile long lastSequenceNumber = -1L;
    private final ChannelParameters parameters;
    private final EncodingLimits encodingLimits;

    public ChunkDecoder(ChannelParameters parameters, EncodingLimits encodingLimits) {
        this.parameters = parameters;
        this.encodingLimits = encodingLimits;
    }

    public DecodedMessage decodeAsymmetric(SecureChannel channel, List<ByteBuf> chunkBuffers) throws MessageAbortException, MessageDecodeException {
        return ChunkDecoder.decode(this.asymmetricDecoder, channel, chunkBuffers);
    }

    public DecodedMessage decodeSymmetric(SecureChannel channel, List<ByteBuf> chunkBuffers) throws MessageAbortException, MessageDecodeException {
        try {
            ChunkDecoder.validateSymmetricSecurityHeaders(channel, chunkBuffers);
        }
        catch (UaException e) {
            chunkBuffers.forEach(ReferenceCountUtil::safeRelease);
            throw new MessageDecodeException(e);
        }
        return ChunkDecoder.decode(this.symmetricDecoder, channel, chunkBuffers);
    }

    private static DecodedMessage decode(AbstractDecoder decoder, SecureChannel channel, List<ByteBuf> chunkBuffers) throws MessageAbortException, MessageDecodeException {
        CompositeByteBuf composite = BufferUtil.compositeBuffer();
        try {
            return decoder.decode(channel, composite, chunkBuffers);
        }
        catch (MessageAbortException e) {
            ReferenceCountUtil.safeRelease((Object)composite);
            chunkBuffers.forEach(ReferenceCountUtil::safeRelease);
            throw e;
        }
        catch (UaException e) {
            ReferenceCountUtil.safeRelease((Object)composite);
            chunkBuffers.forEach(ReferenceCountUtil::safeRelease);
            throw new MessageDecodeException(e);
        }
    }

    private static void validateSymmetricSecurityHeaders(SecureChannel secureChannel, List<ByteBuf> chunkBuffers) throws UaException {
        ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();
        long currentTokenId = channelSecurity.getCurrentToken().getTokenId().longValue();
        long previousTokenId = channelSecurity.getPreviousToken().map(t -> t.getTokenId().longValue()).orElse(-1L);
        for (ByteBuf chunkBuffer : chunkBuffers) {
            long tokenId = chunkBuffer.getUnsignedIntLE(12);
            if (tokenId == currentTokenId || tokenId == previousTokenId) continue;
            String message = String.format("received unknown secure channel token: tokenId=%s currentTokenId=%s previousTokenId=%s", tokenId, currentTokenId, previousTokenId);
            throw new UaException(0x80870000L, message);
        }
    }

    static class LegacySequenceNumberValidator {
        private LegacySequenceNumberValidator() {
        }

        static boolean validateSequenceNumber(long lastSequenceNumber, long sequenceNumber) {
            if (lastSequenceNumber == -1L) {
                return true;
            }
            if (lastSequenceNumber >= 0xFFFFFBFFL && lastSequenceNumber < 0xFFFFFFFFL) {
                return sequenceNumber < 1024L || sequenceNumber == lastSequenceNumber + 1L;
            }
            if (lastSequenceNumber == 0xFFFFFFFFL) {
                return sequenceNumber >= 0L && sequenceNumber < 1024L;
            }
            return sequenceNumber == lastSequenceNumber + 1L;
        }
    }

    private final class SymmetricDecoder
    extends AbstractDecoder {
        private volatile ChannelSecurity.SecurityKeys securityKeys;
        private volatile Cipher cipher;
        private volatile long cipherId;

        private SymmetricDecoder() {
            this.cipher = null;
            this.cipherId = -1L;
        }

        @Override
        public void readSecurityHeader(SecureChannel channel, ByteBuf chunkBuffer) throws UaException {
            long receivedTokenId = SymmetricSecurityHeader.decode(chunkBuffer).getTokenId();
            ChannelSecurity channelSecurity = channel.getChannelSecurity();
            if (channelSecurity == null) {
                if (receivedTokenId != 0L) {
                    throw new UaException(0x80870000L, "unknown secure channel token: " + receivedTokenId);
                }
            } else {
                long currentTokenId = channelSecurity.getCurrentToken().getTokenId().longValue();
                if (receivedTokenId == currentTokenId) {
                    this.securityKeys = channelSecurity.getCurrentKeys();
                } else {
                    long previousTokenId = channelSecurity.getPreviousToken().map(t -> t.getTokenId().longValue()).orElse(-1L);
                    this.logger.debug("Attempting to use SecurityKeys from previousTokenId={}", (Object)previousTokenId);
                    if (receivedTokenId != previousTokenId) {
                        this.logger.warn("receivedTokenId={} did not match previousTokenId={}", (Object)receivedTokenId, (Object)previousTokenId);
                        throw new UaException(0x80870000L, "unknown secure channel token: " + receivedTokenId);
                    }
                    if (channel.isSymmetricEncryptionEnabled() && channelSecurity.getPreviousKeys().isPresent()) {
                        this.securityKeys = channelSecurity.getPreviousKeys().get();
                    }
                }
                if (this.cipherId != receivedTokenId && channel.isSymmetricEncryptionEnabled()) {
                    this.cipher = this.initCipher(channel);
                    this.cipherId = receivedTokenId;
                }
            }
        }

        @Override
        public Cipher getCipher(SecureChannel channel) {
            assert (this.cipher != null);
            return this.cipher;
        }

        @Override
        public int getCipherTextBlockSize(SecureChannel channel) {
            return channel.getSymmetricBlockSize();
        }

        @Override
        public int getSignatureSize(SecureChannel channel) {
            return channel.getSymmetricSignatureSize();
        }

        @Override
        public void verifyChunk(SecureChannel channel, ByteBuf chunkBuffer) throws UaException {
            SecurityAlgorithm securityAlgorithm = channel.getSecurityPolicy().getSymmetricSignatureAlgorithm();
            byte[] secretKey = channel.getDecryptionKeys(this.securityKeys).getSignatureKey();
            int signatureSize = channel.getSymmetricSignatureSize();
            ByteBuffer chunkNioBuffer = chunkBuffer.nioBuffer(0, chunkBuffer.writerIndex());
            ((Buffer)chunkNioBuffer).position(0);
            ((Buffer)chunkNioBuffer).limit(chunkBuffer.writerIndex() - signatureSize);
            byte[] signature = SignatureUtil.hmac(securityAlgorithm, secretKey, chunkNioBuffer);
            byte[] signatureBytes = new byte[signatureSize];
            ((Buffer)chunkNioBuffer).limit(chunkNioBuffer.position() + signatureSize);
            chunkNioBuffer.get(signatureBytes);
            if (!MessageDigest.isEqual(signature, signatureBytes)) {
                throw new UaException(2148728832L, "could not verify signature");
            }
        }

        @Override
        protected boolean isAsymmetric() {
            return false;
        }

        @Override
        public boolean isEncryptionEnabled(SecureChannel channel) {
            return channel.isSymmetricEncryptionEnabled();
        }

        @Override
        public boolean isSigningEnabled(SecureChannel channel) {
            return channel.isSymmetricSigningEnabled();
        }

        private Cipher initCipher(SecureChannel channel) throws UaException {
            try {
                String transformation = channel.getSecurityPolicy().getSymmetricEncryptionAlgorithm().getTransformation();
                ChannelSecurity.SecretKeys decryptionKeys = channel.getDecryptionKeys(this.securityKeys);
                SecretKeySpec keySpec = new SecretKeySpec(decryptionKeys.getEncryptionKey(), "AES");
                IvParameterSpec ivSpec = new IvParameterSpec(decryptionKeys.getInitializationVector());
                Cipher cipher = Cipher.getInstance(transformation);
                cipher.init(2, (Key)keySpec, ivSpec);
                return cipher;
            }
            catch (GeneralSecurityException e) {
                throw new UaException(0x80020000L, (Throwable)e);
            }
        }
    }

    private final class AsymmetricDecoder
    extends AbstractDecoder {
        private AsymmetricDecoder() {
        }

        @Override
        public void readSecurityHeader(SecureChannel channel, ByteBuf chunkBuffer) {
            AsymmetricSecurityHeader.decode(chunkBuffer, ChunkDecoder.this.encodingLimits);
        }

        @Override
        public Cipher getCipher(SecureChannel channel) throws UaException {
            try {
                String transformation = channel.getSecurityPolicy().getAsymmetricEncryptionAlgorithm().getTransformation();
                Cipher cipher = Cipher.getInstance(transformation);
                cipher.init(2, channel.getKeyPair().getPrivate());
                return cipher;
            }
            catch (GeneralSecurityException e) {
                throw new UaException(0x80020000L, (Throwable)e);
            }
        }

        @Override
        public int getCipherTextBlockSize(SecureChannel channel) {
            return channel.getLocalAsymmetricCipherTextBlockSize();
        }

        @Override
        public int getSignatureSize(SecureChannel channel) {
            return channel.getRemoteAsymmetricSignatureSize();
        }

        @Override
        public void verifyChunk(SecureChannel channel, ByteBuf chunkBuffer) throws UaException {
            String transformation = channel.getSecurityPolicy().getAsymmetricSignatureAlgorithm().getTransformation();
            int signatureSize = channel.getRemoteAsymmetricSignatureSize();
            ByteBuffer chunkNioBuffer = chunkBuffer.nioBuffer(0, chunkBuffer.writerIndex());
            ((Buffer)chunkNioBuffer).position(0);
            ((Buffer)chunkNioBuffer).limit(chunkBuffer.writerIndex() - signatureSize);
            try {
                Signature signature = Signature.getInstance(transformation);
                signature.initVerify(channel.getRemoteCertificate().getPublicKey());
                signature.update(chunkNioBuffer);
                byte[] signatureBytes = new byte[signatureSize];
                ((Buffer)chunkNioBuffer).limit(chunkNioBuffer.position() + signatureSize);
                chunkNioBuffer.get(signatureBytes);
                if (!signature.verify(signatureBytes)) {
                    throw new UaException(2148728832L, "could not verify signature");
                }
            }
            catch (NoSuchAlgorithmException e) {
                throw new UaException(0x80020000L, (Throwable)e);
            }
            catch (SignatureException e) {
                throw new UaException(0x80580000L, (Throwable)e);
            }
            catch (InvalidKeyException e) {
                throw new UaException(2148663296L, (Throwable)e);
            }
        }

        @Override
        protected boolean isAsymmetric() {
            return true;
        }

        @Override
        public boolean isEncryptionEnabled(SecureChannel channel) {
            return channel.isAsymmetricEncryptionEnabled();
        }

        @Override
        public boolean isSigningEnabled(SecureChannel channel) {
            return channel.isAsymmetricEncryptionEnabled();
        }
    }

    private abstract class AbstractDecoder {
        protected final Logger logger = LoggerFactory.getLogger(this.getClass());

        private AbstractDecoder() {
        }

        DecodedMessage decode(SecureChannel channel, CompositeByteBuf composite, List<ByteBuf> chunkBuffers) throws MessageAbortException, UaException {
            int signatureSize = this.getSignatureSize(channel);
            int cipherTextBlockSize = this.getCipherTextBlockSize(channel);
            boolean encrypted = this.isEncryptionEnabled(channel);
            boolean signed = this.isSigningEnabled(channel);
            long requestId = -1L;
            for (ByteBuf chunkBuffer : chunkBuffers) {
                char chunkType = (char)chunkBuffer.getByte(3);
                chunkBuffer.skipBytes(12);
                this.readSecurityHeader(channel, chunkBuffer);
                if (encrypted) {
                    this.decryptChunk(channel, chunkBuffer);
                }
                int encryptedStart = chunkBuffer.readerIndex();
                chunkBuffer.readerIndex(0);
                if (signed) {
                    this.verifyChunk(channel, chunkBuffer);
                }
                int paddingOverhead = encrypted ? (cipherTextBlockSize > 256 ? 2 : 1) : 0;
                int paddingSize = encrypted ? this.getPaddingSize(cipherTextBlockSize, signatureSize, chunkBuffer) : 0;
                int bodyEnd = chunkBuffer.readableBytes() - signatureSize - paddingOverhead - paddingSize;
                chunkBuffer.readerIndex(encryptedStart);
                SequenceHeader sequenceHeader = SequenceHeader.decode(chunkBuffer);
                long sequenceNumber = sequenceHeader.getSequenceNumber();
                requestId = sequenceHeader.getRequestId();
                if (!LegacySequenceNumberValidator.validateSequenceNumber(ChunkDecoder.this.lastSequenceNumber, sequenceNumber)) {
                    throw new UaException(2148728832L, String.format("bad sequence number: %s, lastSequenceNumber=%s", sequenceNumber, ChunkDecoder.this.lastSequenceNumber));
                }
                ChunkDecoder.this.lastSequenceNumber = sequenceNumber;
                ByteBuf bodyBuffer = chunkBuffer.readSlice(bodyEnd - chunkBuffer.readerIndex());
                if (encrypted) {
                    int expectedPaddingSize = chunkBuffer.readableBytes() - signatureSize - paddingOverhead;
                    if (paddingSize != expectedPaddingSize) {
                        throw new UaException(2148728832L, "bad padding size");
                    }
                    byte expectedPaddingByte = (byte)(paddingSize & 0xFF);
                    for (int i = chunkBuffer.readerIndex(); i < chunkBuffer.readerIndex() + paddingSize + 1; ++i) {
                        if (chunkBuffer.getByte(i) == expectedPaddingByte) continue;
                        throw new UaException(2148728832L, "bad padding sequence");
                    }
                }
                if (chunkType == 'A') {
                    ErrorMessage errorMessage = ErrorMessage.decode(bodyBuffer);
                    throw new MessageAbortException(errorMessage.getReason(), requestId, errorMessage.getError());
                }
                composite.addComponent(bodyBuffer);
                composite.writerIndex(composite.writerIndex() + bodyBuffer.readableBytes());
            }
            if (ChunkDecoder.this.parameters.getLocalMaxMessageSize() > 0 && composite.readableBytes() > ChunkDecoder.this.parameters.getLocalMaxMessageSize()) {
                String errorMessage = String.format("message size exceeds configured limit: %s > %s", composite.readableBytes(), ChunkDecoder.this.parameters.getLocalMaxMessageSize());
                throw new UaException(0x80800000L, errorMessage);
            }
            return new DecodedMessage((ByteBuf)composite, requestId);
        }

        private void decryptChunk(SecureChannel channel, ByteBuf chunkBuffer) throws UaException {
            int cipherTextBlockSize = this.getCipherTextBlockSize(channel);
            int blockCount = chunkBuffer.readableBytes() / cipherTextBlockSize;
            int plainTextBufferSize = cipherTextBlockSize * blockCount;
            ByteBuf plainTextBuffer = BufferUtil.pooledBuffer(plainTextBufferSize);
            ByteBuffer plainTextNioBuffer = plainTextBuffer.writerIndex(plainTextBufferSize).nioBuffer();
            ByteBuffer chunkNioBuffer = chunkBuffer.nioBuffer();
            try {
                Cipher cipher = this.getCipher(channel);
                assert (chunkBuffer.readableBytes() % cipherTextBlockSize == 0);
                if (this.isAsymmetric()) {
                    for (int blockNumber = 0; blockNumber < blockCount; ++blockNumber) {
                        ((Buffer)chunkNioBuffer).limit(chunkNioBuffer.position() + cipherTextBlockSize);
                        cipher.doFinal(chunkNioBuffer, plainTextNioBuffer);
                    }
                } else {
                    cipher.doFinal(chunkNioBuffer, plainTextNioBuffer);
                }
                ((Buffer)plainTextNioBuffer).flip();
                chunkBuffer.writerIndex(chunkBuffer.readerIndex());
                chunkBuffer.writeBytes(plainTextNioBuffer);
            }
            catch (GeneralSecurityException e) {
                throw new UaException(2148728832L, (Throwable)e);
            }
            finally {
                plainTextBuffer.release();
            }
        }

        private int getPaddingSize(int cipherTextBlockSize, int signatureSize, ByteBuf buffer) {
            int lastPaddingByteOffset = buffer.readableBytes() - signatureSize - 1;
            return cipherTextBlockSize <= 256 ? buffer.getUnsignedByte(lastPaddingByteOffset) : buffer.getUnsignedShortLE(lastPaddingByteOffset - 1);
        }

        protected abstract void readSecurityHeader(SecureChannel var1, ByteBuf var2) throws UaException;

        protected abstract Cipher getCipher(SecureChannel var1) throws UaException;

        protected abstract int getCipherTextBlockSize(SecureChannel var1);

        protected abstract int getSignatureSize(SecureChannel var1);

        protected abstract void verifyChunk(SecureChannel var1, ByteBuf var2) throws UaException;

        protected abstract boolean isAsymmetric();

        protected abstract boolean isEncryptionEnabled(SecureChannel var1);

        protected abstract boolean isSigningEnabled(SecureChannel var1);
    }

    public static class DecodedMessage {
        private final ByteBuf message;
        private final long requestId;

        private DecodedMessage(ByteBuf message, long requestId) {
            this.message = message;
            this.requestId = requestId;
        }

        public ByteBuf getMessage() {
            return this.message;
        }

        public long getRequestId() {
            return this.requestId;
        }
    }
}

