use crate::error::{self, ErrorKind, Result};
use crate::util::ServerTimestamp;
use rc_crypto::hawk;
use serde_derive::*;
use std::borrow::{Borrow, Cow};
use std::cell::RefCell;
use std::fmt;
use std::time::{Duration, SystemTime};
use url::Url;
use viaduct::{header_names, Request};
const RETRY_AFTER_DEFAULT_MS: u64 = 10000;
#[derive(Deserialize, Clone, PartialEq, Eq)]
struct TokenserverToken {
id: String,
key: String,
api_endpoint: String,
uid: u64,
duration: u64,
hashed_fxa_uid: String,
}
impl std::fmt::Debug for TokenserverToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenserverToken")
.field("api_endpoint", &self.api_endpoint)
.field("uid", &self.uid)
.field("duration", &self.duration)
.field("hashed_fxa_uid", &self.hashed_fxa_uid)
.finish()
}
}
struct TokenFetchResult {
token: TokenserverToken,
server_timestamp: ServerTimestamp,
}
trait TokenFetcher {
fn fetch_token(&self) -> super::Result<TokenFetchResult>;
fn now(&self) -> SystemTime;
}
#[derive(Debug)]
struct TokenServerFetcher {
server_url: Url,
access_token: String,
key_id: String,
}
fn fixup_server_url(mut url: Url) -> Result<Url> {
if url.as_str().ends_with("1.0/sync/1.5") {
Ok(url)
} else if url.as_str().ends_with("1.0/sync/1.5/") {
if let Ok(mut path) = url.path_segments_mut() {
path.pop();
}
Ok(url)
} else {
if let Ok(mut path) = url.path_segments_mut() {
path.pop_if_empty();
path.extend(&["1.0", "sync", "1.5"]);
}
Ok(url)
}
}
impl TokenServerFetcher {
fn new(base_url: Url, access_token: String, key_id: String) -> Result<TokenServerFetcher> {
Ok(TokenServerFetcher {
server_url: fixup_server_url(base_url)?,
access_token,
key_id,
})
}
}
impl TokenFetcher for TokenServerFetcher {
fn fetch_token(&self) -> Result<TokenFetchResult> {
log::debug!("Fetching token from {}", self.server_url);
let resp = Request::get(self.server_url.clone())
.header(
header_names::AUTHORIZATION,
format!("Bearer {}", self.access_token),
)?
.header(header_names::X_KEYID, self.key_id.clone())?
.send()?;
if !resp.is_success() {
log::warn!("Non-success status when fetching token: {}", resp.status);
log::trace!(" Response body {}", resp.text());
if let Some(res) = resp.headers.get_as::<f64, _>(header_names::RETRY_AFTER) {
let ms = res
.ok()
.map_or(RETRY_AFTER_DEFAULT_MS, |f| (f * 1000f64) as u64);
let when = self.now() + Duration::from_millis(ms);
return Err(ErrorKind::BackoffError(when).into());
}
let status = resp.status;
return Err(ErrorKind::TokenserverHttpError(status).into());
}
let token: TokenserverToken = resp.json()?;
let server_timestamp = resp
.headers
.try_get::<ServerTimestamp, _>(header_names::X_TIMESTAMP)
.ok_or_else(|| ErrorKind::MissingServerTimestamp)?;
Ok(TokenFetchResult {
token,
server_timestamp,
})
}
fn now(&self) -> SystemTime {
SystemTime::now()
}
}
struct TokenContext {
token: TokenserverToken,
credentials: hawk::Credentials,
server_timestamp: ServerTimestamp,
valid_until: SystemTime,
}
impl fmt::Debug for TokenContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> ::std::result::Result<(), fmt::Error> {
f.debug_struct("TokenContext")
.field("token", &self.token)
.field("credentials", &"(omitted)")
.field("server_timestamp", &self.server_timestamp)
.field("valid_until", &self.valid_until)
.finish()
}
}
impl TokenContext {
fn new(
token: TokenserverToken,
credentials: hawk::Credentials,
server_timestamp: ServerTimestamp,
valid_until: SystemTime,
) -> Self {
Self {
token,
credentials,
server_timestamp,
valid_until,
}
}
fn is_valid(&self, now: SystemTime) -> bool {
now < self.valid_until
}
fn authorization(&self, req: &Request) -> Result<String> {
let url = &req.url;
let path_and_query = match url.query() {
None => Cow::from(url.path()),
Some(qs) => Cow::from(format!("{}?{}", url.path(), qs)),
};
let host = url
.host_str()
.ok_or_else(|| ErrorKind::UnacceptableUrl("Storage URL has no host".into()))?;
let port = url.port_or_known_default().ok_or_else(|| {
ErrorKind::UnacceptableUrl(
"Storage URL has no port and no default port is known for the protocol".into(),
)
})?;
let header =
hawk::RequestBuilder::new(req.method.as_str(), host, port, path_and_query.borrow())
.request()
.make_header(&self.credentials)?;
Ok(format!("Hawk {}", header))
}
}
#[derive(Debug)]
enum TokenState {
NoToken,
Token(TokenContext),
Failed(Option<error::Error>, Option<String>),
Backoff(SystemTime, Option<String>),
NodeReassigned,
}
#[derive(Debug)]
struct TokenProviderImpl<TF: TokenFetcher> {
fetcher: TF,
current_state: RefCell<TokenState>,
}
impl<TF: TokenFetcher> TokenProviderImpl<TF> {
fn new(fetcher: TF) -> Self {
rc_crypto::ensure_initialized();
TokenProviderImpl {
fetcher,
current_state: RefCell::new(TokenState::NoToken),
}
}
fn fetch_context(&self) -> Result<TokenContext> {
let result = self.fetcher.fetch_token()?;
let token = result.token;
let valid_until = SystemTime::now() + Duration::from_secs(token.duration);
let credentials = hawk::Credentials {
id: token.id.clone(),
key: hawk::Key::new(token.key.as_bytes(), hawk::SHA256)?,
};
Ok(TokenContext::new(
token,
credentials,
result.server_timestamp,
valid_until,
))
}
fn fetch_token(&self, previous_endpoint: Option<&str>) -> TokenState {
match self.fetch_context() {
Ok(tc) => {
match previous_endpoint {
Some(prev) => {
if prev == tc.token.api_endpoint {
TokenState::Token(tc)
} else {
log::warn!(
"api_endpoint changed from {} to {}",
prev,
tc.token.api_endpoint
);
TokenState::NodeReassigned
}
}
None => {
TokenState::Token(tc)
}
}
}
Err(e) => {
if let ErrorKind::BackoffError(be) = e.kind() {
return TokenState::Backoff(*be, previous_endpoint.map(ToString::to_string));
}
TokenState::Failed(Some(e), previous_endpoint.map(ToString::to_string))
}
}
}
fn advance_state(&self, state: &TokenState) -> Option<TokenState> {
match state {
TokenState::NoToken => Some(self.fetch_token(None)),
TokenState::Failed(_, existing_endpoint) => {
Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
}
TokenState::Token(existing_context) => {
if existing_context.is_valid(self.fetcher.now()) {
None
} else {
Some(self.fetch_token(Some(existing_context.token.api_endpoint.as_str())))
}
}
TokenState::Backoff(ref until, ref existing_endpoint) => {
if let Ok(remaining) = until.duration_since(self.fetcher.now()) {
log::debug!("enforcing existing backoff - {:?} remains", remaining);
None
} else {
Some(self.fetch_token(existing_endpoint.as_ref().map(String::as_str)))
}
}
TokenState::NodeReassigned => {
None
}
}
}
fn with_token<T, F>(&self, func: F) -> Result<T>
where
F: FnOnce(&TokenContext) -> Result<T>,
{
let state: &mut TokenState = &mut self.current_state.borrow_mut();
if let Some(new_state) = self.advance_state(state) {
*state = new_state;
}
match state {
TokenState::NoToken => {
panic!("Can't be in NoToken state after advancing");
}
TokenState::Token(ref token_context) => {
func(token_context)
}
TokenState::Failed(e, _) => {
Err(e.take().unwrap())
}
TokenState::NodeReassigned => {
Err(ErrorKind::StorageResetError.into())
}
TokenState::Backoff(ref remaining, _) => {
Err(ErrorKind::BackoffError(*remaining).into())
}
}
}
fn hashed_uid(&self) -> Result<String> {
self.with_token(|ctx| Ok(ctx.token.hashed_fxa_uid.clone()))
}
fn authorization(&self, req: &Request) -> Result<String> {
self.with_token(|ctx| ctx.authorization(req))
}
fn api_endpoint(&self) -> Result<String> {
self.with_token(|ctx| Ok(ctx.token.api_endpoint.clone()))
}
}
#[derive(Debug)]
pub struct TokenProvider {
imp: TokenProviderImpl<TokenServerFetcher>,
}
impl TokenProvider {
pub fn new(url: Url, access_token: String, key_id: String) -> Result<Self> {
let fetcher = TokenServerFetcher::new(url, access_token, key_id)?;
Ok(Self {
imp: TokenProviderImpl::new(fetcher),
})
}
pub fn hashed_uid(&self) -> Result<String> {
self.imp.hashed_uid()
}
pub fn authorization(&self, req: &Request) -> Result<String> {
self.imp.authorization(req)
}
pub fn api_endpoint(&self) -> Result<String> {
self.imp.api_endpoint()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
struct TestFetcher<FF, FN>
where
FF: Fn() -> Result<TokenFetchResult>,
FN: Fn() -> SystemTime,
{
fetch: FF,
now: FN,
}
impl<FF, FN> TokenFetcher for TestFetcher<FF, FN>
where
FF: Fn() -> Result<TokenFetchResult>,
FN: Fn() -> SystemTime,
{
fn fetch_token(&self) -> Result<TokenFetchResult> {
(self.fetch)()
}
fn now(&self) -> SystemTime {
(self.now)()
}
}
fn make_tsc<FF, FN>(fetch: FF, now: FN) -> TokenProviderImpl<TestFetcher<FF, FN>>
where
FF: Fn() -> Result<TokenFetchResult>,
FN: Fn() -> SystemTime,
{
let fetcher: TestFetcher<FF, FN> = TestFetcher { fetch, now };
TokenProviderImpl::new(fetcher)
}
#[test]
fn test_endpoint() {
let counter: Cell<u32> = Cell::new(0);
let fetch = || {
counter.set(counter.get() + 1);
Ok(TokenFetchResult {
token: TokenserverToken {
id: "id".to_string(),
key: "key".to_string(),
api_endpoint: "api_endpoint".to_string(),
uid: 1,
duration: 1000,
hashed_fxa_uid: "hash".to_string(),
},
server_timestamp: ServerTimestamp(0i64),
})
};
let tsc = make_tsc(fetch, SystemTime::now);
let e = tsc.api_endpoint().expect("should work");
assert_eq!(e, "api_endpoint".to_string());
assert_eq!(counter.get(), 1);
let e2 = tsc.api_endpoint().expect("should work");
assert_eq!(e2, "api_endpoint".to_string());
assert_eq!(counter.get(), 1);
}
#[test]
fn test_backoff() {
let counter: Cell<u32> = Cell::new(0);
let fetch = || {
counter.set(counter.get() + 1);
let when = SystemTime::now() + Duration::from_millis(10000);
Err(error::Error::from(ErrorKind::BackoffError(when)))
};
let now: Cell<SystemTime> = Cell::new(SystemTime::now());
let tsc = make_tsc(fetch, || now.get());
tsc.api_endpoint().expect_err("should bail");
assert_eq!(counter.get(), 1);
tsc.api_endpoint().expect_err("should bail");
assert_eq!(counter.get(), 1);
now.set(now.get() + Duration::new(20, 0));
tsc.api_endpoint().expect_err("should bail");
assert_eq!(counter.get(), 2);
}
#[test]
fn test_validity() {
let counter: Cell<u32> = Cell::new(0);
let fetch = || {
counter.set(counter.get() + 1);
Ok(TokenFetchResult {
token: TokenserverToken {
id: "id".to_string(),
key: "key".to_string(),
api_endpoint: "api_endpoint".to_string(),
uid: 1,
duration: 10,
hashed_fxa_uid: "hash".to_string(),
},
server_timestamp: ServerTimestamp(0i64),
})
};
let now: Cell<SystemTime> = Cell::new(SystemTime::now());
let tsc = make_tsc(fetch, || now.get());
tsc.api_endpoint().expect("should get a valid token");
assert_eq!(counter.get(), 1);
tsc.api_endpoint().expect("should reuse existing token");
assert_eq!(counter.get(), 1);
now.set(now.get() + Duration::new(20, 0));
tsc.api_endpoint().expect("should re-fetch");
assert_eq!(counter.get(), 2);
}
#[test]
fn test_server_url() {
assert_eq!(
fixup_server_url(
Url::parse("https://token.services.mozilla.com/1.0/sync/1.5").unwrap()
)
.unwrap()
.as_str(),
"https://token.services.mozilla.com/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(
Url::parse("https://token.services.mozilla.com/1.0/sync/1.5/").unwrap()
)
.unwrap()
.as_str(),
"https://token.services.mozilla.com/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(Url::parse("https://token.services.mozilla.com").unwrap())
.unwrap()
.as_str(),
"https://token.services.mozilla.com/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(Url::parse("https://token.services.mozilla.com/").unwrap())
.unwrap()
.as_str(),
"https://token.services.mozilla.com/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(
Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5").unwrap()
)
.unwrap()
.as_str(),
"https://selfhosted.example.com/token/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(
Url::parse("https://selfhosted.example.com/token/1.0/sync/1.5/").unwrap()
)
.unwrap()
.as_str(),
"https://selfhosted.example.com/token/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(Url::parse("https://selfhosted.example.com/token/").unwrap())
.unwrap()
.as_str(),
"https://selfhosted.example.com/token/1.0/sync/1.5"
);
assert_eq!(
fixup_server_url(Url::parse("https://selfhosted.example.com/token").unwrap())
.unwrap()
.as_str(),
"https://selfhosted.example.com/token/1.0/sync/1.5"
);
}
}