use crate::error::*;
use core::marker::PhantomData;
pub use ec::{Curve, EcKey};
use nss::{ec, ecdh};
pub type EphemeralKeyPair = KeyPair<Ephemeral>;
#[derive(PartialEq)]
pub struct Algorithm {
    pub(crate) curve_id: ec::Curve,
}
pub static ECDH_P256: Algorithm = Algorithm {
    curve_id: ec::Curve::P256,
};
pub static ECDH_P384: Algorithm = Algorithm {
    curve_id: ec::Curve::P384,
};
pub trait Lifetime {}
pub struct Ephemeral {}
impl Lifetime for Ephemeral {}
pub struct Static {}
impl Lifetime for Static {}
pub struct KeyPair<U: Lifetime> {
    private_key: PrivateKey<U>,
    public_key: PublicKey,
}
impl<U: Lifetime> KeyPair<U> {
    
    pub fn generate(alg: &'static Algorithm) -> Result<Self> {
        let (prv_key, pub_key) = ec::generate_keypair(alg.curve_id)?;
        Ok(Self {
            private_key: PrivateKey {
                alg,
                wrapped: prv_key,
                usage: PhantomData,
            },
            public_key: PublicKey {
                alg,
                wrapped: pub_key,
            },
        })
    }
    pub fn from_private_key(private_key: PrivateKey<U>) -> Result<Self> {
        let public_key = private_key
            .compute_public_key()
            .map_err(|_| ErrorKind::InternalError)?;
        Ok(Self {
            private_key,
            public_key,
        })
    }
    
    pub fn private_key(&self) -> &PrivateKey<U> {
        &self.private_key
    }
    
    pub fn public_key(&self) -> &PublicKey {
        &self.public_key
    }
    
    pub fn split(self) -> (PrivateKey<U>, PublicKey) {
        (self.private_key, self.public_key)
    }
}
impl KeyPair<Static> {
    pub fn from(private_key: PrivateKey<Static>) -> Result<Self> {
        Self::from_private_key(private_key)
    }
}
pub struct PublicKey {
    wrapped: ec::PublicKey,
    alg: &'static Algorithm,
}
impl PublicKey {
    #[inline]
    pub fn to_bytes(&self) -> Result<Vec<u8>> {
        Ok(self.wrapped.to_bytes()?)
    }
    #[inline]
    pub fn algorithm(&self) -> &'static Algorithm {
        self.alg
    }
}
pub struct UnparsedPublicKey<'a> {
    alg: &'static Algorithm,
    bytes: &'a [u8],
}
impl<'a> UnparsedPublicKey<'a> {
    pub fn new(algorithm: &'static Algorithm, bytes: &'a [u8]) -> Self {
        Self {
            alg: algorithm,
            bytes,
        }
    }
    pub fn algorithm(&self) -> &'static Algorithm {
        self.alg
    }
    pub fn bytes(&self) -> &'a [u8] {
        &self.bytes
    }
}
pub struct PrivateKey<U: Lifetime> {
    wrapped: ec::PrivateKey,
    alg: &'static Algorithm,
    usage: PhantomData<U>,
}
impl<U: Lifetime> PrivateKey<U> {
    #[inline]
    pub fn algorithm(&self) -> &'static Algorithm {
        self.alg
    }
    pub fn compute_public_key(&self) -> Result<PublicKey> {
        let pub_key = self.wrapped.convert_to_public_key()?;
        Ok(PublicKey {
            wrapped: pub_key,
            alg: self.alg,
        })
    }
    
    
    
    pub fn agree(self, peer_public_key: &UnparsedPublicKey<'_>) -> Result<InputKeyMaterial> {
        agree_(&self.wrapped, self.alg, peer_public_key)
    }
}
impl PrivateKey<Static> {
    
    
    
    pub fn agree_static(
        &self,
        peer_public_key: &UnparsedPublicKey<'_>,
    ) -> Result<InputKeyMaterial> {
        agree_(&self.wrapped, self.alg, peer_public_key)
    }
    pub fn import(ec_key: &EcKey) -> Result<Self> {
        
        let alg = match ec_key.curve() {
            Curve::P256 => &ECDH_P256,
            Curve::P384 => &ECDH_P384,
        };
        let private_key = ec::PrivateKey::import(ec_key)?;
        Ok(Self {
            wrapped: private_key,
            alg,
            usage: PhantomData,
        })
    }
    pub fn export(&self) -> Result<EcKey> {
        Ok(self.wrapped.export()?)
    }
    
    
    
    pub fn _tests_only_dangerously_convert_to_ephemeral(self) -> PrivateKey<Ephemeral> {
        PrivateKey::<Ephemeral> {
            wrapped: self.wrapped,
            alg: self.alg,
            usage: PhantomData,
        }
    }
}
fn agree_(
    my_private_key: &ec::PrivateKey,
    my_alg: &Algorithm,
    peer_public_key: &UnparsedPublicKey<'_>,
) -> Result<InputKeyMaterial> {
    let alg = &my_alg;
    if peer_public_key.algorithm() != *alg {
        return Err(ErrorKind::InternalError.into());
    }
    let pub_key = ec::PublicKey::from_bytes(my_private_key.curve(), peer_public_key.bytes())?;
    let value = ecdh::ecdh_agreement(my_private_key, &pub_key)?;
    Ok(InputKeyMaterial { value })
}
#[must_use]
pub struct InputKeyMaterial {
    value: Vec<u8>,
}
impl InputKeyMaterial {
    
    
    
    pub fn derive<F, R>(self, kdf: F) -> R
    where
        F: FnOnce(&[u8]) -> R,
    {
        kdf(&self.value)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    
    
    const PUB_KEY_1_B64: &str =
        "BLunVoWkR67xRdAohVblFBWn1Oosb3kH_baxw1yfIYFfthSm4LIY35vDD-5LE454eB7TShn919DVVGZ_7tWdjTE";
    const PRIV_KEY_1_JWK_D: &str = "CQ8uF_-zB1NftLO6ytwKM3Cnuol64PQw5qOuCzQJeFU";
    const PRIV_KEY_1_JWK_X: &str = "u6dWhaRHrvFF0CiFVuUUFafU6ixveQf9trHDXJ8hgV8";
    const PRIV_KEY_1_JWK_Y: &str = "thSm4LIY35vDD-5LE454eB7TShn919DVVGZ_7tWdjTE";
    const PRIV_KEY_2_JWK_D: &str = "uN2YSQvxuxhQQ9Y1XXjYi1vr2ZTdzuoDX18PYu4LU-0";
    const PRIV_KEY_2_JWK_X: &str = "S2S3tjygMB0DkM-N9jYUgGLt_9_H6km5P9V6V_KS4_4";
    const PRIV_KEY_2_JWK_Y: &str = "03j8Tyqgrc4R4FAUV2C7-im96yMmfmO_5Om6Kr8YP3o";
    const SHARED_SECRET_HEX: &str =
        "163FAA3FC4815D47345C8E959F707B2F1D3537E7B2EA1DAEC23CA8D0A242CFF3";
    fn load_priv_key_1() -> PrivateKey<Static> {
        let private_key = base64::decode_config(PRIV_KEY_1_JWK_D, base64::URL_SAFE_NO_PAD).unwrap();
        let x = base64::decode_config(PRIV_KEY_1_JWK_X, base64::URL_SAFE_NO_PAD).unwrap();
        let y = base64::decode_config(PRIV_KEY_1_JWK_Y, base64::URL_SAFE_NO_PAD).unwrap();
        PrivateKey::<Static>::import(
            &EcKey::from_coordinates(Curve::P256, &private_key, &x, &y).unwrap(),
        )
        .unwrap()
    }
    fn load_priv_key_2() -> PrivateKey<Static> {
        let private_key = base64::decode_config(PRIV_KEY_2_JWK_D, base64::URL_SAFE_NO_PAD).unwrap();
        let x = base64::decode_config(PRIV_KEY_2_JWK_X, base64::URL_SAFE_NO_PAD).unwrap();
        let y = base64::decode_config(PRIV_KEY_2_JWK_Y, base64::URL_SAFE_NO_PAD).unwrap();
        PrivateKey::<Static>::import(
            &EcKey::from_coordinates(Curve::P256, &private_key, &x, &y).unwrap(),
        )
        .unwrap()
    }
    #[test]
    fn test_static_agreement() {
        let pub_key_raw = base64::decode_config(PUB_KEY_1_B64, base64::URL_SAFE_NO_PAD).unwrap();
        let peer_pub_key = UnparsedPublicKey::new(&ECDH_P256, &pub_key_raw);
        let prv_key = load_priv_key_2();
        let ikm = prv_key.agree_static(&peer_pub_key).unwrap();
        let secret = ikm
            .derive(|z| -> Result<Vec<u8>> { Ok(z.to_vec()) })
            .unwrap();
        let secret_b64 = hex::encode_upper(&secret);
        assert_eq!(secret_b64, *SHARED_SECRET_HEX);
    }
    #[test]
    fn test_ephemeral_agreement_roundtrip() {
        let (our_prv_key, our_pub_key) =
            KeyPair::<Ephemeral>::generate(&ECDH_P256).unwrap().split();
        let (their_prv_key, their_pub_key) =
            KeyPair::<Ephemeral>::generate(&ECDH_P256).unwrap().split();
        let their_pub_key_raw = their_pub_key.to_bytes().unwrap();
        let peer_public_key_1 = UnparsedPublicKey::new(&ECDH_P256, &their_pub_key_raw);
        let ikm_1 = our_prv_key.agree(&peer_public_key_1).unwrap();
        let secret_1 = ikm_1
            .derive(|z| -> Result<Vec<u8>> { Ok(z.to_vec()) })
            .unwrap();
        let our_pub_key_raw = our_pub_key.to_bytes().unwrap();
        let peer_public_key_2 = UnparsedPublicKey::new(&ECDH_P256, &our_pub_key_raw);
        let ikm_2 = their_prv_key.agree(&peer_public_key_2).unwrap();
        let secret_2 = ikm_2
            .derive(|z| -> Result<Vec<u8>> { Ok(z.to_vec()) })
            .unwrap();
        assert_eq!(secret_1, secret_2);
    }
    #[test]
    fn test_compute_public_key() {
        let (prv_key, pub_key) = KeyPair::<Static>::generate(&ECDH_P256).unwrap().split();
        let computed_pub_key = prv_key.compute_public_key().unwrap();
        assert_eq!(
            computed_pub_key.to_bytes().unwrap(),
            pub_key.to_bytes().unwrap()
        );
    }
    #[test]
    fn test_compute_public_key_known_values() {
        let prv_key = load_priv_key_1();
        let pub_key = base64::decode_config(PUB_KEY_1_B64, base64::URL_SAFE_NO_PAD).unwrap();
        let computed_pub_key = prv_key.compute_public_key().unwrap();
        assert_eq!(computed_pub_key.to_bytes().unwrap(), pub_key.as_slice());
        let prv_key = load_priv_key_2();
        let computed_pub_key = prv_key.compute_public_key().unwrap();
        assert_ne!(computed_pub_key.to_bytes().unwrap(), pub_key.as_slice());
    }
    #[test]
    fn test_keys_byte_representations_roundtrip() {
        let key_pair = KeyPair::<Static>::generate(&ECDH_P256).unwrap();
        let prv_key = key_pair.private_key;
        let extracted_pub_key = prv_key.compute_public_key().unwrap();
        let ec_key = prv_key.export().unwrap();
        let prv_key_reconstructed = PrivateKey::<Static>::import(&ec_key).unwrap();
        let extracted_pub_key_reconstructed = prv_key.compute_public_key().unwrap();
        let ec_key_reconstructed = prv_key_reconstructed.export().unwrap();
        assert_eq!(ec_key.curve(), ec_key_reconstructed.curve());
        assert_eq!(ec_key.public_key(), ec_key_reconstructed.public_key());
        assert_eq!(ec_key.private_key(), ec_key_reconstructed.private_key());
        assert_eq!(
            extracted_pub_key.to_bytes().unwrap(),
            extracted_pub_key_reconstructed.to_bytes().unwrap()
        );
    }
    #[test]
    fn test_agreement_rejects_invalid_pubkeys() {
        let prv_key = load_priv_key_2();
        let mut invalid_pub_key =
            base64::decode_config(PUB_KEY_1_B64, base64::URL_SAFE_NO_PAD).unwrap();
        invalid_pub_key[0] = invalid_pub_key[0].wrapping_add(1);
        assert!(prv_key
            .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
            .is_err());
        let mut invalid_pub_key =
            base64::decode_config(PUB_KEY_1_B64, base64::URL_SAFE_NO_PAD).unwrap();
        invalid_pub_key[0] = 0x02;
        assert!(prv_key
            .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
            .is_err());
        let mut invalid_pub_key =
            base64::decode_config(PUB_KEY_1_B64, base64::URL_SAFE_NO_PAD).unwrap();
        invalid_pub_key[64] = invalid_pub_key[0].wrapping_add(1);
        assert!(prv_key
            .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
            .is_err());
        let mut invalid_pub_key = [0u8; 65];
        assert!(prv_key
            .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
            .is_err());
        invalid_pub_key[0] = 0x04;
        let mut invalid_pub_key = base64::decode_config(PUB_KEY_1_B64, base64::URL_SAFE_NO_PAD)
            .unwrap()
            .to_vec();
        invalid_pub_key = invalid_pub_key[0..64].to_vec();
        assert!(prv_key
            .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
            .is_err());
        
        
        let invalid_pub_key_b64 = "BEogZ-rnm44oJkKsOE6Tc7NwFMgmntf7Btm_Rc4atxcqq99Xq1RWNTFpk99pdQOSjUvwELss51PkmAGCXhLfMV0";
        let invalid_pub_key =
            base64::decode_config(invalid_pub_key_b64, base64::URL_SAFE_NO_PAD).unwrap();
        assert!(prv_key
            .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
            .is_err());
    }
}