refactoring time

This commit is contained in:
Dustin Stiles 2025-03-09 17:09:21 -04:00
parent 9ba822d4df
commit ea027c6e4a
Signed by: duwstiles
GPG Key ID: BCD9912EC231FC87
14 changed files with 196 additions and 85 deletions

View File

@ -8,8 +8,17 @@ import (
"io"
)
// Encrypt uses AES-GCM to encrypt the given plaintext with the given key.
//16 24 32
// Encrypt uses AES-GCM to encrypt the given plaintext with the given key. The
// plaintext is sealed with a 12-byte nonce, which is prepended to the ciphertext.
func Encrypt(key, plaintext []byte) ([]byte, error) {
switch len(key) {
case 16, 24, 32: // AES-128, AES-192, AES-256
default:
return nil, errors.New("invalid key length")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
@ -30,8 +39,20 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
return ciphertext, nil
}
// Decrypt uses AES-GCM to decrypt the given ciphertext with the given key.
// Decrypt uses AES-GCM to decrypt the given ciphertext with the given key. This
// function expects that the first 12 bytes of the ciphertext are the nonce that
// was used to encrypt the plaintext.
func Decrypt(key, ciphertext []byte) ([]byte, error) {
switch len(key) {
case 16, 24, 32: // AES-128, AES-192, AES-256
default:
return nil, errors.New("invalid key length")
}
if len(ciphertext) < 12 {
return nil, errors.New("ciphertext too short")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err

View File

@ -5,16 +5,16 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"os"
)
var (
// DefaultSigningCurve is the default elliptic curve used for signing.
DefaultSigningCurve = elliptic.P384
ExportedPublicKeySize = 215
)
func LoadSigningKeyFromFile(filePath string) (*ecdsa.PrivateKey, error) {
@ -37,8 +37,8 @@ func LoadSigningKeyFromFile(filePath string) (*ecdsa.PrivateKey, error) {
return masterKey, nil
}
// GenerateSigningKey generates a new ECDSA private key for signing.
func GenerateSigningKey() (*ecdsa.PrivateKey, error) {
// Generate generates a new ECDSA private key for signing.
func Generate() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(DefaultSigningCurve(), rand.Reader)
}
@ -52,35 +52,63 @@ func Verify(pub *ecdsa.PublicKey, data, signature []byte) bool {
return ecdsa.VerifyASN1(pub, data, signature)
}
// ExportPublicSigningKey exports the public key as a byte slice.
func ExportPublicSigningKey(pub *ecdsa.PublicKey) ([]byte, error) {
data := struct {
N []byte `json:"n"`
PubX []byte `json:"pub_x"`
PubY []byte `json:"pub_y"`
}{
N: pub.Curve.Params().N.Bytes(),
PubX: pub.X.Bytes(),
PubY: pub.Y.Bytes(),
}
return json.Marshal(data)
}
// ImportPublicSigningKey imports the public key from a byte slice.
func ImportPublicSigningKey(pubBytes []byte) (*ecdsa.PublicKey, error) {
var data struct {
N []byte `json:"n"`
PubX []byte `json:"pub_x"`
PubY []byte `json:"pub_y"`
}
if err := json.Unmarshal(pubBytes, &data); err != nil {
// ExportPrivateKey exports the private key as a byte slice.
func ExportPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
der, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, err
}
params := new(ecdsa.PublicKey)
params.Curve = DefaultSigningCurve()
params.X = new(big.Int).SetBytes(data.PubX)
params.Y = new(big.Int).SetBytes(data.PubY)
return params, nil
return pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: der,
}), nil
}
// ExportPublicKey exports the public key as a byte slice.
func ExportPublicKey(key *ecdsa.PublicKey) ([]byte, error) {
der, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: der,
}), nil
}
// ImportPrivateKey imports the private key from a byte slice.
func ImportPrivateKey(keyBytes []byte) (*ecdsa.PrivateKey, error) {
block, _ := pem.Decode(keyBytes)
if block == nil {
return nil, fmt.Errorf("failed to decode private key")
}
privKey, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return privKey, nil
}
// ImportPublicKey imports the public key from a byte slice.
func ImportPublicKey(keyBytes []byte) (*ecdsa.PublicKey, error) {
block, _ := pem.Decode(keyBytes)
if block == nil {
return nil, fmt.Errorf("failed to decode public key")
}
pubKeyAny, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
pubKey, ok := pubKeyAny.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("not an ECDSA public key")
}
return pubKey, nil
}

View File

@ -0,0 +1,62 @@
package identity_test
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"log"
"testing"
"koti.casa/numenor-labs/dsfx/pkg/crypto/identity"
)
func TestImportExportPrivate(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
return
}
exported, err := identity.ExportPrivateKey(key)
if err != nil {
t.Fatalf("failed to export key: %v", err)
return
}
imported, err := identity.ImportPrivateKey(exported)
if err != nil {
t.Fatalf("failed to import key: %v", err)
return
}
if !key.Equal(imported) {
t.Fatalf("imported key does not match original")
return
}
}
func TestImportExportPublic(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
return
}
exported, err := identity.ExportPublicKey(&key.PublicKey)
if err != nil {
t.Fatalf("failed to export key: %v", err)
return
}
log.Println("keylen", len(exported))
imported, err := identity.ImportPublicKey(exported)
if err != nil {
t.Fatalf("failed to import key: %v", err)
return
}
if !key.PublicKey.Equal(imported) {
t.Fatalf("imported key does not match original")
return
}
}

View File

@ -32,12 +32,12 @@ func ComputeDHSecret(priv *ecdh.PrivateKey, pub *ecdh.PublicKey) ([]byte, error)
return key, nil
}
// ExportDHPublicKey exports the public key as a byte slice.
func ExportDHPublicKey(pub *ecdh.PublicKey) ([]byte, error) {
// ExportPublicKey exports the public key as a byte slice.
func ExportPublicKey(pub *ecdh.PublicKey) ([]byte, error) {
return pub.Bytes(), nil
}
// ImportDHPublicKey imports the public key from a byte slice.
func ImportDHPublicKey(data []byte) (*ecdh.PublicKey, error) {
// ImportPublicKey imports the public key from a byte slice.
func ImportPublicKey(data []byte) (*ecdh.PublicKey, error) {
return DefaultDHCurve().NewPublicKey(data)
}

View File

@ -17,8 +17,8 @@ type Frame struct {
contents []byte
}
// NewFrame creates a new Frame with a length prefix.
func NewFrame(contents []byte) *Frame {
// New creates a new Frame with a length prefix.
func New(contents []byte) *Frame {
return &Frame{
contents: contents,
}

View File

@ -16,7 +16,7 @@ func TestLenPrefixedWriteTo(t *testing.T) {
// When ...
n, err := frame.NewFrame(msg).WriteTo(buf)
n, err := frame.New(msg).WriteTo(buf)
// Then ...
@ -49,7 +49,7 @@ func TestLenPrefixedReadFrom(t *testing.T) {
expectedBytesRead := len(msg)
// When ...
f := frame.NewFrame(nil)
f := frame.New(nil)
n, err := f.ReadFrom(buf)
// Then ...

View File

@ -2,6 +2,6 @@ goos: linux
goarch: amd64
pkg: koti.casa/numenor-labs/dsfx/pkg/handshake
cpu: Intel(R) Core(TM) Ultra 9 185H
BenchmarkHandshake-22 530 2233341 ns/op
BenchmarkHandshake-22 530 2267189 ns/op
PASS
ok koti.casa/numenor-labs/dsfx/pkg/handshake 1.192s
ok koti.casa/numenor-labs/dsfx/pkg/handshake 1.212s

View File

@ -29,9 +29,9 @@ const (
BoxedServerAuthMessageSize = 130
)
// Handshake initiates the handshake process between the given actor
// Initiate initiates the handshake process between the given actor
// and the remote actor.
func Handshake(
func Initiate(
ctx context.Context,
conn io.ReadWriteCloser,
lPrivKey *ecdsa.PrivateKey,
@ -52,7 +52,7 @@ func Handshake(
logger.DebugContext(ctx, "exporting dh key")
// Export the public key of the actor's ECDH private key.
ourDHKeyRaw, err := keyexchange.ExportDHPublicKey(ourDHKey.PublicKey())
ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey())
if err != nil {
return nil, err
}
@ -60,7 +60,7 @@ func Handshake(
// Write the actor's public key to the connection.
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
_, err = frame.NewFrame(ourDHKeyRaw).WriteTo(conn)
_, err = frame.New(ourDHKeyRaw).WriteTo(conn)
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func Handshake(
// Read the remote actor's public key from the connection.
logger.DebugContext(ctx, "waiting for server's dh key")
remoteDHKeyFrame := frame.NewFrame(nil)
remoteDHKeyFrame := frame.New(nil)
_, err = remoteDHKeyFrame.ReadFrom(conn)
if err != nil {
return nil, err
@ -81,7 +81,7 @@ func Handshake(
// Import the remote actor's public key.
logger.DebugContext(ctx, "importing server's dh key")
remoteDHKey, err := keyexchange.ImportDHPublicKey(remoteDHKeyFrame.Contents())
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents())
if err != nil {
return nil, err
}
@ -91,13 +91,13 @@ func Handshake(
// Export the public key of the actor's signing key.
logger.DebugContext(ctx, "exporting public signing key")
ourPublicKeyRaw, err := identity.ExportPublicSigningKey(&lPrivKey.PublicKey)
ourPublicKeyRaw, err := identity.ExportPublicKey(&lPrivKey.PublicKey)
if err != nil {
return nil, err
}
logger.DebugContext(ctx, "exporting remote public signing key")
remotePublicKeyRaw, err := identity.ExportPublicSigningKey(rPubKey)
remotePublicKeyRaw, err := identity.ExportPublicKey(rPubKey)
if err != nil {
return nil, err
}
@ -143,7 +143,7 @@ func Handshake(
// Write the boxed message to the connection.
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
_, err = frame.NewFrame(boxedMsg).WriteTo(conn)
_, err = frame.New(boxedMsg).WriteTo(conn)
if err != nil {
return nil, err
}
@ -153,7 +153,7 @@ func Handshake(
// Read the authentication message from the connection.
logger.DebugContext(ctx, "waiting for server's authentication message")
authMessageFrame := frame.NewFrame(nil)
authMessageFrame := frame.New(nil)
n, err := authMessageFrame.ReadFrom(conn)
if err != nil {
return nil, err
@ -170,7 +170,7 @@ func Handshake(
// The server authentication is just verifying the signature it created of
// the client authentication message.
logger.DebugContext(ctx, "importing server's public signing key")
remotePublicKey, err := identity.ImportPublicSigningKey(remotePublicKeyRaw)
remotePublicKey, err := identity.ImportPublicKey(remotePublicKeyRaw)
if err != nil {
return nil, err
}
@ -183,7 +183,7 @@ func Handshake(
// Finally, we need to let the server know that the handshake is complete.
logger.DebugContext(ctx, "sending handshake complete message")
handshakeCompleteMsg := []byte{0x01}
_, err = frame.NewFrame(handshakeCompleteMsg).WriteTo(conn)
_, err = frame.New(handshakeCompleteMsg).WriteTo(conn)
if err != nil {
return nil, err
}
@ -192,9 +192,9 @@ func Handshake(
return derivedKey, nil
}
// AcceptHandshake accepts a handshake from the given actor and connection. It
// Accept accepts a handshake from the given actor and connection. It
// returns the shared secret between the actor and the remote actor.
func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.PrivateKey) (*ecdsa.PublicKey, []byte, error) {
func Accept(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecdsa.PrivateKey) (*ecdsa.PublicKey, []byte, error) {
logger := logging.FromContext(ctx).WithGroup("handshake")
// ------------------------------------------------------------------------
@ -202,7 +202,7 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
// Read the remote actor's public key from the connection.
logger.DebugContext(ctx, "waiting for client's dh key")
remoteDHKeyFrame := frame.NewFrame(nil)
remoteDHKeyFrame := frame.New(nil)
_, err := remoteDHKeyFrame.ReadFrom(conn)
if err != nil {
return nil, nil, err
@ -210,7 +210,7 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
// Import the remote actor's public key.
logger.DebugContext(ctx, "importing client's dh key")
remoteDHKey, err := keyexchange.ImportDHPublicKey(remoteDHKeyFrame.Contents())
remoteDHKey, err := keyexchange.ImportPublicKey(remoteDHKeyFrame.Contents())
if err != nil {
return nil, nil, err
}
@ -227,14 +227,14 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
// Export the public key of the actor's ECDH private key.
logger.DebugContext(ctx, "exporting dh key")
ourDHKeyRaw, err := keyexchange.ExportDHPublicKey(ourDHKey.PublicKey())
ourDHKeyRaw, err := keyexchange.ExportPublicKey(ourDHKey.PublicKey())
if err != nil {
return nil, nil, err
}
// Write the actor's public key to the connection.
logger.DebugContext(ctx, "sending dh key", slog.Int("key.size", len(ourDHKeyRaw)))
_, err = frame.NewFrame(ourDHKeyRaw).WriteTo(conn)
_, err = frame.New(ourDHKeyRaw).WriteTo(conn)
if err != nil {
return nil, nil, err
}
@ -244,7 +244,7 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
// Read the authentication message from the connection.
logger.DebugContext(ctx, "waiting for client's authentication message")
authMessageFrame := frame.NewFrame(nil)
authMessageFrame := frame.New(nil)
n, err := authMessageFrame.ReadFrom(conn)
if err != nil {
return nil, nil, err
@ -264,12 +264,12 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
return nil, nil, err
}
clientPublicKeyRaw := plaintext[:222]
signature := plaintext[222:]
clientPublicKeyRaw := plaintext[:identity.ExportedPublicKeySize]
signature := plaintext[identity.ExportedPublicKeySize:]
// Verify the client's public key and signature.
logger.DebugContext(ctx, "importing client's public signing key")
clientPublicKey, err := identity.ImportPublicSigningKey(clientPublicKeyRaw)
clientPublicKey, err := identity.ImportPublicKey(clientPublicKeyRaw)
if err != nil {
return nil, nil, err
}
@ -309,14 +309,14 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
// Send the server's signature back to the client.
logger.DebugContext(ctx, "sending authentication message", slog.Int("message.size", len(boxedMsg)))
_, err = frame.NewFrame(boxedMsg).WriteTo(conn)
_, err = frame.New(boxedMsg).WriteTo(conn)
if err != nil {
return nil, nil, err
}
logger.DebugContext(ctx, "waiting for handshake complete message")
// Read the handshake complete message from the client.
handshakeCompleteFrame := frame.NewFrame(nil)
handshakeCompleteFrame := frame.New(nil)
_, err = handshakeCompleteFrame.ReadFrom(conn)
if err != nil {
return nil, nil, err
@ -333,11 +333,11 @@ func AcceptHandshake(ctx context.Context, conn io.ReadWriteCloser, lPrivKey *ecd
}
func buildMessage(clientPubKey *ecdh.PublicKey, serverPubKey *ecdh.PublicKey) ([]byte, error) {
clientPubKeyRaw, err := keyexchange.ExportDHPublicKey(clientPubKey)
clientPubKeyRaw, err := keyexchange.ExportPublicKey(clientPubKey)
if err != nil {
return nil, err
}
serverPubKeyRaw, err := keyexchange.ExportDHPublicKey(serverPubKey)
serverPubKeyRaw, err := keyexchange.ExportPublicKey(serverPubKey)
if err != nil {
return nil, err
}

View File

@ -18,9 +18,9 @@ func TestHandshake(t *testing.T) {
ctx := context.Background()
// alice, represented by an ecdsa key pair.
alice, _ := identity.GenerateSigningKey()
alice, _ := identity.Generate()
// bob, also represented by an ecdsa key pair.
bob, _ := identity.GenerateSigningKey()
bob, _ := identity.Generate()
var (
// the secret that alice should arrive at on her own
@ -47,11 +47,11 @@ func TestHandshake(t *testing.T) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
aliceSecret, aliceErr = handshake.Handshake(ctx, client, alice, &bob.PublicKey)
aliceSecret, aliceErr = handshake.Initiate(ctx, client, alice, &bob.PublicKey)
wg.Done()
}()
go func() {
discoveredAlicePublicKey, bobSecret, bobErr = handshake.AcceptHandshake(ctx, server, bob)
discoveredAlicePublicKey, bobSecret, bobErr = handshake.Accept(ctx, server, bob)
wg.Done()
}()
wg.Wait()
@ -94,9 +94,9 @@ func runSimulation() error {
ctx := context.Background()
// alice, represented by an ecdsa key pair.
alice, _ := identity.GenerateSigningKey()
alice, _ := identity.Generate()
// bob, also represented by an ecdsa key pair.
bob, _ := identity.GenerateSigningKey()
bob, _ := identity.Generate()
var (
// the secret that alice should arrive at on her own
@ -123,11 +123,11 @@ func runSimulation() error {
var wg sync.WaitGroup
wg.Add(2)
go func() {
_, aliceErr = handshake.Handshake(ctx, client, alice, &bob.PublicKey)
_, aliceErr = handshake.Initiate(ctx, client, alice, &bob.PublicKey)
wg.Done()
}()
go func() {
_, _, bobErr = handshake.AcceptHandshake(ctx, server, bob)
_, _, bobErr = handshake.Accept(ctx, server, bob)
wg.Done()
}()
wg.Wait()

View File

@ -60,7 +60,7 @@ func ParseAddr(addrRaw string) (*Addr, error) {
return nil, ErrInvalidFormat
}
publicKey, err := identity.ImportPublicSigningKey(publicKeyBytes)
publicKey, err := identity.ImportPublicKey(publicKeyBytes)
if err != nil {
return nil, ErrInvalidFormat
}
@ -86,7 +86,7 @@ func (a *Addr) Network() string {
// String implements net.Addr.
func (a *Addr) String() string {
exported, _ := identity.ExportPublicSigningKey(a.publicKey)
exported, _ := identity.ExportPublicKey(a.publicKey)
exportedBase64 := base64.StdEncoding.EncodeToString(exported)
return fmt.Sprintf("%s://%s:%d#%s", a.network, a.ip, a.port, exportedBase64)
}

View File

@ -28,7 +28,7 @@ func NewConn(conn *net.TCPConn, sessionKey []byte, localIdentity, remoteIdentity
// The ciphertext that is actually transferred over the network is larger, so you
// should not rely on this number as an indication of network metrics.
func (c *Conn) Read(b []byte) (int, error) {
f := frame.NewFrame(nil)
f := frame.New(nil)
_, err := f.ReadFrom(c.conn)
if err != nil {
return 0, err
@ -47,7 +47,7 @@ func (c *Conn) Write(b []byte) (int, error) {
if err != nil {
return 0, err
}
_, err = frame.NewFrame(ciphertext).WriteTo(c.conn)
_, err = frame.New(ciphertext).WriteTo(c.conn)
if err != nil {
return 0, err
}

View File

@ -25,7 +25,7 @@ func (l *Listener) Accept() (net.Conn, error) {
return nil, err
}
clientIdentity, sessionKey, err := handshake.AcceptHandshake(ctx, conn, l.identity)
clientIdentity, sessionKey, err := handshake.Accept(ctx, conn, l.identity)
if err != nil {
return nil, err
}

View File

@ -21,7 +21,7 @@ func Dial(
return nil, err
}
sessionKey, err := handshake.Handshake(ctx, conn, identity, raddr.PublicKey())
sessionKey, err := handshake.Initiate(ctx, conn, identity, raddr.PublicKey())
if err != nil {
return nil, err
}

View File

@ -10,7 +10,7 @@ import (
)
func main() {
key, err := identity.GenerateSigningKey()
key, err := identity.Generate()
if err != nil {
panic(err)
}