package network

import (
	"crypto/ed25519"
	"net"
	"time"

	"koti.casa/numenor-labs/dsfx/internal/lib/crypto/encryption"
	"koti.casa/numenor-labs/dsfx/internal/lib/frame"
)

// Conn is a wrapper around net.TCPConn that encrypts and decrypts data as it is
// read and written. Conn is a valid implementation of net.Conn.
type Conn struct {
	conn           *net.TCPConn
	sessionKey     []byte
	localIdentity  ed25519.PublicKey
	remoteIdentity ed25519.PublicKey
}

// NewConn creates a new Conn.
func NewConn(conn *net.TCPConn, sessionKey []byte, localIdentity, remoteIdentity ed25519.PublicKey) *Conn {
	return &Conn{conn, sessionKey, localIdentity, remoteIdentity}
}

// Read implements io.Reader.
// Please note that number of bytes returned is the length of the decrypted data.
// The ciphertext that is actually transferred over the network is larger, so you
// should not rely on this number as an indication of network metrics.
func (c *Conn) Read(b []byte) (int, error) {
	f := frame.New(nil)
	_, err := f.ReadFrom(c.conn)
	if err != nil {
		return 0, err
	}
	plaintext, err := encryption.Decrypt(c.sessionKey, f.Contents())
	if err != nil {
		return 0, err
	}
	copy(b, plaintext)
	return len(plaintext), nil
}

// Write implements io.Writer.
func (c *Conn) Write(b []byte) (int, error) {
	ciphertext, err := encryption.Encrypt(c.sessionKey, b)
	if err != nil {
		return 0, err
	}
	_, err = frame.New(ciphertext).WriteTo(c.conn)
	if err != nil {
		return 0, err
	}
	return len(b), nil
}

// Close implements io.Closer.
func (c *Conn) Close() error {
	// x-security: clear key to mitigate memory attacks
	c.sessionKey = nil
	c.localIdentity = nil
	c.remoteIdentity = nil

	return c.conn.Close()
}

// LocalAddr implements net.Conn.
func (c *Conn) LocalAddr() net.Addr {
	raddr := c.conn.RemoteAddr().(*net.TCPAddr)
	return NewAddr(raddr.IP, raddr.Port, c.localIdentity)
}

// RemoteAddr implements net.Conn.
func (c *Conn) RemoteAddr() net.Addr {
	raddr := c.conn.RemoteAddr().(*net.TCPAddr)
	return NewAddr(raddr.IP, raddr.Port, c.remoteIdentity)
}

// SetDeadline implements net.Conn.
func (c *Conn) SetDeadline(t time.Time) error {
	return c.conn.SetDeadline(t)
}

// SetReadDeadline implements net.Conn.
func (c *Conn) SetReadDeadline(t time.Time) error {
	return c.conn.SetReadDeadline(t)
}

// SetWriteDeadline implements net.Conn.
func (c *Conn) SetWriteDeadline(t time.Time) error {
	return c.conn.SetWriteDeadline(t)
}