package handshake

import (
	"bytes"
	"context"
	"crypto/ecdh"
	"crypto/ed25519"
	"crypto/sha256"
	"errors"
	"io"
	"log/slog"

	"koti.casa/numenor-labs/dsfx/internal/lib/assert"
	"koti.casa/numenor-labs/dsfx/internal/lib/buffer"
	"koti.casa/numenor-labs/dsfx/internal/lib/crypto/encryption"
	"koti.casa/numenor-labs/dsfx/internal/lib/crypto/identity"
	"koti.casa/numenor-labs/dsfx/internal/lib/crypto/keyexchange"
	"koti.casa/numenor-labs/dsfx/internal/lib/logging"
)

const (
	DHKeySize       = 32
	IdentityKeySize = 32
)

// Initiate initiates the handshake process between the given actor
// and the remote actor.
func Initiate(
	ctx context.Context,
	conn io.ReadWriteCloser,
	lPrivKey ed25519.PrivateKey,
	rPubKey ed25519.PublicKey,
) ([]byte, error) {
	logger := logging.FromContext(ctx).WithGroup("handshake")

	// ------------------------------------------------------------------------
	// Step 1: Ephemeral Key Exchange To Server

	logger.DebugContext(ctx, "creating dh key")
	// Create a new ECDH private key for the actor.
	ourDHKey, err := keyexchange.GenerateDHKey()
	if err != nil {
		return nil, err
	}
	assert.Assert(ourDHKey != nil, "failed to generate dh key")

	logger.DebugContext(ctx, "exporting dh key")
	// Export the public key of the actor's ECDH private key.
	ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey())
	if err != nil {
		return nil, err
	}

	welcomeMessage, err := buffer.NewLenPrefixed(ourDHKeyRaw)
	if err != nil {
		return nil, err
	}

	// Write the actor's public key to the connection.
	logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
	n, err := conn.Write(welcomeMessage)
	if err != nil {
		return nil, err
	}
	if n != len(welcomeMessage) {
		return nil, errors.New("failed to write dh key")
	}

	// ------------------------------------------------------------------------
	// Step 2: Ephemeral Key Exchange From Server

	// Read the remote actor's public key from the connection.
	logger.DebugContext(ctx, "waiting for server's dh key")
	remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
	if err != nil {
		return nil, err
	}

	// Import the remote actor's public key.
	logger.DebugContext(ctx, "importing server's dh key")
	remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw)
	if err != nil {
		return nil, err
	}

	// ------------------------------------------------------------------------
	// Step 3: Client Authentication

	// Export the public key of the actor's signing key.
	logger.DebugContext(ctx, "exporting public signing key")
	ourPublicKeyRaw, err := identity.ExportPublicKey(identity.ToPublicKey(lPrivKey))
	if err != nil {
		return nil, err
	}

	logger.DebugContext(ctx, "exporting remote public signing key")
	remotePublicKeyRaw, err := identity.ExportPublicKey(rPubKey)
	if err != nil {
		return nil, err
	}

	// Construct the message that will be signed by the client.
	// This message is formatted as follows:
	// rlt + sha256(ae + be)
	// This binds both the client and server's long term public keys to the
	// ephemeral keys that were exchanged in the previous step. This creates a
	// verifiable link between both parties and the ephemeral keys used to
	// establish the shared secret.
	logger.DebugContext(ctx, "building authentication message")
	authMessage, err := buildMessage(ourDHKey.PublicKey(), remoteDHKey)
	if err != nil {
		return nil, err
	}

	// Sign the message with the actor's private key.
	logger.DebugContext(ctx, "signing authentication message")
	signature, err := identity.Sign(lPrivKey, authMessage)
	if err != nil {
		return nil, err
	}

	// Compute the shared secret between the actor and the remote actor.
	logger.DebugContext(ctx, "computing shared secret")
	derivedKey, err := keyexchange.ComputeDHSecret(ourDHKey, remoteDHKey)
	if err != nil {
		return nil, err
	}
	assert.Assert(len(derivedKey) == 32, "invalid shared secret size")

	plaintext := make([]byte, 0, len(ourPublicKeyRaw)+len(signature))
	plaintext = append(plaintext, ourPublicKeyRaw...)
	plaintext = append(plaintext, signature...)

	// Encrypt the message with the derived key.
	logger.DebugContext(ctx, "encrypting authentication message")
	boxedMsg, err := encryption.Encrypt(derivedKey, plaintext)
	if err != nil {
		return nil, err
	}

	// Write the boxed message to the connection.
	logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
	boxedMsgPrepared, err := buffer.NewLenPrefixed(boxedMsg)
	if err != nil {
		return nil, err
	}
	n, err = conn.Write(boxedMsgPrepared)
	if err != nil {
		return nil, err
	}
	if n != len(boxedMsgPrepared) {
		return nil, errors.New("failed to write authentication message")
	}

	// ------------------------------------------------------------------------
	// Step 4: Server Authentication

	// Read the authentication message from the connection.
	logger.DebugContext(ctx, "waiting for server's authentication message")
	authMessageBoxed, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
	if err != nil {
		return nil, err
	}
	logger.DebugContext(ctx, "received authentication message", slog.Int("message.size", int(n)))

	// Decrypt the authentication message with the derived key.
	logger.DebugContext(ctx, "decrypting authentication message")
	plaintext, err = encryption.Decrypt(derivedKey, authMessageBoxed)
	if err != nil {
		return nil, err
	}

	// The server authentication is just verifying the signature it created of
	// the client authentication message.
	logger.DebugContext(ctx, "importing server's public signing key")
	remotePublicKey, err := identity.ImportPublicKey(remotePublicKeyRaw)
	if err != nil {
		return nil, err
	}

	logger.DebugContext(ctx, "verifying server's signature")
	if !identity.Verify(remotePublicKey, authMessage, plaintext) {
		return nil, errors.New("failed to verify server's signature")
	}

	// Finally, we need to let the server know that the handshake is complete.
	logger.DebugContext(ctx, "sending handshake complete message")
	handshakeCompleteMsg, err := buffer.NewLenPrefixed([]byte{0x01})
	if err != nil {
		return nil, err
	}
	n, err = conn.Write(handshakeCompleteMsg)
	if err != nil {
		return nil, err
	}
	if n != len(handshakeCompleteMsg) {
		return nil, errors.New("failed to write handshake complete message")
	}

	logger.DebugContext(ctx, "handshake complete")
	return derivedKey, nil
}

// Accept accepts a handshake from the given actor and connection. It
// returns the shared secret between the actor and the remote actor.
func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey ed25519.PrivateKey) (ed25519.PublicKey, []byte, error) {
	logger := logging.FromContext(ctx).WithGroup("handshake")

	// ------------------------------------------------------------------------
	// Step 1: Ephemeral Key Exchange From Client

	// Read the remote actor's public key from the connection.
	logger.DebugContext(ctx, "waiting for client's dh key")
	remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
	if err != nil {
		return nil, nil, err
	}

	// Import the remote actor's public key.
	logger.DebugContext(ctx, "importing client's dh key")
	remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw)
	if err != nil {
		return nil, nil, err
	}

	// ------------------------------------------------------------------------
	// Step 2: Ephemeral Key Exchange To Client

	// Create a new ECDH private key for the actor.
	logger.DebugContext(ctx, "creating dh key")
	ourDHKey, err := keyexchange.GenerateDHKey()
	if err != nil {
		return nil, nil, err
	}

	// Export the public key of the actor's ECDH private key.
	logger.DebugContext(ctx, "exporting dh key")
	ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey())
	if err != nil {
		return nil, nil, err
	}

	// Write the actor's public key to the connection.
	logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
	ourPrefixedDHKey, err := buffer.NewLenPrefixed(ourDHKeyRaw)
	if err != nil {
		return nil, nil, err
	}
	n, err := conn.Write(ourPrefixedDHKey)
	if err != nil {
		return nil, nil, err
	}
	if n != len(ourPrefixedDHKey) {
		return nil, nil, errors.New("failed to write dh key")
	}

	// ------------------------------------------------------------------------
	// Step 3: Server Authentication

	// Read the authentication message from the connection.
	logger.DebugContext(ctx, "waiting for client's authentication message")
	authMessageRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
	if err != nil {
		return nil, nil, err
	}
	logger.DebugContext(
		ctx,
		"received authentication message",
		slog.Int("message.size", len(authMessageRaw)),
	)

	// Decrypt the authentication message with the derived key.
	logger.DebugContext(ctx, "computing shared secret")
	derivedKey, err := keyexchange.ComputeDHSecret(ourDHKey, remoteDHKey)
	if err != nil {
		return nil, nil, err
	}

	logger.DebugContext(ctx, "decrypting authentication message")
	plaintext, err := encryption.Decrypt(derivedKey, authMessageRaw)
	if err != nil {
		return nil, nil, err
	}

	clientPublicKeyRaw := plaintext[:44]
	signature := plaintext[44:]

	// Verify the client's public key and signature.
	logger.DebugContext(ctx, "importing client's public signing key")
	clientPublicKey, err := identity.ImportPublicKey(clientPublicKeyRaw)
	if err != nil {
		return nil, nil, err
	}

	// Construct the message that was signed by the client.
	// This message is formatted as follows:
	// rlt + sha256(ae + be)
	// This binds both the client and server's long term public keys to the
	// ephemeral keys that were exchanged in the previous step. This creates a
	// verifiable link between both parties and the ephemeral keys used to
	// establish the shared secret.
	logger.DebugContext(ctx, "building authentication message")
	authMessage, err := buildMessage(remoteDHKey, ourDHKey.PublicKey())
	if err != nil {
		return nil, nil, err
	}

	logger.DebugContext(ctx, "verifying client's signature")
	if !identity.Verify(clientPublicKey, authMessage, signature) {
		return nil, nil, errors.New("failed to verify client's signature")
	}

	// Now we need to sign the authentication message with the server's private
	// key. This will be sent back to the client in the next step to authenticate
	// the server to the client.
	logger.DebugContext(ctx, "signing authentication message")
	serverSignature, err := identity.Sign(lPrivKey, authMessage)
	if err != nil {
		return nil, nil, err
	}

	logger.DebugContext(ctx, "encrypting server's signature")
	boxedMsg, err := encryption.Encrypt(derivedKey, serverSignature)
	if err != nil {
		return nil, nil, err
	}

	// Send the server's signature back to the client.
	logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
	prefixedAuthMessage, err := buffer.NewLenPrefixed(boxedMsg)
	if err != nil {
		return nil, nil, err
	}
	n, err = conn.Write(prefixedAuthMessage)
	if err != nil {
		return nil, nil, err
	}
	if n != len(prefixedAuthMessage) {
		return nil, nil, errors.New("failed to write authentication message")
	}

	logger.DebugContext(ctx, "waiting for handshake complete message")
	// Read the handshake complete message from the client.
	handshakeCompleteMsg, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
	if err != nil {
		return nil, nil, err
	}
	if !bytes.Equal(handshakeCompleteMsg, []byte{0x01}) {
		return nil, nil, errors.New("invalid handshake complete message")
	}

	// ------------------------------------------------------------------------
	// Step 4: Client Authentication

	logger.DebugContext(ctx, "handshake complete")
	return clientPublicKey, derivedKey, nil
}

func buildMessage(clientPubKey *ecdh.PublicKey, serverPubKey *ecdh.PublicKey) ([]byte, error) {
	clientPubKeyRaw, err := keyexchange.ExportPublicKey(clientPubKey)
	if err != nil {
		return nil, err
	}
	serverPubKeyRaw, err := keyexchange.ExportPublicKey(serverPubKey)
	if err != nil {
		return nil, err
	}
	// Construct the message that will be signed by the client.
	// This message is formatted as follows:
	// rlt + sha256(ae + be)
	// This binds both the client and server's long term public keys to the
	// ephemeral keys that were exchanged in the previous step. This creates a
	// verifiable link between both parties and the ephemeral keys used to
	// establish the shared secret.
	message := make([]byte, 0, len(clientPubKeyRaw)+len(serverPubKeyRaw))
	message = append(message, clientPubKeyRaw...)
	message = append(message, serverPubKeyRaw...)

	messageChecksum := sha256.Sum256(message)
	authMessage := make([]byte, 0, len(serverPubKeyRaw)+sha256.Size)
	authMessage = append(authMessage, serverPubKeyRaw...)
	authMessage = append(authMessage, messageChecksum[:]...)

	return authMessage, nil
}