package network import ( "crypto/ed25519" "net" "time" "numenor-labs.us/dsfx/dsfx/internal/lib/crypto/encryption" "numenor-labs.us/dsfx/dsfx/internal/lib/frame" ) // Conn is a wrapper around net.TCPConn that encrypts and decrypts data as it is // read and written. Conn is a valid implementation of net.Conn. type Conn struct { conn *net.TCPConn sessionKey []byte localIdentity ed25519.PublicKey remoteIdentity ed25519.PublicKey } // NewConn creates a new Conn. func NewConn(conn *net.TCPConn, sessionKey []byte, localIdentity, remoteIdentity ed25519.PublicKey) *Conn { return &Conn{conn, sessionKey, localIdentity, remoteIdentity} } // Read implements io.Reader. // Please note that number of bytes returned is the length of the decrypted data. // The ciphertext that is actually transferred over the network is larger, so you // should not rely on this number as an indication of network metrics. func (c *Conn) Read(b []byte) (int, error) { f := frame.New(nil) _, err := f.ReadFrom(c.conn) if err != nil { return 0, err } plaintext, err := encryption.Decrypt(c.sessionKey, f.Contents()) if err != nil { return 0, err } copy(b, plaintext) return len(plaintext), nil } // Write implements io.Writer. func (c *Conn) Write(b []byte) (int, error) { ciphertext, err := encryption.Encrypt(c.sessionKey, b) if err != nil { return 0, err } _, err = frame.New(ciphertext).WriteTo(c.conn) if err != nil { return 0, err } return len(b), nil } // Close implements io.Closer. func (c *Conn) Close() error { // x-security: clear key to mitigate memory attacks c.sessionKey = nil c.localIdentity = nil c.remoteIdentity = nil return c.conn.Close() } // LocalAddr implements net.Conn. func (c *Conn) LocalAddr() net.Addr { raddr := c.conn.RemoteAddr().(*net.TCPAddr) return NewAddr(raddr.IP, raddr.Port, c.localIdentity) } // RemoteAddr implements net.Conn. func (c *Conn) RemoteAddr() net.Addr { raddr := c.conn.RemoteAddr().(*net.TCPAddr) return NewAddr(raddr.IP, raddr.Port, c.remoteIdentity) } // SetDeadline implements net.Conn. func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } // SetReadDeadline implements net.Conn. func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } // SetWriteDeadline implements net.Conn. func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }