1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

use std::borrow::Cow;

/// Represents a header name that we know to be both valid and lowercase.
/// Internally, this avoids allocating for headers that are constant strings,
/// like the predefined ones in this crate, however even without that
/// optimization, we would still likely have an equivalent of this for use
/// as a case-insensitive string guaranteed to only have valid characters.
#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct HeaderName(pub(super) Cow<'static, str>);

/// Indicates an invalid header name. Note that we only emit
/// this for response headers, for request headers, we panic
/// instead. This is because it would likely come through as
/// a network error if we emitted it for local headers, when
/// it's actually a bug that we'd need to fix.
#[derive(thiserror::Error, Debug, Clone, PartialEq)]
#[error("Invalid header name: {0:?}")]
pub struct InvalidHeaderName(Cow<'static, str>);

impl From<&'static str> for HeaderName {
    fn from(s: &'static str) -> HeaderName {
        match HeaderName::new(s) {
            Ok(v) => v,
            Err(e) => {
                panic!("Illegal locally specified header {}", e);
            }
        }
    }
}

impl From<String> for HeaderName {
    fn from(s: String) -> HeaderName {
        match HeaderName::new(s) {
            Ok(v) => v,
            Err(e) => {
                panic!("Illegal locally specified header {}", e);
            }
        }
    }
}

impl From<Cow<'static, str>> for HeaderName {
    fn from(s: Cow<'static, str>) -> HeaderName {
        match HeaderName::new(s) {
            Ok(v) => v,
            Err(e) => {
                panic!("Illegal locally specified header {}", e);
            }
        }
    }
}

impl InvalidHeaderName {
    pub fn name(&self) -> &str {
        &self.0[..]
    }
}

fn validate_header(mut name: Cow<'static, str>) -> Result<HeaderName, InvalidHeaderName> {
    if name.len() == 0 {
        return Err(invalid_header_name(name));
    }
    let mut need_lower_case = false;
    for b in name.bytes() {
        let validity = VALID_HEADER_LUT[b as usize];
        if validity == 0 {
            return Err(invalid_header_name(name));
        }
        if validity == 2 {
            need_lower_case = true;
        }
    }
    if need_lower_case {
        // Only do this if needed, since it causes us to own the header.
        name.to_mut().make_ascii_lowercase();
    }
    Ok(HeaderName(name))
}

impl HeaderName {
    /// Create a new header. In general you likely want to use `HeaderName::from(s)`
    /// instead for headers being specified locally (This will panic instead of
    /// returning a Result, since we have control over headers we specify locally,
    /// and want to know if we specify an illegal one).
    #[inline]
    pub fn new<S: Into<Cow<'static, str>>>(s: S) -> Result<Self, InvalidHeaderName> {
        validate_header(s.into())
    }

    #[inline]
    pub fn as_str(&self) -> &str {
        &self.0[..]
    }
}

impl std::fmt::Display for HeaderName {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(self.as_str())
    }
}

// Separate for dumb micro-optimization reasons.
#[cold]
#[inline(never)]
fn invalid_header_name(s: Cow<'static, str>) -> InvalidHeaderName {
    log::warn!("Invalid header name: {}", s);
    InvalidHeaderName(s)
}
// Note: 0 = invalid, 1 = valid, 2 = valid but needs lowercasing. I'd use an
// enum for this, but it would make this LUT *way* harder to look at. This
// includes 0-9, a-z, A-Z (as 2), and ('!' | '#' | '$' | '%' | '&' | '\'' | '*'
// | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~'), matching the field-name
// token production defined at https://tools.ietf.org/html/rfc7230#section-3.2.
static VALID_HEADER_LUT: [u8; 256] = [
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
    0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 1, 1,
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];

impl std::ops::Deref for HeaderName {
    type Target = str;
    #[inline]
    fn deref(&self) -> &str {
        self.as_str()
    }
}

impl AsRef<str> for HeaderName {
    #[inline]
    fn as_ref(&self) -> &str {
        self.as_str()
    }
}

impl AsRef<[u8]> for HeaderName {
    #[inline]
    fn as_ref(&self) -> &[u8] {
        self.as_str().as_bytes()
    }
}

impl From<HeaderName> for String {
    #[inline]
    fn from(h: HeaderName) -> Self {
        h.0.into()
    }
}

impl From<HeaderName> for Cow<'static, str> {
    #[inline]
    fn from(h: HeaderName) -> Self {
        h.0
    }
}

impl From<HeaderName> for Vec<u8> {
    #[inline]
    fn from(h: HeaderName) -> Self {
        String::from(h.0).into()
    }
}

macro_rules! partialeq_boilerplate {
    ($T0:ty, $T1:ty) => {
        impl<'a> PartialEq<$T0> for $T1 {
            fn eq(&self, other: &$T0) -> bool {
                // The &* should invoke Deref::deref if it exists, no-op otherwise.
                (&*self).eq_ignore_ascii_case(&*other)
            }
        }
        impl<'a> PartialEq<$T1> for $T0 {
            fn eq(&self, other: &$T1) -> bool {
                PartialEq::eq(other, self)
            }
        }
    };
}

partialeq_boilerplate!(HeaderName, str);
partialeq_boilerplate!(HeaderName, &'a str);
partialeq_boilerplate!(HeaderName, String);
partialeq_boilerplate!(HeaderName, &'a String);
partialeq_boilerplate!(HeaderName, Cow<'a, str>);

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_lut() {
        let mut expect = [0u8; 256];
        for b in b'0'..=b'9' {
            expect[b as usize] = 1;
        }
        for b in b'a'..=b'z' {
            expect[b as usize] = 1;
        }
        for b in b'A'..=b'Z' {
            expect[b as usize] = 2;
        }
        for b in b"!#$%&'*+-.^_`|~" {
            expect[*b as usize] = 1;
        }
        assert_eq!(&VALID_HEADER_LUT[..], &expect[..]);
    }
    #[test]
    fn test_validate() {
        assert!(validate_header("".into()).is_err());
        assert!(validate_header(" foo ".into()).is_err());
        assert!(validate_header("a=b".into()).is_err());
        assert_eq!(
            validate_header("content-type".into()),
            Ok(HeaderName("content-type".into()))
        );
        assert_eq!(
            validate_header("Content-Type".into()),
            Ok(HeaderName("content-type".into()))
        );
    }
}