package handshake import ( "bytes" "context" "crypto/ecdh" "crypto/ed25519" "crypto/sha256" "errors" "io" "log/slog" "git.numenor-labs.us/dsfx/internal/lib/assert" "git.numenor-labs.us/dsfx/internal/lib/buffer" "git.numenor-labs.us/dsfx/internal/lib/crypto/encryption" "git.numenor-labs.us/dsfx/internal/lib/crypto/identity" "git.numenor-labs.us/dsfx/internal/lib/crypto/keyexchange" "git.numenor-labs.us/dsfx/internal/lib/logging" ) const ( DHKeySize = 32 IdentityKeySize = 32 ) // Initiate initiates the handshake process between the given actor // and the remote actor. func Initiate( ctx context.Context, conn io.ReadWriteCloser, lPrivKey ed25519.PrivateKey, rPubKey ed25519.PublicKey, ) ([]byte, error) { logger := logging.FromContext(ctx).WithGroup("handshake") // ------------------------------------------------------------------------ // Step 1: Ephemeral Key Exchange To Server logger.DebugContext(ctx, "creating dh key") // Create a new ECDH private key for the actor. ourDHKey, err := keyexchange.GenerateDHKey() if err != nil { return nil, err } assert.Assert(ourDHKey != nil, "failed to generate dh key") logger.DebugContext(ctx, "exporting dh key") // Export the public key of the actor's ECDH private key. ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey()) if err != nil { return nil, err } welcomeMessage, err := buffer.NewLenPrefixed(ourDHKeyRaw) if err != nil { return nil, err } // Write the actor's public key to the connection. logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw))) n, err := conn.Write(welcomeMessage) if err != nil { return nil, err } if n != len(welcomeMessage) { return nil, errors.New("failed to write dh key") } // ------------------------------------------------------------------------ // Step 2: Ephemeral Key Exchange From Server // Read the remote actor's public key from the connection. logger.DebugContext(ctx, "waiting for server's dh key") remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, err } // Import the remote actor's public key. logger.DebugContext(ctx, "importing server's dh key") remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw) if err != nil { return nil, err } // ------------------------------------------------------------------------ // Step 3: Client Authentication // Export the public key of the actor's signing key. logger.DebugContext(ctx, "exporting public signing key") ourPublicKeyRaw, err := identity.ExportPublicKey(identity.ToPublicKey(lPrivKey)) if err != nil { return nil, err } logger.DebugContext(ctx, "exporting remote public signing key") remotePublicKeyRaw, err := identity.ExportPublicKey(rPubKey) if err != nil { return nil, err } // Construct the message that will be signed by the client. // This message is formatted as follows: // rlt + sha256(ae + be) // This binds both the client and server's long term public keys to the // ephemeral keys that were exchanged in the previous step. This creates a // verifiable link between both parties and the ephemeral keys used to // establish the shared secret. logger.DebugContext(ctx, "building authentication message") authMessage, err := buildMessage(ourDHKey.PublicKey(), remoteDHKey) if err != nil { return nil, err } // Sign the message with the actor's private key. logger.DebugContext(ctx, "signing authentication message") signature, err := identity.Sign(lPrivKey, authMessage) if err != nil { return nil, err } // Compute the shared secret between the actor and the remote actor. logger.DebugContext(ctx, "computing shared secret") derivedKey, err := keyexchange.ComputeDHSecret(ourDHKey, remoteDHKey) if err != nil { return nil, err } assert.Assert(len(derivedKey) == 32, "invalid shared secret size") plaintext := make([]byte, 0, len(ourPublicKeyRaw)+len(signature)) plaintext = append(plaintext, ourPublicKeyRaw...) plaintext = append(plaintext, signature...) // Encrypt the message with the derived key. logger.DebugContext(ctx, "encrypting authentication message") boxedMsg, err := encryption.Encrypt(derivedKey, plaintext) if err != nil { return nil, err } // Write the boxed message to the connection. logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg))) boxedMsgPrepared, err := buffer.NewLenPrefixed(boxedMsg) if err != nil { return nil, err } n, err = conn.Write(boxedMsgPrepared) if err != nil { return nil, err } if n != len(boxedMsgPrepared) { return nil, errors.New("failed to write authentication message") } // ------------------------------------------------------------------------ // Step 4: Server Authentication // Read the authentication message from the connection. logger.DebugContext(ctx, "waiting for server's authentication message") authMessageBoxed, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, err } logger.DebugContext(ctx, "received authentication message", slog.Int("message.size", int(n))) // Decrypt the authentication message with the derived key. logger.DebugContext(ctx, "decrypting authentication message") plaintext, err = encryption.Decrypt(derivedKey, authMessageBoxed) if err != nil { return nil, err } // The server authentication is just verifying the signature it created of // the client authentication message. logger.DebugContext(ctx, "importing server's public signing key") remotePublicKey, err := identity.ImportPublicKey(remotePublicKeyRaw) if err != nil { return nil, err } logger.DebugContext(ctx, "verifying server's signature") if !identity.Verify(remotePublicKey, authMessage, plaintext) { return nil, errors.New("failed to verify server's signature") } // Finally, we need to let the server know that the handshake is complete. logger.DebugContext(ctx, "sending handshake complete message") handshakeCompleteMsg, err := buffer.NewLenPrefixed([]byte{0x01}) if err != nil { return nil, err } n, err = conn.Write(handshakeCompleteMsg) if err != nil { return nil, err } if n != len(handshakeCompleteMsg) { return nil, errors.New("failed to write handshake complete message") } logger.DebugContext(ctx, "handshake complete") return derivedKey, nil } // Accept accepts a handshake from the given actor and connection. It // returns the shared secret between the actor and the remote actor. func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey ed25519.PrivateKey) (ed25519.PublicKey, []byte, error) { logger := logging.FromContext(ctx).WithGroup("handshake") // ------------------------------------------------------------------------ // Step 1: Ephemeral Key Exchange From Client // Read the remote actor's public key from the connection. logger.DebugContext(ctx, "waiting for client's dh key") remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, nil, err } // Import the remote actor's public key. logger.DebugContext(ctx, "importing client's dh key") remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw) if err != nil { return nil, nil, err } // ------------------------------------------------------------------------ // Step 2: Ephemeral Key Exchange To Client // Create a new ECDH private key for the actor. logger.DebugContext(ctx, "creating dh key") ourDHKey, err := keyexchange.GenerateDHKey() if err != nil { return nil, nil, err } // Export the public key of the actor's ECDH private key. logger.DebugContext(ctx, "exporting dh key") ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey()) if err != nil { return nil, nil, err } // Write the actor's public key to the connection. logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw))) ourPrefixedDHKey, err := buffer.NewLenPrefixed(ourDHKeyRaw) if err != nil { return nil, nil, err } n, err := conn.Write(ourPrefixedDHKey) if err != nil { return nil, nil, err } if n != len(ourPrefixedDHKey) { return nil, nil, errors.New("failed to write dh key") } // ------------------------------------------------------------------------ // Step 3: Server Authentication // Read the authentication message from the connection. logger.DebugContext(ctx, "waiting for client's authentication message") authMessageRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, nil, err } logger.DebugContext( ctx, "received authentication message", slog.Int("message.size", len(authMessageRaw)), ) // Decrypt the authentication message with the derived key. logger.DebugContext(ctx, "computing shared secret") derivedKey, err := keyexchange.ComputeDHSecret(ourDHKey, remoteDHKey) if err != nil { return nil, nil, err } logger.DebugContext(ctx, "decrypting authentication message") plaintext, err := encryption.Decrypt(derivedKey, authMessageRaw) if err != nil { return nil, nil, err } clientPublicKeyRaw := plaintext[:44] signature := plaintext[44:] // Verify the client's public key and signature. logger.DebugContext(ctx, "importing client's public signing key") clientPublicKey, err := identity.ImportPublicKey(clientPublicKeyRaw) if err != nil { return nil, nil, err } // Construct the message that was signed by the client. // This message is formatted as follows: // rlt + sha256(ae + be) // This binds both the client and server's long term public keys to the // ephemeral keys that were exchanged in the previous step. This creates a // verifiable link between both parties and the ephemeral keys used to // establish the shared secret. logger.DebugContext(ctx, "building authentication message") authMessage, err := buildMessage(remoteDHKey, ourDHKey.PublicKey()) if err != nil { return nil, nil, err } logger.DebugContext(ctx, "verifying client's signature") if !identity.Verify(clientPublicKey, authMessage, signature) { return nil, nil, errors.New("failed to verify client's signature") } // Now we need to sign the authentication message with the server's private // key. This will be sent back to the client in the next step to authenticate // the server to the client. logger.DebugContext(ctx, "signing authentication message") serverSignature, err := identity.Sign(lPrivKey, authMessage) if err != nil { return nil, nil, err } logger.DebugContext(ctx, "encrypting server's signature") boxedMsg, err := encryption.Encrypt(derivedKey, serverSignature) if err != nil { return nil, nil, err } // Send the server's signature back to the client. logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg))) prefixedAuthMessage, err := buffer.NewLenPrefixed(boxedMsg) if err != nil { return nil, nil, err } n, err = conn.Write(prefixedAuthMessage) if err != nil { return nil, nil, err } if n != len(prefixedAuthMessage) { return nil, nil, errors.New("failed to write authentication message") } logger.DebugContext(ctx, "waiting for handshake complete message") // Read the handshake complete message from the client. handshakeCompleteMsg, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, nil, err } if !bytes.Equal(handshakeCompleteMsg, []byte{0x01}) { return nil, nil, errors.New("invalid handshake complete message") } // ------------------------------------------------------------------------ // Step 4: Client Authentication logger.DebugContext(ctx, "handshake complete") return clientPublicKey, derivedKey, nil } func buildMessage(clientPubKey *ecdh.PublicKey, serverPubKey *ecdh.PublicKey) ([]byte, error) { clientPubKeyRaw, err := keyexchange.ExportPublicKey(clientPubKey) if err != nil { return nil, err } serverPubKeyRaw, err := keyexchange.ExportPublicKey(serverPubKey) if err != nil { return nil, err } // Construct the message that will be signed by the client. // This message is formatted as follows: // rlt + sha256(ae + be) // This binds both the client and server's long term public keys to the // ephemeral keys that were exchanged in the previous step. This creates a // verifiable link between both parties and the ephemeral keys used to // establish the shared secret. message := make([]byte, 0, len(clientPubKeyRaw)+len(serverPubKeyRaw)) message = append(message, clientPubKeyRaw...) message = append(message, serverPubKeyRaw...) messageChecksum := sha256.Sum256(message) authMessage := make([]byte, 0, len(serverPubKeyRaw)+sha256.Size) authMessage = append(authMessage, serverPubKeyRaw...) authMessage = append(authMessage, messageChecksum[:]...) return authMessage, nil }