1use libc::IPPROTO_TCP;
6use num_enum::{IntoPrimitive, TryFromPrimitive};
7use socket2::{Domain, Protocol, SockAddr, Socket, Type};
8use std::io::{Read, Write};
9use std::net::SocketAddr;
10use std::{mem::size_of, time::Duration};
11use winnow::binary::{le_u16, le_u32, le_u64, le_u8};
12use winnow::combinator::repeat;
13use winnow::error::{ContextError, ErrMode};
14use winnow::token::take;
15use winnow::{ModalResult, Parser};
16
17const PF_KEY: i32 = 27;
19const PF_KEY_V2: u8 = 2;
21
22#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
24#[repr(u8)]
25pub enum MessageType {
26 Reserved = 0,
27 GetSpi = 1,
28 Update = 2,
29 Add = 3,
30 Delete = 4,
31 Get = 5,
32 Acquire = 6,
33 Register = 7,
34 Expire = 8,
35 Flush = 9,
36 Dump = 10,
37 Promisc = 11,
38 InverseAcquire = 12,
39 UpdatePair = 13,
40 DelPair = 14,
41 DelPairState = 15,
42}
43
44#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
46#[repr(u8)]
47pub enum SaType {
48 Unspec = 0,
49 Ah = 2,
50 Esp = 3,
51 TcpSig = 4,
52 Rsvp = 5,
53 OspvV2 = 6,
54 RipV2 = 7,
55 Mip = 8,
56}
57
58#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
60#[repr(u16)]
61pub enum SaExtType {
62 Sa = 1,
63 LifetimeCurrent = 2,
64 LifetimeHard = 3,
65 LifetimeSoft = 4,
66 AddressSrc = 5,
67 AddressDst = 6,
68 StrAuth = 28,
90}
91
92#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
94#[repr(u8)]
95pub enum SaAuthType {
96 None = 0,
97 Md5 = 1,
98 Md5Hmac = 2,
99 Sha1Hmac = 3,
100 Sha256Hmac = 5,
101 Sha384Hmac = 6,
102 Sha512Hmac = 7,
103}
104
105#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
107#[repr(u8)]
108pub enum SaEncryptType {
109 None = 0,
110 DesCbc = 2,
111 DesCbc3 = 3,
112 Blowfish = 7,
113 Null = 11,
114 Aes = 12,
115 AesCcm8 = 14,
116 AesCcm12 = 15,
117 AesCcm16 = 16,
118 AesGcm8 = 18,
119 AesGcm12 = 19,
120 AesGcm16 = 20,
121}
122
123#[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone)]
125#[repr(u8)]
126pub enum SaState {
127 Larval = 0,
128 Mature = 1,
129 Dying = 2,
130 Dead = 3,
131}
132
133#[derive(Debug)]
135#[repr(C, packed)]
136pub struct Header {
137 pub version: u8,
139 pub typ: MessageType,
141 pub errno: u8,
143 pub sa_typ: SaType,
145 pub len: u16,
147 pub reserved: u16,
150 pub seq: u32,
152 pub pid: u32,
154}
155
156impl Header {
157 pub fn new(typ: MessageType, sa_typ: SaType, len: usize) -> Self {
158 Header {
159 version: PF_KEY_V2,
160 typ,
161 errno: 0,
162 sa_typ,
163 len: u16::try_from(len).unwrap() >> 3,
164 reserved: 0,
165 seq: rand::random(),
166 pid: std::process::id(),
167 }
168 }
169}
170
171#[derive(Debug)]
174pub enum Extension {
175 Association(Association),
176 Lifetime(Lifetime),
177 Address(Address),
178 StrAuth(StrAuth),
179}
180
181#[derive(Clone, Copy, Debug)]
183#[repr(C, packed)]
184pub struct Association {
185 pub len: u16,
187 pub typ: SaExtType,
189 pub spi: u32,
191 pub replay: u8,
193 pub state: SaState,
195 pub auth: SaAuthType,
197 pub encrypt: SaEncryptType,
199 pub flags: u32,
201}
202
203impl Default for Association {
204 fn default() -> Self {
205 Association {
206 len: u16::try_from(size_of::<Association>()).unwrap() >> 3,
207 typ: SaExtType::Sa,
208 spi: 0, replay: 0,
210 state: SaState::Mature,
211 auth: SaAuthType::Md5,
212 encrypt: SaEncryptType::None,
213 flags: 0,
214 }
215 }
216}
217
218#[derive(Debug)]
220#[repr(C, packed)]
221pub struct Lifetime {
222 pub len: u16,
224 pub typ: SaExtType,
226 pub alloc: u32,
228 pub bytes: u64,
230 pub addtime: u64,
232 pub usetime: u64,
234}
235
236impl Lifetime {
237 pub fn hard(addtime: Duration) -> Self {
239 Lifetime {
240 len: u16::try_from(size_of::<Lifetime>()).unwrap() >> 3,
241 typ: SaExtType::LifetimeHard,
242 alloc: 0, bytes: 0, addtime: addtime.as_secs(),
245 usetime: 0,
246 }
247 }
248 pub fn soft(addtime: Duration) -> Self {
250 Lifetime {
251 len: u16::try_from(size_of::<Lifetime>()).unwrap() >> 3,
252 typ: SaExtType::LifetimeSoft,
253 alloc: 0, bytes: 0, addtime: addtime.as_secs(),
256 usetime: 0,
257 }
258 }
259}
260
261#[repr(C, packed)]
263pub struct Address {
264 pub len: u16,
266 pub typ: SaExtType,
268 pub proto: u8,
270 pub prefix_len: u8,
272 pub reserved: u16,
274 pub sockaddr: SockAddr,
276}
277
278impl std::fmt::Debug for Address {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 let s = unsafe { (self as *const Address).read_unaligned() };
281 let len = s.len;
282 let typ = s.typ;
283 let proto = s.proto;
284 let plen = s.prefix_len;
285 let res = s.reserved;
286 let sa = s.sockaddr;
287 let sa = sa.as_socket();
288 f.debug_struct("Address")
289 .field("len", &len)
290 .field("typ", &typ)
291 .field("proto", &proto)
292 .field("prefix_len", &plen)
293 .field("reserved", &res)
294 .field("sockaddr", &sa)
295 .finish()
296 }
297}
298
299impl Address {
300 pub fn src(sockaddr: SockAddr, proto: u8) -> Self {
302 Self::new(sockaddr, proto, SaExtType::AddressSrc)
303 }
304
305 pub fn dst(sockaddr: SockAddr, proto: u8) -> Self {
307 Self::new(sockaddr, proto, SaExtType::AddressDst)
308 }
309
310 pub fn new(sockaddr: SockAddr, proto: u8, typ: SaExtType) -> Self {
312 Address {
313 len: u16::try_from(size_of::<Address>()).unwrap() >> 3,
314 typ,
315 proto,
316 prefix_len: 0,
317 reserved: 0,
318 sockaddr,
319 }
320 }
321
322 pub fn get_sockaddr(&self) -> Option<SocketAddr> {
324 let s = unsafe { (self as *const Address).read_unaligned() };
325 let sa = s.sockaddr;
326 sa.as_socket()
327 }
328}
329
330#[repr(C, packed)]
332pub struct StrAuth {
333 pub len: u16,
335 pub typ: SaExtType,
337 pub bits: u16,
339 pub reserved: u16,
341 pub data: [u8; StrAuth::MAX_KEY_LEN],
343}
344
345impl std::fmt::Debug for StrAuth {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 let s = unsafe { (self as *const StrAuth).read_unaligned() };
348 let len = s.len;
349 let typ = s.typ;
350 let bits = s.bits;
351 let res = s.reserved;
352 let key = s.key();
353 f.debug_struct("StrAuth")
354 .field("len", &len)
355 .field("typ", &typ)
356 .field("bits", &bits)
357 .field("reserved", &res)
358 .field("data", &key)
359 .finish()
360 }
361}
362
363impl StrAuth {
364 pub const MAX_KEY_LEN: usize = 80;
366 const HEADER_LEN: usize = size_of::<Self>() - Self::MAX_KEY_LEN;
368
369 pub fn new(authstring: &str) -> Result<Self, Error> {
371 let key_len = authstring.len();
372 if key_len > Self::MAX_KEY_LEN {
373 return Err(Error::PfKeyParse(format!(
374 "authstring exceeds {} bytes",
375 Self::MAX_KEY_LEN
376 )));
377 }
378 let mut key = StrAuth {
379 len: u16::try_from(size_of::<StrAuth>()).unwrap() >> 3,
380 typ: SaExtType::StrAuth,
381 bits: u16::try_from(key_len << 3).unwrap(),
382 reserved: 0,
383 data: [0; Self::MAX_KEY_LEN],
384 };
385 key.data[..key_len].copy_from_slice(authstring.as_bytes());
386 Ok(key)
387 }
388
389 pub fn key(&self) -> String {
391 let s = unsafe { (self as *const StrAuth).read_unaligned() };
392 let bits = s.bits;
393 let bytelen = ((bits >> 3) as usize).min(Self::MAX_KEY_LEN);
396 let data = s.data;
397 String::from_utf8_lossy(&data[..bytelen]).to_string()
398 }
399}
400
401#[repr(C, packed)]
403pub struct TcpMd5AddKeyRequest {
404 pub header: Header,
406 pub association: Association,
408 pub lifetime: Lifetime,
410 pub src: Address,
412 pub dst: Address,
414 pub key: StrAuth,
416}
417
418impl TcpMd5AddKeyRequest {
419 pub fn new(
421 src: SockAddr,
422 dst: SockAddr,
423 authstring: &str,
424 valid_time: Duration,
425 ) -> Result<Self, Error> {
426 Ok(Self {
427 header: Header::new(
428 MessageType::Add,
429 SaType::TcpSig,
430 size_of::<Self>(),
431 ),
432 association: Association::default(),
433 lifetime: Lifetime::hard(valid_time),
434 src: Address::src(src, IPPROTO_TCP as u8),
435 dst: Address::dst(dst, IPPROTO_TCP as u8),
436 key: StrAuth::new(authstring)?,
437 })
438 }
439}
440
441#[repr(C, packed)]
443pub struct TcpMd5UpdateKeyRequest {
444 pub header: Header,
446 pub association: Association,
448 pub lifetime: Lifetime,
450 pub src: Address,
452 pub dst: Address,
454}
455
456impl TcpMd5UpdateKeyRequest {
457 pub fn new(src: SockAddr, dst: SockAddr, valid_time: Duration) -> Self {
459 Self {
460 header: Header::new(
461 MessageType::Update,
462 SaType::TcpSig,
463 size_of::<Self>(),
464 ),
465 association: Association::default(),
466 lifetime: Lifetime::hard(valid_time),
467 src: Address::src(src, IPPROTO_TCP as u8),
468 dst: Address::dst(dst, IPPROTO_TCP as u8),
469 }
470 }
471}
472
473#[repr(C, packed)]
475pub struct TcpMd5DeleteKeyRequest {
476 pub header: Header,
478 pub association: Association,
480 pub src: Address,
482 pub dst: Address,
484}
485
486impl TcpMd5DeleteKeyRequest {
487 pub fn new(src: SockAddr, dst: SockAddr) -> Self {
489 Self {
490 header: Header::new(
491 MessageType::Delete,
492 SaType::TcpSig,
493 size_of::<Self>(),
494 ),
495 association: Association::default(),
496 src: Address::src(src, IPPROTO_TCP as u8),
497 dst: Address::dst(dst, IPPROTO_TCP as u8),
498 }
499 }
500}
501
502#[repr(C, packed)]
504pub struct TcpMd5GetKeyRequest {
505 pub header: Header,
507 pub association: Association,
509 pub src: Address,
511 pub dst: Address,
513}
514
515impl TcpMd5GetKeyRequest {
516 pub fn new(src: SockAddr, dst: SockAddr) -> Self {
518 Self {
519 header: Header::new(
520 MessageType::Get,
521 SaType::TcpSig,
522 size_of::<Self>(),
523 ),
524 association: Association::default(),
525 src: Address::src(src, IPPROTO_TCP as u8),
526 dst: Address::dst(dst, IPPROTO_TCP as u8),
527 }
528 }
529}
530
531#[derive(Debug)]
533pub struct GetAssociationResponse {
534 pub header: Header,
535 pub extensions: Vec<Extension>,
536}
537
538pub fn tcp_md5_key_add(
543 src: SockAddr,
544 dst: SockAddr,
545 authstring: &str,
546 valid_time: Duration,
547) -> Result<(), Error> {
548 let msg = TcpMd5AddKeyRequest::new(src, dst, authstring, valid_time)?;
549 let mut sock = Socket::new(
550 Domain::from(PF_KEY),
551 Type::RAW,
552 Some(Protocol::from(i32::from(PF_KEY_V2))),
553 )?;
554 let data = unsafe {
555 std::slice::from_raw_parts(
556 (&msg as *const TcpMd5AddKeyRequest) as *const u8,
557 size_of::<TcpMd5AddKeyRequest>(),
558 )
559 };
560 let n = sock.write(data)?;
561 if n != data.len() {
562 return Err(std::io::Error::new(
563 std::io::ErrorKind::UnexpectedEof,
564 format!("short write {} != {}", n, data.len()),
565 )
566 .into());
567 }
568
569 let mut buf = [0u8; 1024];
570 let _n = sock.read(&mut buf)?;
571 let response = unsafe { &*(buf.as_ptr() as *const Header) };
572
573 if response.errno != 0 {
574 if response.seq != msg.header.seq {
575 return Err(Error::PfKeySequenceMismatch {
576 expected: msg.header.seq,
577 received: response.seq,
578 });
579 }
580 return Err(Error::PfKey {
581 errno: response.errno,
582 typ: response.typ,
583 sa_typ: response.sa_typ,
584 diagnostic: response.reserved,
585 });
586 }
587
588 Ok(())
589}
590
591pub fn tcp_md5_key_update(
595 src: SockAddr,
596 dst: SockAddr,
597 valid_time: Duration,
598) -> Result<(), Error> {
599 let msg = TcpMd5UpdateKeyRequest::new(src, dst, valid_time);
600 let mut sock = Socket::new(
601 Domain::from(PF_KEY),
602 Type::RAW,
603 Some(Protocol::from(i32::from(PF_KEY_V2))),
604 )?;
605 let data = unsafe {
606 std::slice::from_raw_parts(
607 (&msg as *const TcpMd5UpdateKeyRequest) as *const u8,
608 size_of::<TcpMd5UpdateKeyRequest>(),
609 )
610 };
611 let n = sock.write(data)?;
612 if n != data.len() {
613 return Err(std::io::Error::new(
614 std::io::ErrorKind::UnexpectedEof,
615 format!("short write {} != {}", n, data.len()),
616 )
617 .into());
618 }
619
620 let mut buf = [0u8; 1024];
621 let _n = sock.read(&mut buf)?;
622 let response = unsafe { &*(buf.as_ptr() as *const Header) };
623
624 if response.errno != 0 {
625 if response.seq != msg.header.seq {
626 return Err(Error::PfKeySequenceMismatch {
627 expected: msg.header.seq,
628 received: response.seq,
629 });
630 }
631 return Err(Error::PfKey {
632 errno: response.errno,
633 typ: response.typ,
634 sa_typ: response.sa_typ,
635 diagnostic: response.reserved,
636 });
637 }
638
639 Ok(())
640}
641
642pub fn tcp_md5_key_get(
645 src: SockAddr,
646 dst: SockAddr,
647) -> Result<GetAssociationResponse, Error> {
648 let msg = TcpMd5GetKeyRequest::new(src, dst);
649 let mut sock = Socket::new(
650 Domain::from(PF_KEY),
651 Type::RAW,
652 Some(Protocol::from(i32::from(PF_KEY_V2))),
653 )?;
654 let data = unsafe {
655 std::slice::from_raw_parts(
656 (&msg as *const TcpMd5GetKeyRequest) as *const u8,
657 size_of::<TcpMd5GetKeyRequest>(),
658 )
659 };
660 let n = sock.write(data)?;
661 if n != data.len() {
662 return Err(std::io::Error::new(
663 std::io::ErrorKind::UnexpectedEof,
664 format!("short write {} != {}", n, data.len()),
665 )
666 .into());
667 }
668
669 let mut buf = [0u8; 1024];
670 let _n = sock.read(&mut buf)?;
671 let response = unsafe { &*(buf.as_ptr() as *const Header) };
672
673 if response.errno != 0 {
674 return Err(Error::PfKey {
675 errno: response.errno,
676 typ: response.typ,
677 sa_typ: response.sa_typ,
678 diagnostic: response.reserved,
679 });
680 }
681
682 let cursor = &mut buf.as_slice();
683 parse::association_response
684 .parse_next(cursor)
685 .map_err(|e| Error::PfKeyParse(format!("{e:?}")))
686}
687
688pub fn tcp_md5_key_remove(src: SockAddr, dst: SockAddr) -> Result<(), Error> {
691 let msg = TcpMd5DeleteKeyRequest::new(src, dst);
692 let mut sock = Socket::new(
693 Domain::from(PF_KEY),
694 Type::RAW,
695 Some(Protocol::from(i32::from(PF_KEY_V2))),
696 )?;
697 let data = unsafe {
698 std::slice::from_raw_parts(
699 (&msg as *const TcpMd5DeleteKeyRequest) as *const u8,
700 size_of::<TcpMd5DeleteKeyRequest>(),
701 )
702 };
703 let n = sock.write(data)?;
704 if n != data.len() {
705 return Err(std::io::Error::new(
706 std::io::ErrorKind::UnexpectedEof,
707 format!("short write {} != {}", n, data.len()),
708 )
709 .into());
710 }
711
712 let mut buf = [0u8; 1024];
713 let _n = sock.read(&mut buf)?;
714 let response = unsafe { &*(buf.as_ptr() as *const Header) };
715
716 if response.errno != 0 {
717 if response.seq != msg.header.seq {
718 return Err(Error::PfKeySequenceMismatch {
719 expected: msg.header.seq,
720 received: response.seq,
721 });
722 }
723 return Err(Error::PfKey {
724 errno: response.errno,
725 typ: response.typ,
726 sa_typ: response.sa_typ,
727 diagnostic: response.reserved,
728 });
729 }
730
731 Ok(())
732}
733
734#[derive(thiserror::Error, Debug)]
736pub enum Error {
737 #[error("io error {0}")]
738 Io(#[from] std::io::Error),
739
740 #[error("pfkey {typ:?}/{sa_typ:?} {errno}/{diagnostic}")]
741 PfKey {
742 typ: MessageType,
743 sa_typ: SaType,
744 errno: u8,
745 diagnostic: u16,
746 },
747
748 #[error("pfkey parse {0}")]
749 PfKeyParse(String),
750
751 #[error("pfkey sequence mismatch {expected} {received}")]
752 PfKeySequenceMismatch { expected: u32, received: u32 },
753}
754
755mod parse {
756 use super::*;
757
758 pub fn association_response(
759 buf: &mut &[u8],
760 ) -> ModalResult<GetAssociationResponse> {
761 Ok(GetAssociationResponse {
762 header: header.parse_next(buf)?,
763 extensions: repeat(0.., extension).parse_next(buf)?,
764 })
765 }
766
767 pub fn header(buf: &mut &[u8]) -> ModalResult<Header> {
768 Ok(Header {
769 version: le_u8.parse_next(buf)?,
770 typ: message_type.parse_next(buf)?,
771 errno: le_u8.parse_next(buf)?,
772 sa_typ: sa_type.parse_next(buf)?,
773 len: le_u16.parse_next(buf)?,
774 reserved: le_u16.parse_next(buf)?,
775 seq: le_u32.parse_next(buf)?,
776 pid: le_u32.parse_next(buf)?,
777 })
778 }
779
780 pub fn extension(buf: &mut &[u8]) -> ModalResult<Extension> {
781 let len = le_u16.parse_next(buf)?;
782 let typ = sa_ext_type.parse_next(buf)?;
783 Ok(match typ {
784 SaExtType::Sa => {
785 Extension::Association(association(len).parse_next(buf)?)
786 }
787 SaExtType::LifetimeCurrent
788 | SaExtType::LifetimeHard
789 | SaExtType::LifetimeSoft => {
790 Extension::Lifetime(lifetime(len, typ).parse_next(buf)?)
791 }
792 SaExtType::AddressSrc | SaExtType::AddressDst => {
793 Extension::Address(address(len, typ).parse_next(buf)?)
794 }
795 SaExtType::StrAuth => {
796 Extension::StrAuth(str_auth(len).parse_next(buf)?)
797 }
798 })
799 }
800
801 pub fn association(
802 len: u16,
803 ) -> impl FnMut(&mut &[u8]) -> ModalResult<Association> {
804 move |buf: &mut &[u8]| -> ModalResult<Association> {
805 Ok(Association {
806 len,
807 typ: SaExtType::Sa,
808 spi: le_u32.parse_next(buf)?,
809 replay: le_u8.parse_next(buf)?,
810 state: sa_state.parse_next(buf)?,
811 auth: sa_auth_type.parse_next(buf)?,
812 encrypt: sa_encrypt_type.parse_next(buf)?,
813 flags: le_u32.parse_next(buf)?,
814 })
815 }
816 }
817
818 pub fn lifetime(
819 len: u16,
820 typ: SaExtType,
821 ) -> impl FnMut(&mut &[u8]) -> ModalResult<Lifetime> {
822 move |buf: &mut &[u8]| -> ModalResult<Lifetime> {
823 Ok(Lifetime {
824 len,
825 typ,
826 alloc: le_u32.parse_next(buf)?,
827 bytes: le_u64.parse_next(buf)?,
828 addtime: le_u64.parse_next(buf)?,
829 usetime: le_u64.parse_next(buf)?,
830 })
831 }
832 }
833
834 pub fn address(
835 len: u16,
836 typ: SaExtType,
837 ) -> impl FnMut(&mut &[u8]) -> ModalResult<Address> {
838 move |buf: &mut &[u8]| -> ModalResult<Address> {
839 let sockaddr_len = ((len as usize) << 3)
840 - (size_of::<Address>() - size_of::<SockAddr>());
841 Ok(Address {
842 len,
843 typ,
844 proto: le_u8.parse_next(buf)?,
845 prefix_len: le_u8.parse_next(buf)?,
846 reserved: le_u16.parse_next(buf)?,
847 sockaddr: unsafe {
848 let x = take(sockaddr_len).parse_next(buf)?;
849 let mut buf = [0; size_of::<SockAddr>()];
850 buf[0..sockaddr_len].copy_from_slice(x);
851 (buf.as_ptr() as *const SockAddr).read_unaligned()
852 },
853 })
854 }
855 }
856
857 pub fn str_auth(
858 len: u16,
859 ) -> impl FnMut(&mut &[u8]) -> ModalResult<StrAuth> {
860 move |buf: &mut &[u8]| -> ModalResult<StrAuth> {
864 let data_len = ((len as usize) << 3)
865 .checked_sub(StrAuth::HEADER_LEN)
866 .filter(|n| *n <= StrAuth::MAX_KEY_LEN)
867 .ok_or_else(|| ErrMode::Cut(ContextError::new()))?;
868 Ok(StrAuth {
869 len,
870 typ: SaExtType::StrAuth,
871 bits: le_u16.parse_next(buf)?,
872 reserved: le_u16.parse_next(buf)?,
873 data: {
874 let x = take(data_len).parse_next(buf)?;
875 let mut buf = [0; StrAuth::MAX_KEY_LEN];
876 buf[0..data_len].copy_from_slice(x);
877 buf
878 },
879 })
880 }
881 }
882
883 pub fn message_type(buf: &mut &[u8]) -> ModalResult<MessageType> {
884 let value = le_u8.parse_next(buf)?;
885 MessageType::try_from_primitive(value)
886 .map_err(|_| ErrMode::Backtrack(ContextError::new()))
887 }
888
889 pub fn sa_type(buf: &mut &[u8]) -> ModalResult<SaType> {
890 let value = le_u8.parse_next(buf)?;
891 SaType::try_from_primitive(value)
892 .map_err(|_| ErrMode::Backtrack(ContextError::new()))
893 }
894
895 fn sa_ext_type(buf: &mut &[u8]) -> ModalResult<SaExtType> {
896 let value = le_u16.parse_next(buf)?;
897 SaExtType::try_from_primitive(value)
898 .map_err(|_| ErrMode::Backtrack(ContextError::new()))
899 }
900
901 fn sa_auth_type(buf: &mut &[u8]) -> ModalResult<SaAuthType> {
902 let value = le_u8.parse_next(buf)?;
903 SaAuthType::try_from_primitive(value)
904 .map_err(|_| ErrMode::Backtrack(ContextError::new()))
905 }
906
907 fn sa_encrypt_type(buf: &mut &[u8]) -> ModalResult<SaEncryptType> {
908 let value = le_u8.parse_next(buf)?;
909 SaEncryptType::try_from_primitive(value)
910 .map_err(|_| ErrMode::Backtrack(ContextError::new()))
911 }
912
913 fn sa_state(buf: &mut &[u8]) -> ModalResult<SaState> {
914 let value = le_u8.parse_next(buf)?;
915 SaState::try_from_primitive(value)
916 .map_err(|_| ErrMode::Backtrack(ContextError::new()))
917 }
918
919 #[cfg(test)]
920 mod tests {
921 use super::*;
922
923 fn str_auth_body(bits: u16, key: &[u8]) -> Vec<u8> {
926 let mut v = Vec::new();
927 v.extend_from_slice(&bits.to_le_bytes());
928 v.extend_from_slice(&0u16.to_le_bytes()); v.extend_from_slice(key);
930 v
931 }
932
933 #[test]
934 fn str_auth_parses_max_key() {
935 let key = [b'k'; StrAuth::MAX_KEY_LEN];
936 let body = str_auth_body((key.len() as u16) << 3, &key);
937 let len = ((StrAuth::HEADER_LEN + key.len()) / 8) as u16;
939
940 let mut input = body.as_slice();
941 let sa = str_auth(len).parse_next(&mut input).expect("valid");
942 assert!(input.is_empty(), "parser should consume the whole body");
943 assert_eq!(sa.key().as_bytes(), &key[..]);
944 }
945
946 #[test]
947 fn str_auth_rejects_overlong_len() {
948 let mut input: &[u8] = &[];
952 assert!(str_auth(12).parse_next(&mut input).is_err());
953 }
954
955 #[test]
956 fn str_auth_rejects_undersized_len() {
957 let mut input: &[u8] = &[];
960 assert!(str_auth(0).parse_next(&mut input).is_err());
961 }
962
963 #[test]
964 fn key_clamps_bogus_bits() {
965 let sa = StrAuth {
968 len: 0,
969 typ: SaExtType::StrAuth,
970 bits: u16::MAX,
971 reserved: 0,
972 data: [b'a'; StrAuth::MAX_KEY_LEN],
973 };
974 assert_eq!(sa.key().len(), StrAuth::MAX_KEY_LEN);
975 }
976 }
977}