diff --git a/internal/client/client.go b/internal/client/client.go index da33501..1bc5c48 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -3,8 +3,11 @@ package client import ( "context" "crypto/ed25519" + "crypto/rand" "encoding/base64" "errors" + "fmt" + "io" "log/slog" "net" @@ -19,16 +22,12 @@ import ( const ( // DefaultConfigDir is the default directory for the dsfxctl configuration. DefaultConfigDir = "/etc/dsfxctl" - // DefaultHost is the default host for the dsfxctl application. - DefaultHost = "0.0.0.0" ) // Conf holds the configuration for the dsfxctl application. type Conf struct { // Directories ConfigDir string - // Networking - Host string } func loadConfigFromSystem(sys system.System) Conf { @@ -39,11 +38,6 @@ func loadConfigFromSystem(sys system.System) Conf { c.ConfigDir = DefaultConfigDir } - c.Host = sys.GetEnv("DSFXCTL_HOST") - if c.Host == "" { - c.Host = DefaultHost - } - return c } @@ -81,10 +75,16 @@ func (a *Client) Run(ctx context.Context) error { opts := &slog.HandlerOptions{ AddSource: false, - Level: slog.LevelDebug, + Level: slog.LevelInfo, } logger := slog.New(slog.NewTextHandler(a.system.Stdout(), opts)) + err := a.disk.MkdirAll(a.conf.ConfigDir, 0755) + if err != nil { + logger.ErrorContext(ctx, "failed to create config directory", slog.Any("error", err)) + return err + } + // Everything in the application will attempt to use the logger in stored in // the context, but we also set the default with slog as a fallback. In cases // where the context is not available, or the context is not a child of the @@ -92,43 +92,11 @@ func (a *Client) Run(ctx context.Context) error { slog.SetDefault(logger) ctx = logging.WithContext(ctx, logger) - keyFile, err := a.configScope.Open("key") + id, err := a.loadIdentity() if err != nil { - logger.WarnContext(ctx, "key file is missing, reinitializing") - logger.WarnContext(ctx, "if this is your first time running dsfxctl, you can ignore this") - } - if keyFile == nil { - logger.InfoContext(ctx, "generating new key") - keyFile, err = a.configScope.Create("key") - if err != nil { - logger.ErrorContext(ctx, "failed to create key file", slog.Any("error", err)) - return err - } - privkey, err := identity.Generate() - if err != nil { - logger.ErrorContext(ctx, "failed to generate key", slog.Any("error", err)) - return err - } - - _, err = keyFile.Write([]byte(base64.StdEncoding.EncodeToString(privkey))) - if err != nil { - logger.ErrorContext(ctx, "failed to write key", slog.Any("error", err)) - return err - } - } - defer keyFile.Close() - - keyRaw := make([]byte, ed25519.PrivateKeySize) - n, err := keyFile.Read(keyRaw) - if err != nil { - logger.ErrorContext(ctx, "failed to read key file", slog.Any("error", err)) + logger.ErrorContext(ctx, "failed to load identity", slog.Any("error", err)) return err } - if n != ed25519.PrivateKeySize { - logger.ErrorContext(ctx, "key file is not the correct size", slog.Int("size", n)) - return err - } - id := ed25519.PrivateKey(keyRaw) laddr := network.NewAddr( net.ParseIP("0.0.0.0"), @@ -171,3 +139,75 @@ func testConnection(ctx context.Context, id ed25519.PrivateKey, laddr *network.A } defer conn.Close() } + +// loadIdentity ... +func (c *Client) loadIdentity() (ed25519.PrivateKey, error) { + hasKeyFile, err := c.hasKeyFile() + if err != nil { + return nil, fmt.Errorf("failed to check for admins file: %w", err) + } + + if !hasKeyFile { + if err := c.createKeyFile(); err != nil { + return nil, fmt.Errorf("failed to create admins file: %w", err) + } + } + + return c.readKeyFile() +} + +// hasKeyFile ... +func (c *Client) hasKeyFile() (bool, error) { + f, err := c.configScope.Open("ed25519.key") + if err != nil { + return false, nil + } + defer f.Close() + + return true, nil +} + +// createKeyFile ... +func (c *Client) createKeyFile() error { + f, err := c.configScope.Create("ed25519.key") + if err != nil { + return err + } + defer f.Close() + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return err + } + + _, err = f.Write([]byte(base64.StdEncoding.EncodeToString(privateKey))) + if err != nil { + return err + } + + return nil +} + +// readKeyFile ... +func (c *Client) readKeyFile() (ed25519.PrivateKey, error) { + f, err := c.configScope.Open("ed25519.key") + if err != nil { + return nil, err + } + defer f.Close() + + keyRawBase64, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + keyRaw, err := base64.StdEncoding.DecodeString(string(keyRawBase64)) + if err != nil { + return nil, fmt.Errorf("failed to decode key: %w", err) + } + if len(keyRaw) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("key file is not the correct size: %d", len(keyRaw)) + } + + return ed25519.PrivateKey(keyRaw), nil +}