libnet/
pf_key.rs

1// Copyright 2024 Oxide Computer Company
2
3// This file implements the PF_KEY protocol as described in RFC 2367.
4
5use libc::IPPROTO_TCP;
6use num_enum::{IntoPrimitive, TryFromPrimitive};
7use socket2::{Domain, Protocol, SockAddr, Socket, Type};
8use std::io::{Read, Write};
9use std::net::SocketAddr;
10use std::{mem::size_of, time::Duration};
11use winnow::binary::{le_u16, le_u32, le_u64, le_u8};
12use winnow::combinator::repeat;
13use winnow::error::{ContextError, ErrMode};
14use winnow::token::take;
15use winnow::{ModalResult, Parser};
16
17/// The PF_KEY protocol family.
18const PF_KEY: i32 = 27;
19/// The PF_KEY protocol version.
20const PF_KEY_V2: u8 = 2;
21
22/// PF_KEY message types.
23#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
24#[repr(u8)]
25pub enum MessageType {
26    Reserved = 0,
27    GetSpi = 1,
28    Update = 2,
29    Add = 3,
30    Delete = 4,
31    Get = 5,
32    Acquire = 6,
33    Register = 7,
34    Expire = 8,
35    Flush = 9,
36    Dump = 10,
37    Promisc = 11,
38    InverseAcquire = 12,
39    UpdatePair = 13,
40    DelPair = 14,
41    DelPairState = 15,
42}
43
44/// PF_KEY security association types.
45#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
46#[repr(u8)]
47pub enum SaType {
48    Unspec = 0,
49    Ah = 2,
50    Esp = 3,
51    TcpSig = 4,
52    Rsvp = 5,
53    OspvV2 = 6,
54    RipV2 = 7,
55    Mip = 8,
56}
57
58/// PF_KEY security association extension types.
59#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
60#[repr(u16)]
61pub enum SaExtType {
62    Sa = 1,
63    LifetimeCurrent = 2,
64    LifetimeHard = 3,
65    LifetimeSoft = 4,
66    AddressSrc = 5,
67    AddressDst = 6,
68    //TODO AddressProxy = 7,
69    //TODO KeyAuth = 8,
70    //TODO KeyEncrypt = 9,
71    //TODO IdentitySrc = 10,
72    //TODO IdentityDst = 11,
73    //TODO Sensitivity = 12,
74    //TODO Proposal = 13,
75    //TODO SupportedAuth = 14,
76    //TODO SupportedEncrypt = 15,
77    //TODO SpiRange = 16,
78    //TODO Ereg = 17,
79    //TODO Eprop = 18,
80    //TODO KmCookie = 19,
81    //TODO AddressNattLoc = 20,
82    //TODO AddressNattRem = 21,
83    //TODO AddressInnerDst = 22,
84    //TODO Pair = 23,
85    //TODO ReplayValue = 24,
86    //TODO Edump = 25,
87    //TODO LifetimeIdle = 26,
88    //TODO OuterSens = 27,
89    StrAuth = 28,
90}
91
92/// PF_KEY security association authentication types.
93#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
94#[repr(u8)]
95pub enum SaAuthType {
96    None = 0,
97    Md5 = 1,
98    Md5Hmac = 2,
99    Sha1Hmac = 3,
100    Sha256Hmac = 5,
101    Sha384Hmac = 6,
102    Sha512Hmac = 7,
103}
104
105/// PF_KEY security association encryption types.
106#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
107#[repr(u8)]
108pub enum SaEncryptType {
109    None = 0,
110    DesCbc = 2,
111    DesCbc3 = 3,
112    Blowfish = 7,
113    Null = 11,
114    Aes = 12,
115    AesCcm8 = 14,
116    AesCcm12 = 15,
117    AesCcm16 = 16,
118    AesGcm8 = 18,
119    AesGcm12 = 19,
120    AesGcm16 = 20,
121}
122
123/// PF_KEY security association states.
124#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
125#[repr(u8)]
126pub enum SaState {
127    Larval = 0,
128    Mature = 1,
129    Dying = 2,
130    Dead = 3,
131}
132
133/// A PF_KEY security association header.
134#[derive(Debug)]
135#[repr(C, packed)]
136pub struct Header {
137    /// Protocol version. Always PF_KEY_2.
138    pub version: u8,
139    /// The message type.
140    pub typ: MessageType,
141    /// Error returned by OS, if any.
142    pub errno: u8,
143    /// Security association type.
144    pub sa_typ: SaType,
145    /// Length of the message in 8-byte units.
146    pub len: u16,
147    /// Reserved when going to the kernel, diagnostic code when coming from the
148    /// kernel.
149    pub reserved: u16,
150    /// Sequence id for this message.
151    pub seq: u32,
152    /// Process id of the sender.
153    pub pid: u32,
154}
155
156impl Header {
157    pub fn new(typ: MessageType, sa_typ: SaType, len: usize) -> Self {
158        Header {
159            version: PF_KEY_V2,
160            typ,
161            errno: 0,
162            sa_typ,
163            len: u16::try_from(len).unwrap() >> 3,
164            reserved: 0,
165            seq: rand::random(),
166            pid: std::process::id(),
167        }
168    }
169}
170
171/// The extension enumeration contains all PF_KEY extensions supported by this
172/// module.
173#[derive(Debug)]
174pub enum Extension {
175    Association(Association),
176    Lifetime(Lifetime),
177    Address(Address),
178    StrAuth(StrAuth),
179}
180
181/// Basic information about a security association.
182#[derive(Clone, Copy, Debug)]
183#[repr(C, packed)]
184pub struct Association {
185    /// Length of this extension in 8-byte units.
186    pub len: u16,
187    /// The type of this extension.
188    pub typ: SaExtType,
189    /// Security parameters index.
190    pub spi: u32,
191    /// Replay window size.
192    pub replay: u8,
193    /// State of the association.
194    pub state: SaState,
195    /// Authentication type.
196    pub auth: SaAuthType,
197    /// Encryption type.
198    pub encrypt: SaEncryptType,
199    /// Optional flags.
200    pub flags: u32,
201}
202
203impl Default for Association {
204    fn default() -> Self {
205        Association {
206            len: u16::try_from(size_of::<Association>()).unwrap() >> 3,
207            typ: SaExtType::Sa,
208            spi: 0, // This is not for IPsec
209            replay: 0,
210            state: SaState::Mature,
211            auth: SaAuthType::Md5,
212            encrypt: SaEncryptType::None,
213            flags: 0,
214        }
215    }
216}
217
218/// Lifetime information for a security association.
219#[derive(Debug)]
220#[repr(C, packed)]
221pub struct Lifetime {
222    /// Length of this extension in 8-byte units.
223    pub len: u16,
224    /// The type of this extension.
225    pub typ: SaExtType,
226    /// How many allocations this lifetime lasts for.
227    pub alloc: u32,
228    /// How many bytes this lifetime lasts for.
229    pub bytes: u64,
230    /// How long after creation this lifetime expires in seconds.
231    pub addtime: u64,
232    /// How long after first use this lifetime expires in seconds.
233    pub usetime: u64,
234}
235
236impl Lifetime {
237    /// Create a hard lifetime extension.
238    pub fn hard(addtime: Duration) -> Self {
239        Lifetime {
240            len: u16::try_from(size_of::<Lifetime>()).unwrap() >> 3,
241            typ: SaExtType::LifetimeHard,
242            alloc: 0, // no allocation limit
243            bytes: 0, // no byte limit
244            addtime: addtime.as_secs(),
245            usetime: 0,
246        }
247    }
248    /// Create a soft lifetime extension.
249    pub fn soft(addtime: Duration) -> Self {
250        Lifetime {
251            len: u16::try_from(size_of::<Lifetime>()).unwrap() >> 3,
252            typ: SaExtType::LifetimeSoft,
253            alloc: 0, // no allocation limit
254            bytes: 0, // no byte limit
255            addtime: addtime.as_secs(),
256            usetime: 0,
257        }
258    }
259}
260
261/// Address information for a security association.
262#[repr(C, packed)]
263pub struct Address {
264    /// Length of this extension in 8-byte units.
265    pub len: u16,
266    /// The type of this extension.
267    pub typ: SaExtType,
268    /// Protocol family identifier for this address.
269    pub proto: u8,
270    /// Prefix length associated with the address.
271    pub prefix_len: u8,
272    /// Reserved bits.
273    pub reserved: u16,
274    /// Address and port the security association binds to.
275    pub sockaddr: SockAddr,
276}
277
278impl std::fmt::Debug for Address {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        let s = unsafe { (self as *const Address).read_unaligned() };
281        let len = s.len;
282        let typ = s.typ;
283        let proto = s.proto;
284        let plen = s.prefix_len;
285        let res = s.reserved;
286        let sa = s.sockaddr;
287        let sa = sa.as_socket();
288        f.debug_struct("Address")
289            .field("len", &len)
290            .field("typ", &typ)
291            .field("proto", &proto)
292            .field("prefix_len", &plen)
293            .field("reserved", &res)
294            .field("sockaddr", &sa)
295            .finish()
296    }
297}
298
299impl Address {
300    /// Create a new source address extension.
301    pub fn src(sockaddr: SockAddr, proto: u8) -> Self {
302        Self::new(sockaddr, proto, SaExtType::AddressSrc)
303    }
304
305    /// Create a new destination address extension.
306    pub fn dst(sockaddr: SockAddr, proto: u8) -> Self {
307        Self::new(sockaddr, proto, SaExtType::AddressDst)
308    }
309
310    /// Create a new address extension.
311    pub fn new(sockaddr: SockAddr, proto: u8, typ: SaExtType) -> Self {
312        Address {
313            len: u16::try_from(size_of::<Address>()).unwrap() >> 3,
314            typ,
315            proto,
316            prefix_len: 0,
317            reserved: 0,
318            sockaddr,
319        }
320    }
321
322    /// Get the socket address.
323    pub fn get_sockaddr(&self) -> Option<SocketAddr> {
324        let s = unsafe { (self as *const Address).read_unaligned() };
325        let sa = s.sockaddr;
326        sa.as_socket()
327    }
328}
329
330/// String authentication information for this security association.
331#[repr(C, packed)]
332pub struct StrAuth {
333    /// Length of this extension in 8-byte units.
334    pub len: u16,
335    /// The type of this extension.
336    pub typ: SaExtType,
337    /// Length of the key in bits.
338    pub bits: u16,
339    /// Reserved.
340    pub reserved: u16,
341    /// Key data.
342    pub data: [u8; StrAuth::MAX_KEY_LEN],
343}
344
345impl std::fmt::Debug for StrAuth {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        let s = unsafe { (self as *const StrAuth).read_unaligned() };
348        let len = s.len;
349        let typ = s.typ;
350        let bits = s.bits;
351        let res = s.reserved;
352        let key = s.key();
353        f.debug_struct("StrAuth")
354            .field("len", &len)
355            .field("typ", &typ)
356            .field("bits", &bits)
357            .field("reserved", &res)
358            .field("data", &key)
359            .finish()
360    }
361}
362
363impl StrAuth {
364    /// Maximum size of a StrAuth key, in bytes.
365    pub const MAX_KEY_LEN: usize = 80;
366    /// Fixed-header bytes preceding the variable-length key payload.
367    const HEADER_LEN: usize = size_of::<Self>() - Self::MAX_KEY_LEN;
368
369    /// Create a new string authentication extension for a given key.
370    pub fn new(authstring: &str) -> Result<Self, Error> {
371        let key_len = authstring.len();
372        if key_len > Self::MAX_KEY_LEN {
373            return Err(Error::PfKeyParse(format!(
374                "authstring exceeds {} bytes",
375                Self::MAX_KEY_LEN
376            )));
377        }
378        let mut key = StrAuth {
379            len: u16::try_from(size_of::<StrAuth>()).unwrap() >> 3,
380            typ: SaExtType::StrAuth,
381            bits: u16::try_from(key_len << 3).unwrap(),
382            reserved: 0,
383            data: [0; Self::MAX_KEY_LEN],
384        };
385        key.data[..key_len].copy_from_slice(authstring.as_bytes());
386        Ok(key)
387    }
388
389    /// Return the key in string form.
390    pub fn key(&self) -> String {
391        let s = unsafe { (self as *const StrAuth).read_unaligned() };
392        let bits = s.bits;
393        // `bits` is wire-supplied; clamp so a bogus value can't index past the
394        // fixed key buffer.
395        let bytelen = ((bits >> 3) as usize).min(Self::MAX_KEY_LEN);
396        let data = s.data;
397        String::from_utf8_lossy(&data[..bytelen]).to_string()
398    }
399}
400
401/// A packet to add a TCP-MD5 security association.
402#[repr(C, packed)]
403pub struct TcpMd5AddKeyRequest {
404    /// Packet header.
405    pub header: Header,
406    /// Association info.
407    pub association: Association,
408    /// Lifetime info.
409    pub lifetime: Lifetime,
410    /// Source socket address to bind to.
411    pub src: Address,
412    /// Destination socket address to bind to.
413    pub dst: Address,
414    /// String-based key.
415    pub key: StrAuth,
416}
417
418impl TcpMd5AddKeyRequest {
419    /// Create a new TCP-MD5 add key request.
420    pub fn new(
421        src: SockAddr,
422        dst: SockAddr,
423        authstring: &str,
424        valid_time: Duration,
425    ) -> Result<Self, Error> {
426        Ok(Self {
427            header: Header::new(
428                MessageType::Add,
429                SaType::TcpSig,
430                size_of::<Self>(),
431            ),
432            association: Association::default(),
433            lifetime: Lifetime::hard(valid_time),
434            src: Address::src(src, IPPROTO_TCP as u8),
435            dst: Address::dst(dst, IPPROTO_TCP as u8),
436            key: StrAuth::new(authstring)?,
437        })
438    }
439}
440
441/// A packet to update a TCP-MD5 security association.
442#[repr(C, packed)]
443pub struct TcpMd5UpdateKeyRequest {
444    /// Packet header.
445    pub header: Header,
446    /// Association info.
447    pub association: Association,
448    /// Lifetime info.
449    pub lifetime: Lifetime,
450    /// Source socket address to bind to.
451    pub src: Address,
452    /// Destination socket address to bind to.
453    pub dst: Address,
454}
455
456impl TcpMd5UpdateKeyRequest {
457    /// Create a new TCP-MD5 update key request.
458    pub fn new(src: SockAddr, dst: SockAddr, valid_time: Duration) -> Self {
459        Self {
460            header: Header::new(
461                MessageType::Update,
462                SaType::TcpSig,
463                size_of::<Self>(),
464            ),
465            association: Association::default(),
466            lifetime: Lifetime::hard(valid_time),
467            src: Address::src(src, IPPROTO_TCP as u8),
468            dst: Address::dst(dst, IPPROTO_TCP as u8),
469        }
470    }
471}
472
473/// A packet to delete a TCP-MD5 security association.
474#[repr(C, packed)]
475pub struct TcpMd5DeleteKeyRequest {
476    /// Packet header.
477    pub header: Header,
478    /// Association info.
479    pub association: Association,
480    /// Source socket address to unbind.
481    pub src: Address,
482    /// Destination socket address to unbind.
483    pub dst: Address,
484}
485
486impl TcpMd5DeleteKeyRequest {
487    /// Create a new TCP-MD5 delete key request.
488    pub fn new(src: SockAddr, dst: SockAddr) -> Self {
489        Self {
490            header: Header::new(
491                MessageType::Delete,
492                SaType::TcpSig,
493                size_of::<Self>(),
494            ),
495            association: Association::default(),
496            src: Address::src(src, IPPROTO_TCP as u8),
497            dst: Address::dst(dst, IPPROTO_TCP as u8),
498        }
499    }
500}
501
502/// A packet to request info about a TCP-MD5 security association.
503#[repr(C, packed)]
504pub struct TcpMd5GetKeyRequest {
505    /// Packet header.
506    pub header: Header,
507    /// Association info.
508    pub association: Association,
509    /// Source socket address predicate.
510    pub src: Address,
511    /// Destination socket address predicate.
512    pub dst: Address,
513}
514
515impl TcpMd5GetKeyRequest {
516    /// Create a new TCP-MD5 get key request.
517    pub fn new(src: SockAddr, dst: SockAddr) -> Self {
518        Self {
519            header: Header::new(
520                MessageType::Get,
521                SaType::TcpSig,
522                size_of::<Self>(),
523            ),
524            association: Association::default(),
525            src: Address::src(src, IPPROTO_TCP as u8),
526            dst: Address::dst(dst, IPPROTO_TCP as u8),
527        }
528    }
529}
530
531/// Response information returned from kernel from a key association request.
532#[derive(Debug)]
533pub struct GetAssociationResponse {
534    pub header: Header,
535    pub extensions: Vec<Extension>,
536}
537
538/// Add a TCP-MD5 security association for the provided source and destination
539/// address with `authstring` as the key that is valid for `valid_time` after
540/// creation. If update is true, this is treated as an update to an existing
541/// association, otherwise a new association is created.
542pub fn tcp_md5_key_add(
543    src: SockAddr,
544    dst: SockAddr,
545    authstring: &str,
546    valid_time: Duration,
547) -> Result<(), Error> {
548    let msg = TcpMd5AddKeyRequest::new(src, dst, authstring, valid_time)?;
549    let mut sock = Socket::new(
550        Domain::from(PF_KEY),
551        Type::RAW,
552        Some(Protocol::from(i32::from(PF_KEY_V2))),
553    )?;
554    let data = unsafe {
555        std::slice::from_raw_parts(
556            (&msg as *const TcpMd5AddKeyRequest) as *const u8,
557            size_of::<TcpMd5AddKeyRequest>(),
558        )
559    };
560    let n = sock.write(data)?;
561    if n != data.len() {
562        return Err(std::io::Error::new(
563            std::io::ErrorKind::UnexpectedEof,
564            format!("short write {} != {}", n, data.len()),
565        )
566        .into());
567    }
568
569    let mut buf = [0u8; 1024];
570    let _n = sock.read(&mut buf)?;
571    let response = unsafe { &*(buf.as_ptr() as *const Header) };
572
573    if response.errno != 0 {
574        if response.seq != msg.header.seq {
575            return Err(Error::PfKeySequenceMismatch {
576                expected: msg.header.seq,
577                received: response.seq,
578            });
579        }
580        return Err(Error::PfKey {
581            errno: response.errno,
582            typ: response.typ,
583            sa_typ: response.sa_typ,
584            diagnostic: response.reserved,
585        });
586    }
587
588    Ok(())
589}
590
591/// Update a TCP-MD5 security association for the provided source and
592/// destination. This function is primarily for updating the lifetime
593/// of the security association.
594pub fn tcp_md5_key_update(
595    src: SockAddr,
596    dst: SockAddr,
597    valid_time: Duration,
598) -> Result<(), Error> {
599    let msg = TcpMd5UpdateKeyRequest::new(src, dst, valid_time);
600    let mut sock = Socket::new(
601        Domain::from(PF_KEY),
602        Type::RAW,
603        Some(Protocol::from(i32::from(PF_KEY_V2))),
604    )?;
605    let data = unsafe {
606        std::slice::from_raw_parts(
607            (&msg as *const TcpMd5UpdateKeyRequest) as *const u8,
608            size_of::<TcpMd5UpdateKeyRequest>(),
609        )
610    };
611    let n = sock.write(data)?;
612    if n != data.len() {
613        return Err(std::io::Error::new(
614            std::io::ErrorKind::UnexpectedEof,
615            format!("short write {} != {}", n, data.len()),
616        )
617        .into());
618    }
619
620    let mut buf = [0u8; 1024];
621    let _n = sock.read(&mut buf)?;
622    let response = unsafe { &*(buf.as_ptr() as *const Header) };
623
624    if response.errno != 0 {
625        if response.seq != msg.header.seq {
626            return Err(Error::PfKeySequenceMismatch {
627                expected: msg.header.seq,
628                received: response.seq,
629            });
630        }
631        return Err(Error::PfKey {
632            errno: response.errno,
633            typ: response.typ,
634            sa_typ: response.sa_typ,
635            diagnostic: response.reserved,
636        });
637    }
638
639    Ok(())
640}
641
642/// Get info on a TCP-MD5 security association for the provided source and
643/// destination address with
644pub fn tcp_md5_key_get(
645    src: SockAddr,
646    dst: SockAddr,
647) -> Result<GetAssociationResponse, Error> {
648    let msg = TcpMd5GetKeyRequest::new(src, dst);
649    let mut sock = Socket::new(
650        Domain::from(PF_KEY),
651        Type::RAW,
652        Some(Protocol::from(i32::from(PF_KEY_V2))),
653    )?;
654    let data = unsafe {
655        std::slice::from_raw_parts(
656            (&msg as *const TcpMd5GetKeyRequest) as *const u8,
657            size_of::<TcpMd5GetKeyRequest>(),
658        )
659    };
660    let n = sock.write(data)?;
661    if n != data.len() {
662        return Err(std::io::Error::new(
663            std::io::ErrorKind::UnexpectedEof,
664            format!("short write {} != {}", n, data.len()),
665        )
666        .into());
667    }
668
669    let mut buf = [0u8; 1024];
670    let _n = sock.read(&mut buf)?;
671    let response = unsafe { &*(buf.as_ptr() as *const Header) };
672
673    if response.errno != 0 {
674        return Err(Error::PfKey {
675            errno: response.errno,
676            typ: response.typ,
677            sa_typ: response.sa_typ,
678            diagnostic: response.reserved,
679        });
680    }
681
682    let cursor = &mut buf.as_slice();
683    parse::association_response
684        .parse_next(cursor)
685        .map_err(|e| Error::PfKeyParse(format!("{e:?}")))
686}
687
688/// Delete the TCP-MD5 security association for the provided source and
689/// destination.
690pub fn tcp_md5_key_remove(src: SockAddr, dst: SockAddr) -> Result<(), Error> {
691    let msg = TcpMd5DeleteKeyRequest::new(src, dst);
692    let mut sock = Socket::new(
693        Domain::from(PF_KEY),
694        Type::RAW,
695        Some(Protocol::from(i32::from(PF_KEY_V2))),
696    )?;
697    let data = unsafe {
698        std::slice::from_raw_parts(
699            (&msg as *const TcpMd5DeleteKeyRequest) as *const u8,
700            size_of::<TcpMd5DeleteKeyRequest>(),
701        )
702    };
703    let n = sock.write(data)?;
704    if n != data.len() {
705        return Err(std::io::Error::new(
706            std::io::ErrorKind::UnexpectedEof,
707            format!("short write {} != {}", n, data.len()),
708        )
709        .into());
710    }
711
712    let mut buf = [0u8; 1024];
713    let _n = sock.read(&mut buf)?;
714    let response = unsafe { &*(buf.as_ptr() as *const Header) };
715
716    if response.errno != 0 {
717        if response.seq != msg.header.seq {
718            return Err(Error::PfKeySequenceMismatch {
719                expected: msg.header.seq,
720                received: response.seq,
721            });
722        }
723        return Err(Error::PfKey {
724            errno: response.errno,
725            typ: response.typ,
726            sa_typ: response.sa_typ,
727            diagnostic: response.reserved,
728        });
729    }
730
731    Ok(())
732}
733
734/// Errors that can be returned from PF_KEY operations.
735#[derive(thiserror::Error, Debug)]
736pub enum Error {
737    #[error("io error {0}")]
738    Io(#[from] std::io::Error),
739
740    #[error("pfkey {typ:?}/{sa_typ:?} {errno}/{diagnostic}")]
741    PfKey {
742        typ: MessageType,
743        sa_typ: SaType,
744        errno: u8,
745        diagnostic: u16,
746    },
747
748    #[error("pfkey parse {0}")]
749    PfKeyParse(String),
750
751    #[error("pfkey sequence mismatch {expected} {received}")]
752    PfKeySequenceMismatch { expected: u32, received: u32 },
753}
754
755mod parse {
756    use super::*;
757
758    pub fn association_response(
759        buf: &mut &[u8],
760    ) -> ModalResult<GetAssociationResponse> {
761        Ok(GetAssociationResponse {
762            header: header.parse_next(buf)?,
763            extensions: repeat(0.., extension).parse_next(buf)?,
764        })
765    }
766
767    pub fn header(buf: &mut &[u8]) -> ModalResult<Header> {
768        Ok(Header {
769            version: le_u8.parse_next(buf)?,
770            typ: message_type.parse_next(buf)?,
771            errno: le_u8.parse_next(buf)?,
772            sa_typ: sa_type.parse_next(buf)?,
773            len: le_u16.parse_next(buf)?,
774            reserved: le_u16.parse_next(buf)?,
775            seq: le_u32.parse_next(buf)?,
776            pid: le_u32.parse_next(buf)?,
777        })
778    }
779
780    pub fn extension(buf: &mut &[u8]) -> ModalResult<Extension> {
781        let len = le_u16.parse_next(buf)?;
782        let typ = sa_ext_type.parse_next(buf)?;
783        Ok(match typ {
784            SaExtType::Sa => {
785                Extension::Association(association(len).parse_next(buf)?)
786            }
787            SaExtType::LifetimeCurrent
788            | SaExtType::LifetimeHard
789            | SaExtType::LifetimeSoft => {
790                Extension::Lifetime(lifetime(len, typ).parse_next(buf)?)
791            }
792            SaExtType::AddressSrc | SaExtType::AddressDst => {
793                Extension::Address(address(len, typ).parse_next(buf)?)
794            }
795            SaExtType::StrAuth => {
796                Extension::StrAuth(str_auth(len).parse_next(buf)?)
797            }
798        })
799    }
800
801    pub fn association(
802        len: u16,
803    ) -> impl FnMut(&mut &[u8]) -> ModalResult<Association> {
804        move |buf: &mut &[u8]| -> ModalResult<Association> {
805            Ok(Association {
806                len,
807                typ: SaExtType::Sa,
808                spi: le_u32.parse_next(buf)?,
809                replay: le_u8.parse_next(buf)?,
810                state: sa_state.parse_next(buf)?,
811                auth: sa_auth_type.parse_next(buf)?,
812                encrypt: sa_encrypt_type.parse_next(buf)?,
813                flags: le_u32.parse_next(buf)?,
814            })
815        }
816    }
817
818    pub fn lifetime(
819        len: u16,
820        typ: SaExtType,
821    ) -> impl FnMut(&mut &[u8]) -> ModalResult<Lifetime> {
822        move |buf: &mut &[u8]| -> ModalResult<Lifetime> {
823            Ok(Lifetime {
824                len,
825                typ,
826                alloc: le_u32.parse_next(buf)?,
827                bytes: le_u64.parse_next(buf)?,
828                addtime: le_u64.parse_next(buf)?,
829                usetime: le_u64.parse_next(buf)?,
830            })
831        }
832    }
833
834    pub fn address(
835        len: u16,
836        typ: SaExtType,
837    ) -> impl FnMut(&mut &[u8]) -> ModalResult<Address> {
838        move |buf: &mut &[u8]| -> ModalResult<Address> {
839            let sockaddr_len = ((len as usize) << 3)
840                - (size_of::<Address>() - size_of::<SockAddr>());
841            Ok(Address {
842                len,
843                typ,
844                proto: le_u8.parse_next(buf)?,
845                prefix_len: le_u8.parse_next(buf)?,
846                reserved: le_u16.parse_next(buf)?,
847                sockaddr: unsafe {
848                    let x = take(sockaddr_len).parse_next(buf)?;
849                    let mut buf = [0; size_of::<SockAddr>()];
850                    buf[0..sockaddr_len].copy_from_slice(x);
851                    (buf.as_ptr() as *const SockAddr).read_unaligned()
852                },
853            })
854        }
855    }
856
857    pub fn str_auth(
858        len: u16,
859    ) -> impl FnMut(&mut &[u8]) -> ModalResult<StrAuth> {
860        // `len` is the whole-extension size in 8-byte words and is taken
861        // straight off the wire. Derive the key-payload size defensively so a
862        // bogus length can neither underflow nor overrun the fixed key buffer.
863        move |buf: &mut &[u8]| -> ModalResult<StrAuth> {
864            let data_len = ((len as usize) << 3)
865                .checked_sub(StrAuth::HEADER_LEN)
866                .filter(|n| *n <= StrAuth::MAX_KEY_LEN)
867                .ok_or_else(|| ErrMode::Cut(ContextError::new()))?;
868            Ok(StrAuth {
869                len,
870                typ: SaExtType::StrAuth,
871                bits: le_u16.parse_next(buf)?,
872                reserved: le_u16.parse_next(buf)?,
873                data: {
874                    let x = take(data_len).parse_next(buf)?;
875                    let mut buf = [0; StrAuth::MAX_KEY_LEN];
876                    buf[0..data_len].copy_from_slice(x);
877                    buf
878                },
879            })
880        }
881    }
882
883    pub fn message_type(buf: &mut &[u8]) -> ModalResult<MessageType> {
884        let value = le_u8.parse_next(buf)?;
885        MessageType::try_from_primitive(value)
886            .map_err(|_| ErrMode::Backtrack(ContextError::new()))
887    }
888
889    pub fn sa_type(buf: &mut &[u8]) -> ModalResult<SaType> {
890        let value = le_u8.parse_next(buf)?;
891        SaType::try_from_primitive(value)
892            .map_err(|_| ErrMode::Backtrack(ContextError::new()))
893    }
894
895    fn sa_ext_type(buf: &mut &[u8]) -> ModalResult<SaExtType> {
896        let value = le_u16.parse_next(buf)?;
897        SaExtType::try_from_primitive(value)
898            .map_err(|_| ErrMode::Backtrack(ContextError::new()))
899    }
900
901    fn sa_auth_type(buf: &mut &[u8]) -> ModalResult<SaAuthType> {
902        let value = le_u8.parse_next(buf)?;
903        SaAuthType::try_from_primitive(value)
904            .map_err(|_| ErrMode::Backtrack(ContextError::new()))
905    }
906
907    fn sa_encrypt_type(buf: &mut &[u8]) -> ModalResult<SaEncryptType> {
908        let value = le_u8.parse_next(buf)?;
909        SaEncryptType::try_from_primitive(value)
910            .map_err(|_| ErrMode::Backtrack(ContextError::new()))
911    }
912
913    fn sa_state(buf: &mut &[u8]) -> ModalResult<SaState> {
914        let value = le_u8.parse_next(buf)?;
915        SaState::try_from_primitive(value)
916            .map_err(|_| ErrMode::Backtrack(ContextError::new()))
917    }
918
919    #[cfg(test)]
920    mod tests {
921        use super::*;
922
923        /// Build the bytes a `str_auth` parser expects (positioned after the
924        /// extension's len/typ, i.e. starting at `bits`).
925        fn str_auth_body(bits: u16, key: &[u8]) -> Vec<u8> {
926            let mut v = Vec::new();
927            v.extend_from_slice(&bits.to_le_bytes());
928            v.extend_from_slice(&0u16.to_le_bytes()); // reserved
929            v.extend_from_slice(key);
930            v
931        }
932
933        #[test]
934        fn str_auth_parses_max_key() {
935            let key = [b'k'; StrAuth::MAX_KEY_LEN];
936            let body = str_auth_body((key.len() as u16) << 3, &key);
937            // len is the whole-extension size in 8-byte words.
938            let len = ((StrAuth::HEADER_LEN + key.len()) / 8) as u16;
939
940            let mut input = body.as_slice();
941            let sa = str_auth(len).parse_next(&mut input).expect("valid");
942            assert!(input.is_empty(), "parser should consume the whole body");
943            assert_eq!(sa.key().as_bytes(), &key[..]);
944        }
945
946        #[test]
947        fn str_auth_rejects_overlong_len() {
948            // len encoding a payload > StrAuth::MAX_KEY_LEN must error rather
949            // than panic on the slice copy. Validation happens before any bytes
950            // are read, so an empty input is fine.
951            let mut input: &[u8] = &[];
952            assert!(str_auth(12).parse_next(&mut input).is_err());
953        }
954
955        #[test]
956        fn str_auth_rejects_undersized_len() {
957            // len too small to contain the fixed header must error rather than
958            // underflow the payload-length subtraction.
959            let mut input: &[u8] = &[];
960            assert!(str_auth(0).parse_next(&mut input).is_err());
961        }
962
963        #[test]
964        fn key_clamps_bogus_bits() {
965            // A wire-supplied bit length larger than the key buffer must not
966            // panic when read back via key().
967            let sa = StrAuth {
968                len: 0,
969                typ: SaExtType::StrAuth,
970                bits: u16::MAX,
971                reserved: 0,
972                data: [b'a'; StrAuth::MAX_KEY_LEN],
973            };
974            assert_eq!(sa.key().len(), StrAuth::MAX_KEY_LEN);
975        }
976    }
977}