p4rs/
checksum.rs

1// Copyright 2022 Oxide Computer Company
2
3use bitvec::prelude::*;
4
5#[derive(Default)]
6pub struct Csum(u16);
7
8impl Csum {
9    pub fn add(&mut self, a: u8, b: u8) {
10        let x = u16::from_be_bytes([a, b]);
11        let (mut result, overflow) = self.0.overflowing_add(x);
12        if overflow {
13            result += 1;
14        }
15        self.0 = result;
16    }
17    pub fn add128(&mut self, data: [u8; 16]) {
18        self.add(data[0], data[1]);
19        self.add(data[2], data[3]);
20        self.add(data[4], data[5]);
21        self.add(data[6], data[7]);
22        self.add(data[8], data[9]);
23        self.add(data[10], data[11]);
24        self.add(data[12], data[13]);
25        self.add(data[14], data[15]);
26    }
27    pub fn add32(&mut self, data: [u8; 4]) {
28        self.add(data[0], data[1]);
29        self.add(data[2], data[3]);
30    }
31    pub fn add16(&mut self, data: [u8; 2]) {
32        self.add(data[0], data[1]);
33    }
34    pub fn result(&self) -> u16 {
35        !self.0
36    }
37}
38
39pub fn udp6_checksum(data: &[u8]) -> u16 {
40    let src = &data[8..24];
41    let dst = &data[24..40];
42    let udp_len = &data[4..6];
43    let next_header = &data[6];
44    let src_port = &data[40..42];
45    let dst_port = &data[42..44];
46    let payload_len = &data[44..46];
47    let payload = &data[48..];
48
49    let mut csum = Csum(0);
50
51    for i in (0..src.len()).step_by(2) {
52        csum.add(src[i], src[i + 1]);
53    }
54    for i in (0..dst.len()).step_by(2) {
55        csum.add(dst[i], dst[i + 1]);
56    }
57    csum.add(udp_len[0], udp_len[1]);
58    //TODO assuming no jumbo
59    csum.add(0, *next_header);
60    csum.add(src_port[0], src_port[1]);
61    csum.add(dst_port[0], dst_port[1]);
62    csum.add(payload_len[0], payload_len[1]);
63
64    let len = payload.len();
65    let (odd, len) = if len % 2 == 0 {
66        (false, len)
67    } else {
68        (true, len - 1)
69    };
70    for i in (0..len).step_by(2) {
71        csum.add(payload[i], payload[i + 1]);
72    }
73    if odd {
74        csum.add(payload[len], 0);
75    }
76
77    csum.result()
78}
79
80pub trait Checksum {
81    fn csum(&self) -> BitVec<u8, Msb0>;
82}
83
84fn bvec_csum(bv: &BitVec<u8, Msb0>) -> BitVec<u8, Msb0> {
85    let x: u128 = bv.load();
86    let buf = x.to_be_bytes();
87    let mut c: u16 = 0;
88    for i in (0..16).step_by(2) {
89        c += u16::from_be_bytes([buf[i], buf[i + 1]])
90    }
91    let c = !c;
92    let mut result = bitvec![u8, Msb0; 0u8, 16];
93    result.store(c);
94    result
95}
96
97impl Checksum for BitVec<u8, Msb0> {
98    fn csum(&self) -> BitVec<u8, Msb0> {
99        bvec_csum(self)
100    }
101}
102
103impl Checksum for &BitVec<u8, Msb0> {
104    fn csum(&self) -> BitVec<u8, Msb0> {
105        bvec_csum(self)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use pnet::packet::udp;
113    use std::f32::consts::PI;
114    use std::net::Ipv6Addr;
115
116    #[test]
117    fn udp_checksum() {
118        let mut packet = [0u8; 200];
119
120        //
121        // ipv6
122        //
123
124        packet[0] = 6; // version = 6
125        packet[5] = 160; // 160 byte payload (200 - payload=40)
126        packet[6] = 17; // next header = udp
127        packet[7] = 255; // hop limit = 255
128
129        // src = fd00::1
130        packet[8] = 0xfd;
131        packet[23] = 0x01;
132
133        // dst = fd00::2
134        packet[24] = 0xfd;
135        packet[39] = 0x02;
136
137        //
138        // udp
139        //
140
141        packet[41] = 47; // source port = 47
142        packet[43] = 74; // dstination port = 74
143        packet[45] = 160; // udp header + payload = 160 bytes
144        for (i, data_point) in packet.iter_mut().enumerate().skip(46) {
145            *data_point = ((i as f32) * (PI / 32.0) * 10.0) as u8;
146        }
147
148        let x = udp6_checksum(&packet);
149
150        let p = udp::UdpPacket::new(&packet[40..]).unwrap();
151        let src: Ipv6Addr = "fd00::1".parse().unwrap();
152        let dst: Ipv6Addr = "fd00::2".parse().unwrap();
153        let y = udp::ipv6_checksum(&p, &src, &dst);
154
155        assert_eq!(x, y);
156    }
157}