dsfx/pkg/handshake/handshake.go
2025-03-09 17:09:39 -04:00

362 lines
12 KiB
Go

package handshake
import (
"bytes"
"context"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/sha256"
"errors"
"io"
"log/slog"
"koti.casa/numenor-labs/dsfx/pkg/assert"
"koti.casa/numenor-labs/dsfx/pkg/crypto/encryption"
"koti.casa/numenor-labs/dsfx/pkg/crypto/identity"
"koti.casa/numenor-labs/dsfx/pkg/crypto/keyexchange"
"koti.casa/numenor-labs/dsfx/pkg/frame"
"koti.casa/numenor-labs/dsfx/pkg/logging"
)
const (
// ECDHPublicKeySize is the size of an ECDH public key in bytes.
ECDHPublicKeySize = 97
// ECDSAPublicKeySize is the size of an ECDSA public key in bytes.
ECDSAPublicKeySize = 222
// BoxedClientAuthMessageSize is the size of a boxed client authentication message in bytes.
BoxedClientAuthMessageSize = 353
// BoxedServerAuthMessageSize is the size of a boxed server authentication message in bytes.
BoxedServerAuthMessageSize = 130
)
// Initiate initiates the handshake process between the given actor
// and the remote actor.
func Initiate(
ctx context.Context,
conn io.ReadWriteCloser,
lPrivKey *ecdsa.PrivateKey,
rPubKey *ecdsa.PublicKey,
) ([]byte, error) {
logger := logging.FromContext(ctx).WithGroup("handshake")
// ------------------------------------------------------------------------
// Step 1: Ephemeral Key Exchange To Server
logger.DebugContext(ctx, "creating dh key")
// Create a new ECDH private key for the actor.
ourDHKey, err := keyexchange.GenerateDHKey()
if err != nil {
return nil, err
}
assert.Assert(ourDHKey != nil, "failed to generate dh key")
logger.DebugContext(ctx, "exporting dh key")
// Export the public key of the actor's ECDH private key.
ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey())
if err != nil {
return nil, err
}
assert.Assert(len(ourDHKeyRaw) == ECDHPublicKeySize, "invalid dh key size")
// Write the actor's public key to the connection.
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
_, err = frame.New(ourDHKeyRaw).WriteTo(conn)
if err != nil {
return nil, err
}
// ------------------------------------------------------------------------
// Step 2: Ephemeral Key Exchange From Server
// Read the remote actor's public key from the connection.
logger.DebugContext(ctx, "waiting for server's dh key")
remoteDHKeyFrame := frame.New(nil)
_, err = remoteDHKeyFrame.ReadFrom(conn)
if err != nil {
return nil, err
}
if len(remoteDHKeyFrame.Contents()) != ECDHPublicKeySize {
return nil, errors.New("invalid dh key size")
}
// Import the remote actor's public key.
logger.DebugContext(ctx, "importing server's dh key")
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents())
if err != nil {
return nil, err
}
// ------------------------------------------------------------------------
// Step 3: Client Authentication
// Export the public key of the actor's signing key.
logger.DebugContext(ctx, "exporting public signing key")
ourPublicKeyRaw, err := identity.ExportPublicKey(&lPrivKey.PublicKey)
if err != nil {
return nil, err
}
logger.DebugContext(ctx, "exporting remote public signing key")
remotePublicKeyRaw, err := identity.ExportPublicKey(rPubKey)
if err != nil {
return nil, err
}
// Construct the message that will be signed by the client.
// This message is formatted as follows:
// rlt + sha256(ae + be)
// This binds both the client and server's long term public keys to the
// ephemeral keys that were exchanged in the previous step. This creates a
// verifiable link between both parties and the ephemeral keys used to
// establish the shared secret.
logger.DebugContext(ctx, "building authentication message")
authMessage, err := buildMessage(ourDHKey.PublicKey(), remoteDHKey)
if err != nil {
return nil, err
}
// Sign the message with the actor's private key.
logger.DebugContext(ctx, "signing authentication message")
signature, err := identity.Sign(lPrivKey, authMessage)
if err != nil {
return nil, err
}
// Compute the shared secret between the actor and the remote actor.
logger.DebugContext(ctx, "computing shared secret")
derivedKey, err := keyexchange.ComputeDHSecret(ourDHKey, remoteDHKey)
if err != nil {
return nil, err
}
assert.Assert(len(derivedKey) == 32, "invalid shared secret size")
plaintext := make([]byte, 0, len(ourPublicKeyRaw)+len(signature))
plaintext = append(plaintext, ourPublicKeyRaw...)
plaintext = append(plaintext, signature...)
// Encrypt the message with the derived key.
logger.DebugContext(ctx, "encrypting authentication message")
boxedMsg, err := encryption.Encrypt(derivedKey, plaintext)
if err != nil {
return nil, err
}
// Write the boxed message to the connection.
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
_, err = frame.New(boxedMsg).WriteTo(conn)
if err != nil {
return nil, err
}
// ------------------------------------------------------------------------
// Step 4: Server Authentication
// Read the authentication message from the connection.
logger.DebugContext(ctx, "waiting for server's authentication message")
authMessageFrame := frame.New(nil)
n, err := authMessageFrame.ReadFrom(conn)
if err != nil {
return nil, err
}
logger.DebugContext(ctx, "received authentication message", slog.Int("message.size", int(n)))
// Decrypt the authentication message with the derived key.
logger.DebugContext(ctx, "decrypting authentication message")
plaintext, err = encryption.Decrypt(derivedKey, authMessageFrame.Contents())
if err != nil {
return nil, err
}
// The server authentication is just verifying the signature it created of
// the client authentication message.
logger.DebugContext(ctx, "importing server's public signing key")
remotePublicKey, err := identity.ImportPublicKey(remotePublicKeyRaw)
if err != nil {
return nil, err
}
logger.DebugContext(ctx, "verifying server's signature")
if !identity.Verify(remotePublicKey, authMessage, plaintext) {
return nil, errors.New("failed to verify server's signature")
}
// Finally, we need to let the server know that the handshake is complete.
logger.DebugContext(ctx, "sending handshake complete message")
handshakeCompleteMsg := []byte{0x01}
_, err = frame.New(handshakeCompleteMsg).WriteTo(conn)
if err != nil {
return nil, err
}
logger.DebugContext(ctx, "handshake complete")
return derivedKey, nil
}
// Accept accepts a handshake from the given actor and connection. It
// returns the shared secret between the actor and the remote actor.
func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.PrivateKey) (*ecdsa.PublicKey, []byte, error) {
logger := logging.FromContext(ctx).WithGroup("handshake")
// ------------------------------------------------------------------------
// Step 1: Ephemeral Key Exchange From Client
// Read the remote actor's public key from the connection.
logger.DebugContext(ctx, "waiting for client's dh key")
remoteDHKeyFrame := frame.New(nil)
_, err := remoteDHKeyFrame.ReadFrom(conn)
if err != nil {
return nil, nil, err
}
// Import the remote actor's public key.
logger.DebugContext(ctx, "importing client's dh key")
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents())
if err != nil {
return nil, nil, err
}
// ------------------------------------------------------------------------
// Step 2: Ephemeral Key Exchange To Client
// Create a new ECDH private key for the actor.
logger.DebugContext(ctx, "creating dh key")
ourDHKey, err := keyexchange.GenerateDHKey()
if err != nil {
return nil, nil, err
}
// Export the public key of the actor's ECDH private key.
logger.DebugContext(ctx, "exporting dh key")
ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey())
if err != nil {
return nil, nil, err
}
// Write the actor's public key to the connection.
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
_, err = frame.New(ourDHKeyRaw).WriteTo(conn)
if err != nil {
return nil, nil, err
}
// ------------------------------------------------------------------------
// Step 3: Server Authentication
// Read the authentication message from the connection.
logger.DebugContext(ctx, "waiting for client's authentication message")
authMessageFrame := frame.New(nil)
n, err := authMessageFrame.ReadFrom(conn)
if err != nil {
return nil, nil, err
}
logger.DebugContext(ctx, "received authentication message", slog.Int("message.size", int(n)))
// Decrypt the authentication message with the derived key.
logger.DebugContext(ctx, "computing shared secret")
derivedKey, err := keyexchange.ComputeDHSecret(ourDHKey, remoteDHKey)
if err != nil {
return nil, nil, err
}
logger.DebugContext(ctx, "decrypting authentication message")
plaintext, err := encryption.Decrypt(derivedKey, authMessageFrame.Contents())
if err != nil {
return nil, nil, err
}
clientPublicKeyRaw := plaintext[:identity.ExportedPublicKeySize]
signature := plaintext[identity.ExportedPublicKeySize:]
// Verify the client's public key and signature.
logger.DebugContext(ctx, "importing client's public signing key")
clientPublicKey, err := identity.ImportPublicKey(clientPublicKeyRaw)
if err != nil {
return nil, nil, err
}
// Construct the message that was signed by the client.
// This message is formatted as follows:
// rlt + sha256(ae + be)
// This binds both the client and server's long term public keys to the
// ephemeral keys that were exchanged in the previous step. This creates a
// verifiable link between both parties and the ephemeral keys used to
// establish the shared secret.
logger.DebugContext(ctx, "building authentication message")
authMessage, err := buildMessage(remoteDHKey, ourDHKey.PublicKey())
if err != nil {
return nil, nil, err
}
logger.DebugContext(ctx, "verifying client's signature")
if !identity.Verify(clientPublicKey, authMessage, signature) {
return nil, nil, errors.New("failed to verify client's signature")
}
// Now we need to sign the authentication message with the server's private
// key. This will be sent back to the client in the next step to authenticate
// the server to the client.
logger.DebugContext(ctx, "signing authentication message")
serverSignature, err := identity.Sign(lPrivKey, authMessage)
if err != nil {
return nil, nil, err
}
logger.DebugContext(ctx, "encrypting server's signature")
boxedMsg, err := encryption.Encrypt(derivedKey, serverSignature)
if err != nil {
return nil, nil, err
}
// Send the server's signature back to the client.
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
_, err = frame.New(boxedMsg).WriteTo(conn)
if err != nil {
return nil, nil, err
}
logger.DebugContext(ctx, "waiting for handshake complete message")
// Read the handshake complete message from the client.
handshakeCompleteFrame := frame.New(nil)
_, err = handshakeCompleteFrame.ReadFrom(conn)
if err != nil {
return nil, nil, err
}
if !bytes.Equal(handshakeCompleteFrame.Contents(), []byte{0x01}) {
return nil, nil, errors.New("invalid handshake complete message")
}
// ------------------------------------------------------------------------
// Step 4: Client Authentication
logger.DebugContext(ctx, "handshake complete")
return clientPublicKey, derivedKey, nil
}
func buildMessage(clientPubKey *ecdh.PublicKey, serverPubKey *ecdh.PublicKey) ([]byte, error) {
clientPubKeyRaw, err := keyexchange.ExportPublicKey(clientPubKey)
if err != nil {
return nil, err
}
serverPubKeyRaw, err := keyexchange.ExportPublicKey(serverPubKey)
if err != nil {
return nil, err
}
// Construct the message that will be signed by the client.
// This message is formatted as follows:
// rlt + sha256(ae + be)
// This binds both the client and server's long term public keys to the
// ephemeral keys that were exchanged in the previous step. This creates a
// verifiable link between both parties and the ephemeral keys used to
// establish the shared secret.
message := make([]byte, 0, len(clientPubKeyRaw)+len(serverPubKeyRaw))
message = append(message, clientPubKeyRaw...)
message = append(message, serverPubKeyRaw...)
messageChecksum := sha256.Sum256(message)
authMessage := make([]byte, 0, len(serverPubKeyRaw)+sha256.Size)
authMessage = append(authMessage, serverPubKeyRaw...)
authMessage = append(authMessage, messageChecksum[:]...)
return authMessage, nil
}