package age

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"io"
	"os"
	"os/exec"
	"path/filepath"
	"runtime"
	"strings"

	"filippo.io/age"
	"filippo.io/age/agessh"
	"filippo.io/age/armor"
	"filippo.io/age/plugin"
	"github.com/sirupsen/logrus"
	"golang.org/x/crypto/ssh"

	"github.com/getsops/sops/v3/logging"
	"github.com/google/shlex"
)

const (
	// SopsAgeKeyEnv can be set as an environment variable with a string list
	// of age keys as value.
	SopsAgeKeyEnv = "SOPS_AGE_KEY"
	// SopsAgeKeyFileEnv can be set as an environment variable pointing to an
	// age keys file.
	SopsAgeKeyFileEnv = "SOPS_AGE_KEY_FILE"
	// SopsAgeKeyCmdEnv can be set as an environment variable with a command
	// to execute that returns the age keys.
	SopsAgeKeyCmdEnv = "SOPS_AGE_KEY_CMD"
	// SopsAgeRecipientEnv is passed as an environment variable to the command
	// set in SopsAgeKeyCmdEnv and contains the Bech32-encoded age public key
	// for which the private key should be returned.
	SopsAgeRecipientEnv = "SOPS_AGE_RECIPIENT"
	// SopsAgeSshPrivateKeyCmdEnv can be set as an environment variable with a command
	// to execute that returns the private SSH key.
	SopsAgeSshPrivateKeyCmdEnv = "SOPS_AGE_SSH_PRIVATE_KEY_CMD"
	// SopsAgeSshPrivateKeyFileEnv can be set as an environment variable pointing to
	// a private SSH key file.
	SopsAgeSshPrivateKeyFileEnv = "SOPS_AGE_SSH_PRIVATE_KEY_FILE"
	// SopsAgeKeyUserConfigPath is the default age keys file path in
	// getUserConfigDir().
	SopsAgeKeyUserConfigPath = "sops/age/keys.txt"
	// On macOS, os.UserConfigDir() ignores XDG_CONFIG_HOME. So we handle that manually.
	xdgConfigHome = "XDG_CONFIG_HOME"
	// KeyTypeIdentifier is the string used to identify an age MasterKey.
	KeyTypeIdentifier = "age"
)

// log is the global logger for any age MasterKey.
var log *logrus.Logger

func init() {
	log = logging.NewLogger("AGE")
}

// MasterKey is an age key used to Encrypt and Decrypt SOPS' data key.
type MasterKey struct {
	// Identity used to contain a Bench32-encoded private key.
	// Deprecated: private keys are no longer publicly exposed.
	// Instead, they are either injected by a (local) key service server
	// using ParsedIdentities.ApplyToMasterKey, or loaded from the runtime
	// environment (variables) as defined by the `SopsAgeKey*` constants.
	Identity string
	// Recipient contains the Bench32-encoded age public key used to Encrypt.
	Recipient string
	// EncryptedKey contains the SOPS data key encrypted with age.
	EncryptedKey string

	// parsedIdentities contains a slice of parsed age identities.
	// It is used to lazy-load the Identities at-most once.
	// It can also be injected by a (local) keyservice.KeyServiceServer using
	// ParsedIdentities.ApplyToMasterKey().
	parsedIdentities []age.Identity
	// parsedRecipient contains a parsed age public key.
	// It is used to lazy-load the Recipient at-most once.
	parsedRecipient age.Recipient
}

// MasterKeysFromRecipients takes a comma-separated list of Bech32-encoded
// public keys, parses them, and returns a slice of new MasterKeys.
func MasterKeysFromRecipients(commaSeparatedRecipients string) ([]*MasterKey, error) {
	if commaSeparatedRecipients == "" {
		// otherwise Split returns [""] and MasterKeyFromRecipient is unhappy
		return make([]*MasterKey, 0), nil
	}
	recipients := strings.Split(commaSeparatedRecipients, ",")

	var keys []*MasterKey
	for _, recipient := range recipients {
		key, err := MasterKeyFromRecipient(recipient)
		if err != nil {
			return nil, err
		}
		keys = append(keys, key)
	}
	return keys, nil
}

// errSet is a collection of captured errors.
type errSet []error

// Error joins the errors into a "; " separated string.
func (e errSet) Error() string {
	str := make([]string, len(e))
	for i, err := range e {
		str[i] = err.Error()
	}
	return strings.Join(str, "; ")
}

// MasterKeyFromRecipient takes a Bech32-encoded age public key, parses it, and
// returns a new MasterKey.
func MasterKeyFromRecipient(recipient string) (*MasterKey, error) {
	recipient = strings.TrimSpace(recipient)
	parsedRecipient, err := parseRecipient(recipient)
	if err != nil {
		return nil, err
	}
	return &MasterKey{
		Recipient:       recipient,
		parsedRecipient: parsedRecipient,
	}, nil
}

// ParsedIdentities contains a set of parsed age identities.
// It allows for creating a (local) keyservice.KeyServiceServer which parses
// identities only once, to then inject them using ApplyToMasterKey() for all
// requests.
type ParsedIdentities []age.Identity

// Import attempts to parse the given identities, to then add them to itself.
// It returns any parsing error.
// A single identity argument is allowed to be a multiline string containing
// multiple identities. Empty lines and lines starting with "#" are ignored.
// It is not thread safe, and parallel importing would better be done by
// parsing (using age.ParseIdentities) and appending to the slice yourself, in
// combination with e.g. a sync.Mutex.
func (i *ParsedIdentities) Import(identity ...string) error {
	// one identity per line
	r := strings.NewReader(strings.Join(identity, "\n"))

	identities, err := parseIdentities(r)
	if err != nil {
		return fmt.Errorf("failed to parse and add to age identities: %w", err)
	}
	*i = append(*i, identities...)
	return nil
}

// ApplyToMasterKey configures the ParsedIdentities on the provided key.
func (i ParsedIdentities) ApplyToMasterKey(key *MasterKey) {
	key.parsedIdentities = i
}

// Encrypt takes a SOPS data key, encrypts it with the Recipient, and stores
// the result in the EncryptedKey field.
func (key *MasterKey) Encrypt(dataKey []byte) error {
	if key.parsedRecipient == nil {
		parsedRecipient, err := parseRecipient(key.Recipient)
		if err != nil {
			log.WithField("recipient", key.parsedRecipient).Info("Encryption failed")
			return err
		}
		key.parsedRecipient = parsedRecipient
	}

	var buffer bytes.Buffer
	aw := armor.NewWriter(&buffer)
	w, err := age.Encrypt(aw, key.parsedRecipient)
	if err != nil {
		log.WithField("recipient", key.parsedRecipient).Info("Encryption failed")
		return fmt.Errorf("failed to create writer for encrypting sops data key with age: %w", err)
	}
	if _, err := w.Write(dataKey); err != nil {
		log.WithField("recipient", key.parsedRecipient).Info("Encryption failed")
		return fmt.Errorf("failed to encrypt sops data key with age: %w", err)
	}
	if err := w.Close(); err != nil {
		log.WithField("recipient", key.parsedRecipient).Info("Encryption failed")
		return fmt.Errorf("failed to close writer for encrypting sops data key with age: %w", err)
	}
	if err := aw.Close(); err != nil {
		log.WithField("recipient", key.parsedRecipient).Info("Encryption failed")
		return fmt.Errorf("failed to close armored writer: %w", err)
	}

	key.SetEncryptedDataKey(buffer.Bytes())
	log.WithField("recipient", key.parsedRecipient).Info("Encryption succeeded")
	return nil
}

// EncryptIfNeeded encrypts the provided SOPS data key, if it has not been
// encrypted yet.
func (key *MasterKey) EncryptIfNeeded(dataKey []byte) error {
	if key.EncryptedKey == "" {
		return key.Encrypt(dataKey)
	}
	return nil
}

// EncryptedDataKey returns the encrypted SOPS data key this master key holds.
func (key *MasterKey) EncryptedDataKey() []byte {
	return []byte(key.EncryptedKey)
}

// SetEncryptedDataKey sets the encrypted SOPS data key for this master key.
func (key *MasterKey) SetEncryptedDataKey(enc []byte) {
	key.EncryptedKey = string(enc)
}

func formatError(msg string, err error, errs errSet, unusedLocations []string) error {
	var loadSuffix string
	if len(errs) > 0 {
		loadSuffix = fmt.Sprintf(". Errors while loading age identities: %s", errs.Error())
	}
	var unusedSuffix string
	if len(unusedLocations) > 0 {
		count := len(unusedLocations)
		if count == 1 {
			unusedSuffix = fmt.Sprintf(" '%s'", unusedLocations[0])
		} else if count == 2 {
			unusedSuffix = fmt.Sprintf("s '%s' and '%s'", unusedLocations[0], unusedLocations[1])
		} else {
			unusedSuffix = fmt.Sprintf("s '%s', and '%s'", strings.Join(unusedLocations[:count - 1], "', '"), unusedLocations[count - 1])
		}
		unusedSuffix = fmt.Sprintf(". Did not find keys in location%s.", unusedSuffix)
	}
	if err != nil {
		return fmt.Errorf("%s: %w%s%s", msg, err, loadSuffix, unusedSuffix)
	} else {
		return fmt.Errorf("%s%s%s", msg, loadSuffix, unusedSuffix)
	}
}

// Decrypt decrypts the EncryptedKey with the parsed or loaded identities, and
// returns the result.
func (key *MasterKey) Decrypt() ([]byte, error) {
	var errs errSet
	var unusedLocations []string
	if len(key.parsedIdentities) == 0 {
		var ids ParsedIdentities
		ids, unusedLocations, errs = key.loadIdentities()
		if len(ids) == 0 {
			log.Info("Decryption failed")
			return nil, formatError("failed to load age identities", nil, errs, unusedLocations)
		}
		ids.ApplyToMasterKey(key)
	}

	src := bytes.NewReader([]byte(key.EncryptedKey))
	ar := armor.NewReader(src)
	r, err := age.Decrypt(ar, key.parsedIdentities...)
	if err != nil {
		log.Info("Decryption failed")
		return nil, formatError("failed to create reader for decrypting sops data key with age", err, errs, unusedLocations)
	}

	var b bytes.Buffer
	if _, err := io.Copy(&b, r); err != nil {
		log.Info("Decryption failed")
		return nil, fmt.Errorf("failed to copy age decrypted data into bytes.Buffer: %w", err)
	}

	log.Info("Decryption succeeded")
	return b.Bytes(), nil
}

// NeedsRotation returns whether the data key needs to be rotated or not.
func (key *MasterKey) NeedsRotation() bool {
	return false
}

// ToString converts the key to a string representation.
func (key *MasterKey) ToString() string {
	return key.Recipient
}

// ToMap converts the MasterKey to a map for serialization purposes.
func (key *MasterKey) ToMap() map[string]interface{} {
	out := make(map[string]interface{})
	out["recipient"] = key.Recipient
	out["enc"] = key.EncryptedKey
	return out
}

// TypeToIdentifier returns the string identifier for the MasterKey type.
func (key *MasterKey) TypeToIdentifier() string {
	return KeyTypeIdentifier
}

// getOutputFromCmd executes a shell command provided in param 'cmdString',
// optionally adding env vars provided in param 'envVars',
// and returns the command's output and error
func getOutputFromCmd(cmdString string, envVars []string) ([]byte, error) {
	var out []byte

	args, err := shlex.Split(cmdString)
	if err != nil {
		return nil, fmt.Errorf("failed to parse command %s: %w", cmdString, err)
	}
	cmd := exec.Command(args[0], args[1:]...)
	if envVars != nil {
		cmd.Env = append(os.Environ(), envVars[0:]...)
	}
	out, err = cmd.Output()
	if err != nil {
		return nil, fmt.Errorf("failed to execute command %s: %w", cmdString, err)
	}

	return out, nil
}

// loadAgeSSHIdentity attempts to load age SSH identities in this order:
// 1. An SSH private key from the SopsAgeSshPrivateKeyFileEnv environment variable.
// 2. An SSH private key returned by executing the command from the
// SopsAgeSshPrivateKeyCmdEnv environment variable
// 3. `~/.ssh/id_ed25519` or `~/.ssh/id_rsa`.
// If no age SSH identity is found, it will return nil.
func (key *MasterKey) loadAgeSSHIdentities() ([]age.Identity, []string, errSet) {
	var identities []age.Identity
	var unusedLocations []string
	var errs errSet

	sshKeyFilePath, ok := os.LookupEnv(SopsAgeSshPrivateKeyFileEnv)
	if ok {
		identity, err := parseSSHIdentityFromPrivateKeyFile(sshKeyFilePath)
		if err != nil {
			errs = append(errs, err)
		} else {
			identities = append(identities, identity)
		}
	} else {
		unusedLocations = append(unusedLocations, SopsAgeSshPrivateKeyFileEnv)
	}

	sshKeyCmd, ok := os.LookupEnv(SopsAgeSshPrivateKeyCmdEnv)
	if ok {
		out, err := getOutputFromCmd(sshKeyCmd, []string{fmt.Sprintf("%s=%s", SopsAgeRecipientEnv, key.Recipient)})
		if err != nil {
			errs = append(errs, err)
		} else {
			identity, err := parseSSHIdentityFromPrivateKeyCmdOutput(out)
			if err != nil {
				errs = append(errs, err)
			} else {
				identities = append(identities, identity)
			}
		}
	} else {
		unusedLocations = append(unusedLocations, SopsAgeSshPrivateKeyCmdEnv)
	}

	userHomeDir, err := os.UserHomeDir()
	if err != nil {
		errs = append(errs, err)
	} else if userHomeDir == "" {
		log.Warnf("could not determine the user home directory: %v", err)
	} else {
		sshEd25519PrivateKeyPath := filepath.Join(userHomeDir, ".ssh", "id_ed25519")
		if _, err := os.Stat(sshEd25519PrivateKeyPath); err == nil {
			identity, err := parseSSHIdentityFromPrivateKeyFile(sshEd25519PrivateKeyPath)
			if err != nil {
				errs = append(errs, err)
			} else {
				identities = append(identities, identity)
			}
		} else {
			unusedLocations = append(unusedLocations, sshEd25519PrivateKeyPath)
		}

		sshRsaPrivateKeyPath := filepath.Join(userHomeDir, ".ssh", "id_rsa")
		if _, err := os.Stat(sshRsaPrivateKeyPath); err == nil {
			identity, err := parseSSHIdentityFromPrivateKeyFile(sshRsaPrivateKeyPath)
			if err != nil {
				errs = append(errs, err)
			} else {
				identities = append(identities, identity)
			}
		} else {
			unusedLocations = append(unusedLocations, sshRsaPrivateKeyPath)
		}
	}

	return identities, unusedLocations, errs
}

func getUserConfigDir() (string, error) {
	if runtime.GOOS == "darwin" {
		if userConfigDir, ok := os.LookupEnv(xdgConfigHome); ok && userConfigDir != "" {
			return userConfigDir, nil
		}
	}
	return os.UserConfigDir()
}

// loadIdentities attempts to load the age identities based on runtime
// environment configurations (e.g. SopsAgeKeyEnv, SopsAgeKeyFileEnv,
// SopsAgeSshPrivateKeyFileEnv, SopsAgeKeyUserConfigPath). It will load all
// found references, and expects at least one configuration to be present.
func (key *MasterKey) loadIdentities() (ParsedIdentities, []string, errSet) {
	identities, unusedLocations, errs := key.loadAgeSSHIdentities()

	var readers = make(map[string]io.Reader, 0)

	if ageKey, ok := os.LookupEnv(SopsAgeKeyEnv); ok {
		readers[SopsAgeKeyEnv] = strings.NewReader(ageKey)
	} else {
		unusedLocations = append(unusedLocations, SopsAgeKeyEnv)
	}

	if ageKeyFile, ok := os.LookupEnv(SopsAgeKeyFileEnv); ok {
		f, err := os.Open(ageKeyFile)
		if err != nil {
			errs = append(errs, fmt.Errorf("failed to open %s file: %w", SopsAgeKeyFileEnv, err))
		} else {
			defer f.Close()
			readers[SopsAgeKeyFileEnv] = f
		}
	} else {
		unusedLocations = append(unusedLocations, SopsAgeKeyFileEnv)
	}

	if ageKeyCmd, ok := os.LookupEnv(SopsAgeKeyCmdEnv); ok {
		out, err := getOutputFromCmd(ageKeyCmd, []string{fmt.Sprintf("%s=%s", SopsAgeRecipientEnv, key.Recipient)})
		if err != nil {
			errs = append(errs, err)
		} else {
			readers[SopsAgeKeyCmdEnv] = bytes.NewReader(out)
		}
	} else {
		unusedLocations = append(unusedLocations, SopsAgeKeyCmdEnv)
	}

	userConfigDir, err := getUserConfigDir()
	if err != nil && len(readers) == 0 && len(identities) == 0 {
		errs = append(errs, fmt.Errorf("user config directory could not be determined: %w", err))
	} else if userConfigDir != "" {
		ageKeyFilePath := filepath.Join(userConfigDir, filepath.FromSlash(SopsAgeKeyUserConfigPath))
		f, err := os.Open(ageKeyFilePath)
		if err != nil && !errors.Is(err, os.ErrNotExist) {
			errs = append(errs, fmt.Errorf("failed to open file: %w", err))
		} else if errors.Is(err, os.ErrNotExist) && len(readers) == 0 && len(identities) == 0 {
			unusedLocations = append(unusedLocations, ageKeyFilePath)
		} else if err == nil {
			defer f.Close()
			readers[ageKeyFilePath] = f
		}
	}

	for location, r := range readers {
		ids, err := unwrapIdentities(location, r)
		if err != nil {
			errs = append(errs, err)
		} else {
			identities = append(identities, ids...)
			if len(ids) == 0 {
				unusedLocations = append(unusedLocations, location)
			}
		}
	}
	return identities, unusedLocations, errs
}

// parseRecipient attempts to parse a string containing an encoded age public
// key or a public ssh key.
func parseRecipient(recipient string) (age.Recipient, error) {
	switch {
	case strings.HasPrefix(recipient, "age1pq1"):
		parsedRecipient, err := age.ParseHybridRecipient(recipient)
		if err != nil {
			return nil, fmt.Errorf("failed to parse input as Bech32-encoded age public key: %w", err)
		}

		return parsedRecipient, nil
	case strings.HasPrefix(recipient, "age1") && strings.Count(recipient, "1") > 1:
		parsedRecipient, err := plugin.NewRecipient(recipient, pluginTerminalUI)
		if err != nil {
			return nil, fmt.Errorf("failed to parse input as age key from age plugin: %w", err)
		}
		return parsedRecipient, nil
	case strings.HasPrefix(recipient, "age1"):
		parsedRecipient, err := age.ParseX25519Recipient(recipient)
		if err != nil {
			return nil, fmt.Errorf("failed to parse input as Bech32-encoded age public key: %w", err)
		}

		return parsedRecipient, nil
	case strings.HasPrefix(recipient, "ssh-"):
		parsedRecipient, err := agessh.ParseRecipient(recipient)
		if err != nil {
			return nil, fmt.Errorf("failed to parse input as age-ssh public key: %w", err)
		}
		return parsedRecipient, nil
	}

	return nil, fmt.Errorf("failed to parse input, unknown recipient type: %q", recipient)
}

// parseIdentities attempts to parse one or more age identities from the provided reader.
// One identity per line.
// Empty lines and lines starting with "#" are ignored.
func parseIdentities(r io.Reader) (ParsedIdentities, error) {
	var identities ParsedIdentities

	scanner := bufio.NewScanner(r)

	for scanner.Scan() {
		line := scanner.Text()

		if line == "" || strings.HasPrefix(line, "#") {
			continue
		}

		parsed, err := parseIdentity(line)
		if err != nil {
			return nil, err
		}

		identities = append(identities, parsed)
	}

	return identities, nil
}

func parseIdentity(s string) (age.Identity, error) {
	switch {
	case strings.HasPrefix(s, "AGE-PLUGIN-"):
		return plugin.NewIdentity(s, pluginTerminalUI)
	case strings.HasPrefix(s, "AGE-SECRET-KEY-PQ-1"):
		return age.ParseHybridIdentity(s)
	case strings.HasPrefix(s, "AGE-SECRET-KEY-1"):
		return age.ParseX25519Identity(s)
	default:
		return nil, fmt.Errorf("unknown identity type")
	}
}

// parseSSHIdentityFromPrivateKeyCmdOutput returns an age.Identity from the given
// private key. Note that encrypted private keys are not supported.
func parseSSHIdentityFromPrivateKeyCmdOutput(key []byte) (age.Identity, error) {
	id, err := agessh.ParseIdentity(key)
	if sshErr, ok := err.(*ssh.PassphraseMissingError); ok {
		return nil, fmt.Errorf("the SSH key returned by running SOPS_AGE_SSH_PRIVATE_KEY_CMD is password protected, which is unsupported. (%q)", sshErr)
	}
	if err != nil {
		return nil, fmt.Errorf("malformed SSH identity returned by running SOPS_AGE_SSH_PRIVATE_KEY_CMD: %q", err)
	}
	return id, nil
}
