mirror of
https://git.numenor-labs.us/dsfx.git
synced 2025-04-29 16:20:34 +00:00
use buffer pkg over frame
This commit is contained in:
parent
e674be0399
commit
9eaf838069
57
pkg/buffer/lenprefixed.go
Normal file
57
pkg/buffer/lenprefixed.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package buffer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxUint16 is the maximum value of a uint16. It is used to check if the
|
||||||
|
// length of the data is too large to be encoded in a uint16.
|
||||||
|
const MaxUint16 = 0xFFFF
|
||||||
|
|
||||||
|
// ErrInvalidLength is the error message returned when the length of the data
|
||||||
|
// is too large to be encoded in a uint16.
|
||||||
|
var ErrInvalidLength = errors.New("data length is too large to be encoded in a uint16")
|
||||||
|
|
||||||
|
// NewLenPrefixed returns a new buffer with a length prefix. The length prefix is
|
||||||
|
// 2 bytes long and is encoded in big-endian order. The length prefix is
|
||||||
|
// followed by the data.
|
||||||
|
func NewLenPrefixed(data []byte) ([]byte, error) {
|
||||||
|
length := len(data)
|
||||||
|
// Overflow Guard: If the length of the data is greater than the maximum
|
||||||
|
// value of a uint16, return an error.
|
||||||
|
if length > MaxUint16 {
|
||||||
|
return nil, ErrInvalidLength
|
||||||
|
}
|
||||||
|
buf := make([]byte, 2+len(data))
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(len(data)))
|
||||||
|
copy(buf[2:], data)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadLenPrefixed(maxSize uint16, r io.Reader) ([]byte, error) {
|
||||||
|
lenBuf := make([]byte, 2)
|
||||||
|
if _, err := io.ReadFull(r, lenBuf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(lenBuf) < 2 {
|
||||||
|
return nil, errors.New("buffer is too small to contain length prefix")
|
||||||
|
}
|
||||||
|
length := binary.BigEndian.Uint16(lenBuf)
|
||||||
|
|
||||||
|
if length > maxSize {
|
||||||
|
return nil, errors.New("data length is too large to be encoded in a uint16")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(buf) < int(length) {
|
||||||
|
return nil, errors.New("buffer is too small to contain data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
@ -11,22 +11,16 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"koti.casa/numenor-labs/dsfx/pkg/assert"
|
"koti.casa/numenor-labs/dsfx/pkg/assert"
|
||||||
|
"koti.casa/numenor-labs/dsfx/pkg/buffer"
|
||||||
"koti.casa/numenor-labs/dsfx/pkg/crypto/encryption"
|
"koti.casa/numenor-labs/dsfx/pkg/crypto/encryption"
|
||||||
"koti.casa/numenor-labs/dsfx/pkg/crypto/identity"
|
"koti.casa/numenor-labs/dsfx/pkg/crypto/identity"
|
||||||
"koti.casa/numenor-labs/dsfx/pkg/crypto/keyexchange"
|
"koti.casa/numenor-labs/dsfx/pkg/crypto/keyexchange"
|
||||||
"koti.casa/numenor-labs/dsfx/pkg/frame"
|
|
||||||
"koti.casa/numenor-labs/dsfx/pkg/logging"
|
"koti.casa/numenor-labs/dsfx/pkg/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ECDHPublicKeySize is the size of an ECDH public key in bytes.
|
DHKeySize = 32
|
||||||
ECDHPublicKeySize = 97
|
IdentityKeySize = 32
|
||||||
// ECDSAPublicKeySize is the size of an ECDSA public key in bytes.
|
|
||||||
ECDSAPublicKeySize = 222
|
|
||||||
// BoxedClientAuthMessageSize is the size of a boxed client authentication message in bytes.
|
|
||||||
BoxedClientAuthMessageSize = 353
|
|
||||||
// BoxedServerAuthMessageSize is the size of a boxed server authentication message in bytes.
|
|
||||||
BoxedServerAuthMessageSize = 130
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initiate initiates the handshake process between the given actor
|
// Initiate initiates the handshake process between the given actor
|
||||||
@ -56,32 +50,35 @@ func Initiate(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
assert.Assert(len(ourDHKeyRaw) == ECDHPublicKeySize, "invalid dh key size")
|
|
||||||
|
welcomeMessage, err := buffer.NewLenPrefixed(ourDHKeyRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Write the actor's public key to the connection.
|
// Write the actor's public key to the connection.
|
||||||
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
|
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
|
||||||
_, err = frame.New(ourDHKeyRaw).WriteTo(conn)
|
n, err := conn.Write(welcomeMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if n != len(welcomeMessage) {
|
||||||
|
return nil, errors.New("failed to write dh key")
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------------
|
// ------------------------------------------------------------------------
|
||||||
// Step 2: Ephemeral Key Exchange From Server
|
// Step 2: Ephemeral Key Exchange From Server
|
||||||
|
|
||||||
// Read the remote actor's public key from the connection.
|
// Read the remote actor's public key from the connection.
|
||||||
logger.DebugContext(ctx, "waiting for server's dh key")
|
logger.DebugContext(ctx, "waiting for server's dh key")
|
||||||
remoteDHKeyFrame := frame.New(nil)
|
remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
|
||||||
_, err = remoteDHKeyFrame.ReadFrom(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(remoteDHKeyFrame.Contents()) != ECDHPublicKeySize {
|
|
||||||
return nil, errors.New("invalid dh key size")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Import the remote actor's public key.
|
// Import the remote actor's public key.
|
||||||
logger.DebugContext(ctx, "importing server's dh key")
|
logger.DebugContext(ctx, "importing server's dh key")
|
||||||
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents())
|
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -143,18 +140,24 @@ func Initiate(
|
|||||||
|
|
||||||
// Write the boxed message to the connection.
|
// Write the boxed message to the connection.
|
||||||
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
|
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
|
||||||
_, err = frame.New(boxedMsg).WriteTo(conn)
|
boxedMsgPrepared, err := buffer.NewLenPrefixed(boxedMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
n, err = conn.Write(boxedMsgPrepared)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if n != len(boxedMsgPrepared) {
|
||||||
|
return nil, errors.New("failed to write authentication message")
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------------
|
// ------------------------------------------------------------------------
|
||||||
// Step 4: Server Authentication
|
// Step 4: Server Authentication
|
||||||
|
|
||||||
// Read the authentication message from the connection.
|
// Read the authentication message from the connection.
|
||||||
logger.DebugContext(ctx, "waiting for server's authentication message")
|
logger.DebugContext(ctx, "waiting for server's authentication message")
|
||||||
authMessageFrame := frame.New(nil)
|
authMessageBoxed, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
|
||||||
n, err := authMessageFrame.ReadFrom(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -162,7 +165,7 @@ func Initiate(
|
|||||||
|
|
||||||
// Decrypt the authentication message with the derived key.
|
// Decrypt the authentication message with the derived key.
|
||||||
logger.DebugContext(ctx, "decrypting authentication message")
|
logger.DebugContext(ctx, "decrypting authentication message")
|
||||||
plaintext, err = encryption.Decrypt(derivedKey, authMessageFrame.Contents())
|
plaintext, err = encryption.Decrypt(derivedKey, authMessageBoxed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -182,11 +185,17 @@ func Initiate(
|
|||||||
|
|
||||||
// Finally, we need to let the server know that the handshake is complete.
|
// Finally, we need to let the server know that the handshake is complete.
|
||||||
logger.DebugContext(ctx, "sending handshake complete message")
|
logger.DebugContext(ctx, "sending handshake complete message")
|
||||||
handshakeCompleteMsg := []byte{0x01}
|
handshakeCompleteMsg, err := buffer.NewLenPrefixed([]byte{0x01})
|
||||||
_, err = frame.New(handshakeCompleteMsg).WriteTo(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
n, err = conn.Write(handshakeCompleteMsg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if n != len(handshakeCompleteMsg) {
|
||||||
|
return nil, errors.New("failed to write handshake complete message")
|
||||||
|
}
|
||||||
|
|
||||||
logger.DebugContext(ctx, "handshake complete")
|
logger.DebugContext(ctx, "handshake complete")
|
||||||
return derivedKey, nil
|
return derivedKey, nil
|
||||||
@ -202,15 +211,14 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat
|
|||||||
|
|
||||||
// Read the remote actor's public key from the connection.
|
// Read the remote actor's public key from the connection.
|
||||||
logger.DebugContext(ctx, "waiting for client's dh key")
|
logger.DebugContext(ctx, "waiting for client's dh key")
|
||||||
remoteDHKeyFrame := frame.New(nil)
|
remoteDHKeyRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
|
||||||
_, err := remoteDHKeyFrame.ReadFrom(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Import the remote actor's public key.
|
// Import the remote actor's public key.
|
||||||
logger.DebugContext(ctx, "importing client's dh key")
|
logger.DebugContext(ctx, "importing client's dh key")
|
||||||
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents())
|
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@ -234,22 +242,32 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat
|
|||||||
|
|
||||||
// Write the actor's public key to the connection.
|
// Write the actor's public key to the connection.
|
||||||
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
|
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
|
||||||
_, err = frame.New(ourDHKeyRaw).WriteTo(conn)
|
ourPrefixedDHKey, err := buffer.NewLenPrefixed(ourDHKeyRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
n, err := conn.Write(ourPrefixedDHKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if n != len(ourPrefixedDHKey) {
|
||||||
|
return nil, nil, errors.New("failed to write dh key")
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------------
|
// ------------------------------------------------------------------------
|
||||||
// Step 3: Server Authentication
|
// Step 3: Server Authentication
|
||||||
|
|
||||||
// Read the authentication message from the connection.
|
// Read the authentication message from the connection.
|
||||||
logger.DebugContext(ctx, "waiting for client's authentication message")
|
logger.DebugContext(ctx, "waiting for client's authentication message")
|
||||||
authMessageFrame := frame.New(nil)
|
authMessageRaw, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
|
||||||
n, err := authMessageFrame.ReadFrom(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "received authentication message", slog.Int("message.size", int(n)))
|
logger.DebugContext(
|
||||||
|
ctx,
|
||||||
|
"received authentication message",
|
||||||
|
slog.Int("message.size", len(authMessageRaw)),
|
||||||
|
)
|
||||||
|
|
||||||
// Decrypt the authentication message with the derived key.
|
// Decrypt the authentication message with the derived key.
|
||||||
logger.DebugContext(ctx, "computing shared secret")
|
logger.DebugContext(ctx, "computing shared secret")
|
||||||
@ -259,7 +277,7 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.DebugContext(ctx, "decrypting authentication message")
|
logger.DebugContext(ctx, "decrypting authentication message")
|
||||||
plaintext, err := encryption.Decrypt(derivedKey, authMessageFrame.Contents())
|
plaintext, err := encryption.Decrypt(derivedKey, authMessageRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@ -309,19 +327,25 @@ func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.Privat
|
|||||||
|
|
||||||
// Send the server's signature back to the client.
|
// Send the server's signature back to the client.
|
||||||
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
|
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
|
||||||
_, err = frame.New(boxedMsg).WriteTo(conn)
|
prefixedAuthMessage, err := buffer.NewLenPrefixed(boxedMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
n, err = conn.Write(prefixedAuthMessage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if n != len(prefixedAuthMessage) {
|
||||||
|
return nil, nil, errors.New("failed to write authentication message")
|
||||||
|
}
|
||||||
|
|
||||||
logger.DebugContext(ctx, "waiting for handshake complete message")
|
logger.DebugContext(ctx, "waiting for handshake complete message")
|
||||||
// Read the handshake complete message from the client.
|
// Read the handshake complete message from the client.
|
||||||
handshakeCompleteFrame := frame.New(nil)
|
handshakeCompleteMsg, err := buffer.ReadLenPrefixed(buffer.MaxUint16, conn)
|
||||||
_, err = handshakeCompleteFrame.ReadFrom(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if !bytes.Equal(handshakeCompleteFrame.Contents(), []byte{0x01}) {
|
if !bytes.Equal(handshakeCompleteMsg, []byte{0x01}) {
|
||||||
return nil, nil, errors.New("invalid handshake complete message")
|
return nil, nil, errors.New("invalid handshake complete message")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user