use serde_derive::*;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::iter::FromIterator;
use url::Url;
use viaduct::{header_names, status_codes, Headers, Request};
use crate::config::PushConfiguration;
use crate::error::{
self,
ErrorKind::{AlreadyRegisteredError, CommunicationError, CommunicationServerError},
};
use crate::storage::Store;
#[derive(Debug)]
pub struct RegisterResponse {
pub uaid: String,
pub channel_id: String,
pub secret: Option<String>,
pub endpoint: String,
pub senderid: Option<String>,
}
#[serde(untagged)]
#[derive(Serialize, Deserialize)]
pub enum BroadcastValue {
Value(String),
Nested(HashMap<String, BroadcastValue>),
}
pub trait Connection {
fn subscribe(
&mut self,
channel_id: &str,
app_server_key: Option<&str>,
) -> error::Result<RegisterResponse>;
fn unsubscribe(&self, channel_id: Option<&str>) -> error::Result<bool>;
fn update(&mut self, new_token: &str) -> error::Result<bool>;
fn channel_list(&self) -> error::Result<Vec<String>>;
fn verify_connection(&self, channels: &[String]) -> error::Result<bool>;
fn broadcast_subscribe(&self, broadcast: BroadcastValue) -> error::Result<BroadcastValue>;
fn broadcasts(&self) -> error::Result<BroadcastValue>;
}
pub struct ConnectHttp {
pub options: PushConfiguration,
pub uaid: Option<String>,
pub auth: Option<String>,
}
pub fn connect(
options: PushConfiguration,
uaid: Option<String>,
auth: Option<String>,
) -> error::Result<ConnectHttp> {
if options.socket_protocol.is_some() && options.http_protocol.is_some() {
return Err(
CommunicationError("Both socket and HTTP protocols cannot be set.".to_owned()).into(),
);
};
if options.socket_protocol.is_some() {
return Err(CommunicationError("Unsupported".to_owned()).into());
};
if options.bridge_type.is_some() && options.registration_id.is_none() {
return Err(CommunicationError(
"Missing Registration ID, please register with OS first".to_owned(),
)
.into());
};
let connection = ConnectHttp {
uaid,
options,
auth,
};
Ok(connection)
}
impl ConnectHttp {
fn headers(&self) -> error::Result<Headers> {
let mut headers = Headers::new();
if self.auth.is_some() {
headers
.insert(
header_names::AUTHORIZATION,
format!("webpush {}", self.auth.clone().unwrap()),
)
.map_err(|e| {
error::ErrorKind::CommunicationError(format!("Header error: {:?}", e))
})?;
};
Ok(headers)
}
}
impl Connection for ConnectHttp {
fn subscribe(
&mut self,
channel_id: &str,
app_server_key: Option<&str>,
) -> error::Result<RegisterResponse> {
if self.options.http_protocol.is_none() || self.options.bridge_type.is_none() {
return Err(
CommunicationError("Bridge type or application id not set.".to_owned()).into(),
);
}
let options = self.options.clone();
let bridge_type = &options.bridge_type.unwrap();
let mut url = format!(
"{}://{}/v1/{}/{}/registration",
&options.http_protocol.unwrap(),
&options.server_host,
&bridge_type,
&options.sender_id
);
if let Some(uaid) = &self.uaid {
url.push('/');
url.push_str(&uaid);
url.push_str("/subscription");
}
let mut body = HashMap::new();
body.insert("token", options.registration_id.unwrap());
body.insert("channelID", channel_id.to_owned());
if let Some(key) = app_server_key {
body.insert("key", key.to_owned());
}
if &self.options.sender_id == "test" {
self.uaid = Some("abad1d3a00000000aabbccdd00000000".to_owned());
self.auth = Some("LsuUOBKVQRY6-l7_Ajo-Ag".to_owned());
return Ok(RegisterResponse {
uaid: self.uaid.clone().unwrap(),
channel_id: "deadbeef00000000decafbad00000000".to_owned(),
secret: self.auth.clone(),
endpoint: "http://push.example.com/test/opaque".to_owned(),
senderid: Some(self.options.sender_id.clone()),
});
}
let url = Url::parse(&url)?;
let requested = match Request::post(url)
.headers(self.headers()?)
.json(&body)
.send()
{
Ok(v) => v,
Err(e) => {
return Err(
CommunicationServerError(format!("Could not fetch endpoint: {}", e)).into(),
);
}
};
if requested.is_server_error() {
return Err(CommunicationServerError("General Server error".to_string()).into());
}
if requested.is_client_error() {
if requested.status == status_codes::CONFLICT {
return Err(AlreadyRegisteredError.into());
}
return Err(CommunicationError(format!(
"Unhandled client error {} : {:?}",
requested.status,
String::from_utf8_lossy(&requested.body)
))
.into());
}
let response: Value = match requested.json() {
Ok(v) => v,
Err(e) => {
return Err(
CommunicationServerError(format!("Could not parse response: {:?}", e)).into(),
);
}
};
if self.uaid.is_none() {
self.uaid = response["uaid"].as_str().map(ToString::to_string);
}
if self.auth.is_none() {
self.auth = response["secret"].as_str().map(ToString::to_string);
}
let channel_id = response["channelID"].as_str().map(ToString::to_string);
let endpoint = response["endpoint"].as_str().map(ToString::to_string);
Ok(RegisterResponse {
uaid: self.uaid.clone().unwrap(),
channel_id: channel_id.unwrap(),
secret: self.auth.clone(),
endpoint: endpoint.unwrap(),
senderid: response["senderid"].as_str().map(ToString::to_string),
})
}
fn unsubscribe(&self, channel_id: Option<&str>) -> error::Result<bool> {
if self.auth.is_none() {
return Err(CommunicationError("Connection is unauthorized".into()).into());
}
if self.uaid.is_none() {
return Err(CommunicationError("No UAID set".into()).into());
}
let options = self.options.clone();
let mut url = format!(
"{}://{}/v1/{}/{}/registration/{}",
&options.http_protocol.unwrap(),
&options.server_host,
&options.bridge_type.unwrap(),
&options.sender_id,
&self.uaid.clone().unwrap(),
);
if let Some(channel_id) = channel_id {
url = format!("{}/subscription/{}", url, channel_id)
}
if &self.options.sender_id == "test" {
return Ok(true);
}
match Request::delete(Url::parse(&url)?)
.headers(self.headers()?)
.send()
{
Ok(_) => Ok(true),
Err(e) => Err(CommunicationServerError(format!("Could not unsubscribe: {}", e)).into()),
}
}
fn update(&mut self, new_token: &str) -> error::Result<bool> {
if self.options.sender_id == "test" {
self.uaid = Some("abad1d3a00000000aabbccdd00000000".to_owned());
self.auth = Some("LsuUOBKVQRY6-l7_Ajo-Ag".to_owned());
return Ok(true);
}
if self.auth.is_none() {
return Err(CommunicationError("Connection is unauthorized".into()).into());
}
if self.uaid.is_none() {
return Err(CommunicationError("No UAID set".into()).into());
}
self.options.registration_id = Some(new_token.to_owned());
let options = self.options.clone();
let url = format!(
"{}://{}/v1/{}/{}/registration/{}",
&options.http_protocol.unwrap(),
&options.server_host,
&options.bridge_type.unwrap(),
&options.sender_id,
&self.uaid.clone().unwrap()
);
let mut body = HashMap::new();
body.insert("token", new_token);
match Request::put(Url::parse(&url)?)
.json(&body)
.headers(self.headers()?)
.send()
{
Ok(_) => Ok(true),
Err(e) => {
Err(CommunicationServerError(format!("Could not update token: {}", e)).into())
}
}
}
fn channel_list(&self) -> error::Result<Vec<String>> {
#[derive(Deserialize, Debug)]
struct Payload {
uaid: String,
#[serde(rename = "channelIDs")]
channel_ids: Vec<String>,
};
if self.auth.is_none() {
return Err(CommunicationError("Connection is unauthorized".into()).into());
}
if self.uaid.is_none() {
return Err(CommunicationError("No UAID set".into()).into());
}
let options = self.options.clone();
if options.bridge_type.is_none() {
return Err(CommunicationError("No Bridge Type set".into()).into());
}
let url = format!(
"{}://{}/v1/{}/{}/registration/{}",
&options.http_protocol.unwrap_or_else(|| "https".to_owned()),
&options.server_host,
&options.bridge_type.unwrap(),
&options.sender_id,
&self.uaid.clone().unwrap(),
);
let request = match Request::get(Url::parse(&url)?)
.headers(self.headers()?)
.send()
{
Ok(v) => v,
Err(e) => {
return Err(CommunicationServerError(format!(
"Could not fetch channel list: {}",
e
))
.into());
}
};
if request.is_server_error() {
return Err(CommunicationServerError("Server error".to_string()).into());
}
if request.is_client_error() {
return Err(CommunicationError(format!("Unhandled client error {:?}", request)).into());
}
let payload: Payload = match request.json() {
Ok(p) => p,
Err(e) => {
return Err(CommunicationServerError(format!(
"Could not fetch channel_list: Bad Response {:?}",
e
))
.into());
}
};
if payload.uaid != self.uaid.clone().unwrap() {
return Err(
CommunicationServerError("Invalid Response from server".to_string()).into(),
);
}
Ok(payload
.channel_ids
.iter()
.map(|s| Store::normalize_uuid(&s))
.collect())
}
fn broadcast_subscribe(&self, _broadcast: BroadcastValue) -> error::Result<BroadcastValue> {
Err(CommunicationError("Unsupported".to_string()).into())
}
fn broadcasts(&self) -> error::Result<BroadcastValue> {
Err(CommunicationError("Unsupported".to_string()).into())
}
fn verify_connection(&self, channels: &[String]) -> error::Result<bool> {
if self.auth.is_none() {
return Err(CommunicationError("Connection uninitiated".to_owned()).into());
}
if &self.options.sender_id == "test" {
return Ok(false);
}
let local_channels: HashSet<String> = HashSet::from_iter(channels.iter().cloned());
let remote_channels: HashSet<String> = HashSet::from_iter(self.channel_list()?);
if remote_channels != local_channels {
self.unsubscribe(None)?;
return Ok(false);
}
Ok(true)
}
}
#[cfg(test)]
mod test {
use super::*;
use super::Connection;
use mockito::{mock, server_address};
use serde_json::json;
const DUMMY_CHID: &str = "deadbeef00000000decafbad00000000";
const DUMMY_UAID: &str = "abad1dea00000000aabbccdd00000000";
const SENDER_ID: &str = "FakeSenderID";
const SECRET: &str = "SuP3rS1kRet";
#[test]
fn test_communications() {
viaduct_reqwest::use_reqwest_backend();
let config = PushConfiguration {
http_protocol: Some("http".to_owned()),
server_host: server_address().to_string(),
sender_id: SENDER_ID.to_owned(),
bridge_type: Some("test".to_owned()),
registration_id: Some("SomeRegistrationValue".to_owned()),
..Default::default()
};
{
let body = json!({
"uaid": DUMMY_UAID,
"channelID": DUMMY_CHID,
"endpoint": "https://example.com/update",
"senderid": SENDER_ID,
"secret": SECRET,
})
.to_string();
let ap_mock = mock(
"POST",
format!("/v1/test/{}/registration", SENDER_ID).as_ref(),
)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(body)
.create();
let mut conn = connect(config.clone(), None, None).unwrap();
let channel_id = hex::encode(crate::crypto::get_bytes(16).unwrap());
let response = conn.subscribe(&channel_id, None).unwrap();
ap_mock.assert();
assert_eq!(response.uaid, DUMMY_UAID);
assert_eq!(conn.auth, Some(SECRET.to_owned()));
}
{
let body = json!({
"uaid": DUMMY_UAID,
"channelID": DUMMY_CHID,
"endpoint": "https://example.com/update",
"senderid": SENDER_ID,
"secret": null,
})
.to_string();
let ap_ns_mock = mock(
"POST",
format!("/v1/test/{}/registration", SENDER_ID).as_ref(),
)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(body)
.create();
let mut conn = connect(config.clone(), None, None).unwrap();
let channel_id = hex::encode(crate::crypto::get_bytes(16).unwrap());
let response = conn.subscribe(&channel_id, None).unwrap();
ap_ns_mock.assert();
assert_eq!(response.uaid, DUMMY_UAID);
assert_eq!(conn.auth, None);
}
{
let ap_mock = mock(
"DELETE",
format!(
"/v1/test/{}/registration/{}/subscription/{}",
SENDER_ID, DUMMY_UAID, DUMMY_CHID
)
.as_ref(),
)
.match_header("authorization", format!("webpush {}", SECRET).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body("{}")
.create();
let conn = connect(
config.clone(),
Some(DUMMY_UAID.to_owned()),
Some(SECRET.to_owned()),
)
.unwrap();
let response = conn.unsubscribe(Some(DUMMY_CHID)).unwrap();
ap_mock.assert();
assert!(response);
}
{
let ap_mock = mock(
"DELETE",
format!("/v1/test/{}/registration/{}", SENDER_ID, DUMMY_UAID).as_ref(),
)
.match_header("authorization", format!("webpush {}", SECRET).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body("{}")
.create();
let conn = connect(
config.clone(),
Some(DUMMY_UAID.to_owned()),
Some(SECRET.to_owned()),
)
.unwrap();
let response = conn.unsubscribe(None).unwrap();
ap_mock.assert();
assert!(response);
}
{
let ap_mock = mock(
"PUT",
format!("/v1/test/{}/registration/{}", SENDER_ID, DUMMY_UAID).as_ref(),
)
.match_header("authorization", format!("webpush {}", SECRET).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body("{}")
.create();
let mut conn = connect(
config.clone(),
Some(DUMMY_UAID.to_owned()),
Some(SECRET.to_owned()),
)
.unwrap();
let response = conn.update("NewTokenValue").unwrap();
ap_mock.assert();
assert_eq!(
conn.options.registration_id,
Some("NewTokenValue".to_owned())
);
assert!(response);
}
{
let body_cl_success = json!({
"uaid": DUMMY_UAID,
"channelIDs": [DUMMY_CHID],
})
.to_string();
let ap_mock = mock(
"GET",
format!("/v1/test/{}/registration/{}", SENDER_ID, DUMMY_UAID).as_ref(),
)
.match_header("authorization", format!("webpush {}", SECRET).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(body_cl_success)
.create();
let conn =
connect(config, Some(DUMMY_UAID.to_owned()), Some(SECRET.to_owned())).unwrap();
let response = conn.channel_list().unwrap();
ap_mock.assert();
assert!(response == [DUMMY_CHID.to_owned()]);
}
}
}