mirror of
https://git.numenor-labs.us/dsfx.git
synced 2025-04-29 16:20:34 +00:00
92 lines
2.3 KiB
Go
92 lines
2.3 KiB
Go
|
package dnet
|
||
|
|
||
|
import (
|
||
|
"crypto/ecdsa"
|
||
|
"net"
|
||
|
"time"
|
||
|
|
||
|
"koti.casa/numenor-labs/dsfx/shared/dcrypto"
|
||
|
)
|
||
|
|
||
|
// Conn is a wrapper around net.TCPConn that encrypts and decrypts data as it is
|
||
|
// read and written. Conn is a valid implementation of net.Conn.
|
||
|
type Conn struct {
|
||
|
conn *net.TCPConn
|
||
|
sessionKey []byte
|
||
|
localIdentity *ecdsa.PublicKey
|
||
|
remoteIdentity *ecdsa.PublicKey
|
||
|
}
|
||
|
|
||
|
// NewConn creates a new Conn.
|
||
|
func NewConn(conn *net.TCPConn, sessionKey []byte, localIdentity, remoteIdentity *ecdsa.PublicKey) *Conn {
|
||
|
return &Conn{conn, sessionKey, localIdentity, remoteIdentity}
|
||
|
}
|
||
|
|
||
|
// Read implements io.Reader.
|
||
|
// Please note that number of bytes returned is the length of the decrypted data.
|
||
|
// The ciphertext that is actually transferred over the network is larger, so you
|
||
|
// should not rely on this number as an indication of network metrics.
|
||
|
func (c *Conn) Read(b []byte) (int, error) {
|
||
|
f := NewFrame(nil)
|
||
|
_, err := f.ReadFrom(c.conn)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
plaintext, err := dcrypto.Decrypt(c.sessionKey, f.Contents())
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
copy(b, plaintext)
|
||
|
return len(plaintext), nil
|
||
|
}
|
||
|
|
||
|
// Write implements io.Writer.
|
||
|
func (c *Conn) Write(b []byte) (int, error) {
|
||
|
ciphertext, err := dcrypto.Encrypt(c.sessionKey, b)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
_, err = NewFrame(ciphertext).WriteTo(c.conn)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return len(b), nil
|
||
|
}
|
||
|
|
||
|
// Close implements io.Closer.
|
||
|
func (c *Conn) Close() error {
|
||
|
// x-security: clear key to mitigate memory attacks
|
||
|
c.sessionKey = nil
|
||
|
c.localIdentity = nil
|
||
|
c.remoteIdentity = nil
|
||
|
|
||
|
return c.conn.Close()
|
||
|
}
|
||
|
|
||
|
// LocalAddr implements net.Conn.
|
||
|
func (c *Conn) LocalAddr() net.Addr {
|
||
|
raddr := c.conn.RemoteAddr().(*net.TCPAddr)
|
||
|
return NewAddr(raddr.IP, raddr.Port, c.localIdentity)
|
||
|
}
|
||
|
|
||
|
// RemoteAddr implements net.Conn.
|
||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||
|
raddr := c.conn.RemoteAddr().(*net.TCPAddr)
|
||
|
return NewAddr(raddr.IP, raddr.Port, c.remoteIdentity)
|
||
|
}
|
||
|
|
||
|
// SetDeadline implements net.Conn.
|
||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||
|
return c.conn.SetDeadline(t)
|
||
|
}
|
||
|
|
||
|
// SetReadDeadline implements net.Conn.
|
||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||
|
return c.conn.SetReadDeadline(t)
|
||
|
}
|
||
|
|
||
|
// SetWriteDeadline implements net.Conn.
|
||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||
|
return c.conn.SetWriteDeadline(t)
|
||
|
}
|