From 9eaf8380693a7aa6f7305e8aa5b2c8ccd1514613 Mon Sep 17 00:00:00 2001 From: Dustin Stiles Date: Mon, 10 Mar 2025 09:21:34 -0400 Subject: [PATCH] use buffer pkg over frame --- pkg/buffer/lenprefixed.go | 57 +++++++++++++++++++++++ pkg/handshake/handshake.go | 94 ++++++++++++++++++++++++-------------- 2 files changed, 116 insertions(+), 35 deletions(-) create mode 100644 pkg/buffer/lenprefixed.go diff --git a/pkg/buffer/lenprefixed.go b/pkg/buffer/lenprefixed.go new file mode 100644 index 0000000..c4463c7 --- /dev/null +++ b/pkg/buffer/lenprefixed.go @@ -0,0 +1,57 @@ +package buffer + +import ( + "encoding/binary" + "errors" + "io" +) + +// MaxUint16 is the maximum value of a uint16. It is used to check if the +// length of the data is too large to be encoded in a uint16. +const MaxUint16 = 0xFFFF + +// ErrInvalidLength is the error message returned when the length of the data +// is too large to be encoded in a uint16. +var ErrInvalidLength = errors.New("data length is too large to be encoded in a uint16") + +// NewLenPrefixed returns a new buffer with a length prefix. The length prefix is +// 2 bytes long and is encoded in big-endian order. The length prefix is +// followed by the data. +func NewLenPrefixed(data []byte) ([]byte, error) { + length := len(data) + // Overflow Guard: If the length of the data is greater than the maximum + // value of a uint16, return an error. + if length > MaxUint16 { + return nil, ErrInvalidLength + } + buf := make([]byte, 2+len(data)) + binary.BigEndian.PutUint16(buf, uint16(len(data))) + copy(buf[2:], data) + return buf, nil +} + +func ReadLenPrefixed(maxSize uint16, r io.Reader) ([]byte, error) { + lenBuf := make([]byte, 2) + if _, err := io.ReadFull(r, lenBuf); err != nil { + return nil, err + } + if len(lenBuf) < 2 { + return nil, errors.New("buffer is too small to contain length prefix") + } + length := binary.BigEndian.Uint16(lenBuf) + + if length > maxSize { + return nil, errors.New("data length is too large to be encoded in a uint16") + } + + buf := make([]byte, length) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if len(buf) < int(length) { + return nil, errors.New("buffer is too small to contain data") + } + + return buf, nil +} diff --git a/pkg/handshake/handshake.go b/pkg/handshake/handshake.go index 9080bd5..74c86d2 100644 --- a/pkg/handshake/handshake.go +++ b/pkg/handshake/handshake.go @@ -11,22 +11,16 @@ import ( "log/slog" "koti.casa/numenor-labs/dsfx/pkg/assert" + "koti.casa/numenor-labs/dsfx/pkg/buffer" "koti.casa/numenor-labs/dsfx/pkg/crypto/encryption" "koti.casa/numenor-labs/dsfx/pkg/crypto/identity" "koti.casa/numenor-labs/dsfx/pkg/crypto/keyexchange" - "koti.casa/numenor-labs/dsfx/pkg/frame" "koti.casa/numenor-labs/dsfx/pkg/logging" ) const ( - // ECDHPublicKeySize is the size of an ECDH public key in bytes. - ECDHPublicKeySize = 97 - // ECDSAPublicKeySize is the size of an ECDSA public key in bytes. - ECDSAPublicKeySize = 222 - // BoxedClientAuthMessageSize is the size of a boxed client authentication message in bytes. - BoxedClientAuthMessageSize = 353 - // BoxedServerAuthMessageSize is the size of a boxed server authentication message in bytes. - BoxedServerAuthMessageSize = 130 + DHKeySize = 32 + IdentityKeySize = 32 ) // Initiate initiates the handshake process between the given actor @@ -56,32 +50,35 @@ func Initiate( if err != nil { return nil, err } - assert.Assert(len(ourDHKeyRaw) == ECDHPublicKeySize, "invalid dh key size") + + welcomeMessage, err := buffer.NewLenPrefixed(ourDHKeyRaw) + if err != nil { + return nil, err + } // Write the actor's public key to the connection. logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw))) - _, err = frame.New(ourDHKeyRaw).WriteTo(conn) + n, err := conn.Write(welcomeMessage) if err != nil { return nil, err } + if n != len(welcomeMessage) { + return nil, errors.New("failed to write dh key") + } // ------------------------------------------------------------------------ // Step 2: Ephemeral Key Exchange From Server // Read the remote actor's public key from the connection. logger.DebugContext(ctx, "waiting for server's dh key") - remoteDHKeyFrame := frame.New(nil) - _, err = remoteDHKeyFrame.ReadFrom(conn) + remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, err } - if len(remoteDHKeyFrame.Contents()) != ECDHPublicKeySize { - return nil, errors.New("invalid dh key size") - } // Import the remote actor's public key. logger.DebugContext(ctx, "importing server's dh key") - remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents()) + remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw) if err != nil { return nil, err } @@ -143,18 +140,24 @@ func Initiate( // Write the boxed message to the connection. logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg))) - _, err = frame.New(boxedMsg).WriteTo(conn) + boxedMsgPrepared, err := buffer.NewLenPrefixed(boxedMsg) if err != nil { return nil, err } + n, err = conn.Write(boxedMsgPrepared) + if err != nil { + return nil, err + } + if n != len(boxedMsgPrepared) { + return nil, errors.New("failed to write authentication message") + } // ------------------------------------------------------------------------ // Step 4: Server Authentication // Read the authentication message from the connection. logger.DebugContext(ctx, "waiting for server's authentication message") - authMessageFrame := frame.New(nil) - n, err := authMessageFrame.ReadFrom(conn) + authMessageBoxed, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, err } @@ -162,7 +165,7 @@ func Initiate( // Decrypt the authentication message with the derived key. logger.DebugContext(ctx, "decrypting authentication message") - plaintext, err = encryption.Decrypt(derivedKey, authMessageFrame.Contents()) + plaintext, err = encryption.Decrypt(derivedKey, authMessageBoxed) if err != nil { return nil, err } @@ -182,11 +185,17 @@ func Initiate( // Finally, we need to let the server know that the handshake is complete. logger.DebugContext(ctx, "sending handshake complete message") - handshakeCompleteMsg := []byte{0x01} - _, err = frame.New(handshakeCompleteMsg).WriteTo(conn) + handshakeCompleteMsg, err := buffer.NewLenPrefixed([]byte{0x01}) if err != nil { return nil, err } + n, err = conn.Write(handshakeCompleteMsg) + if err != nil { + return nil, err + } + if n != len(handshakeCompleteMsg) { + return nil, errors.New("failed to write handshake complete message") + } logger.DebugContext(ctx, "handshake complete") return derivedKey, nil @@ -202,15 +211,14 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat // Read the remote actor's public key from the connection. logger.DebugContext(ctx, "waiting for client's dh key") - remoteDHKeyFrame := frame.New(nil) - _, err := remoteDHKeyFrame.ReadFrom(conn) + remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, nil, err } // Import the remote actor's public key. logger.DebugContext(ctx, "importing client's dh key") - remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents()) + remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw) if err != nil { return nil, nil, err } @@ -234,22 +242,32 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat // Write the actor's public key to the connection. logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw))) - _, err = frame.New(ourDHKeyRaw).WriteTo(conn) + ourPrefixedDHKey, err := buffer.NewLenPrefixed(ourDHKeyRaw) if err != nil { return nil, nil, err } + n, err := conn.Write(ourPrefixedDHKey) + if err != nil { + return nil, nil, err + } + if n != len(ourPrefixedDHKey) { + return nil, nil, errors.New("failed to write dh key") + } // ------------------------------------------------------------------------ // Step 3: Server Authentication // Read the authentication message from the connection. logger.DebugContext(ctx, "waiting for client's authentication message") - authMessageFrame := frame.New(nil) - n, err := authMessageFrame.ReadFrom(conn) + authMessageRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, nil, err } - logger.DebugContext(ctx, "received authentication message", slog.Int("message.size", int(n))) + logger.DebugContext( + ctx, + "received authentication message", + slog.Int("message.size", len(authMessageRaw)), + ) // Decrypt the authentication message with the derived key. logger.DebugContext(ctx, "computing shared secret") @@ -259,7 +277,7 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat } logger.DebugContext(ctx, "decrypting authentication message") - plaintext, err := encryption.Decrypt(derivedKey, authMessageFrame.Contents()) + plaintext, err := encryption.Decrypt(derivedKey, authMessageRaw) if err != nil { return nil, nil, err } @@ -309,19 +327,25 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat // Send the server's signature back to the client. logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg))) - _, err = frame.New(boxedMsg).WriteTo(conn) + prefixedAuthMessage, err := buffer.NewLenPrefixed(boxedMsg) if err != nil { return nil, nil, err } + n, err = conn.Write(prefixedAuthMessage) + if err != nil { + return nil, nil, err + } + if n != len(prefixedAuthMessage) { + return nil, nil, errors.New("failed to write authentication message") + } logger.DebugContext(ctx, "waiting for handshake complete message") // Read the handshake complete message from the client. - handshakeCompleteFrame := frame.New(nil) - _, err = handshakeCompleteFrame.ReadFrom(conn) + handshakeCompleteMsg, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn) if err != nil { return nil, nil, err } - if !bytes.Equal(handshakeCompleteFrame.Contents(), []byte{0x01}) { + if !bytes.Equal(handshakeCompleteMsg, []byte{0x01}) { return nil, nil, errors.New("invalid handshake complete message") }