1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Copyright 2022 Oxide Computer Company

use bitvec::prelude::*;

#[derive(Default)]
pub struct Csum(u16);

impl Csum {
    pub fn add(&mut self, a: u8, b: u8) {
        let x = u16::from_be_bytes([a, b]);
        let (mut result, overflow) = self.0.overflowing_add(x);
        if overflow {
            result += 1;
        }
        self.0 = result;
    }
    pub fn add128(&mut self, data: [u8; 16]) {
        self.add(data[0], data[1]);
        self.add(data[2], data[3]);
        self.add(data[4], data[5]);
        self.add(data[6], data[7]);
        self.add(data[8], data[9]);
        self.add(data[10], data[11]);
        self.add(data[12], data[13]);
        self.add(data[14], data[15]);
    }
    pub fn add32(&mut self, data: [u8; 4]) {
        self.add(data[0], data[1]);
        self.add(data[2], data[3]);
    }
    pub fn add16(&mut self, data: [u8; 2]) {
        self.add(data[0], data[1]);
    }
    pub fn result(&self) -> u16 {
        !self.0
    }
}

pub fn udp6_checksum(data: &[u8]) -> u16 {
    let src = &data[8..24];
    let dst = &data[24..40];
    let udp_len = &data[4..6];
    let next_header = &data[6];
    let src_port = &data[40..42];
    let dst_port = &data[42..44];
    let payload_len = &data[44..46];
    let payload = &data[48..];

    let mut csum = Csum(0);

    for i in (0..src.len()).step_by(2) {
        csum.add(src[i], src[i + 1]);
    }
    for i in (0..dst.len()).step_by(2) {
        csum.add(dst[i], dst[i + 1]);
    }
    csum.add(udp_len[0], udp_len[1]);
    //TODO assuming no jumbo
    csum.add(0, *next_header);
    csum.add(src_port[0], src_port[1]);
    csum.add(dst_port[0], dst_port[1]);
    csum.add(payload_len[0], payload_len[1]);

    let len = payload.len();
    let (odd, len) = if len % 2 == 0 {
        (false, len)
    } else {
        (true, len - 1)
    };
    for i in (0..len).step_by(2) {
        csum.add(payload[i], payload[i + 1]);
    }
    if odd {
        csum.add(payload[len], 0);
    }

    csum.result()
}

pub trait Checksum {
    fn csum(&self) -> BitVec<u8, Msb0>;
}

fn bvec_csum(bv: &BitVec<u8, Msb0>) -> BitVec<u8, Msb0> {
    let x: u128 = bv.load();
    let buf = x.to_be_bytes();
    let mut c: u16 = 0;
    for i in (0..16).step_by(2) {
        c += u16::from_be_bytes([buf[i], buf[i + 1]])
    }
    let c = !c;
    let mut result = bitvec![u8, Msb0; 0u8, 16];
    result.store(c);
    result
}

impl Checksum for BitVec<u8, Msb0> {
    fn csum(&self) -> BitVec<u8, Msb0> {
        bvec_csum(self)
    }
}

impl Checksum for &BitVec<u8, Msb0> {
    fn csum(&self) -> BitVec<u8, Msb0> {
        bvec_csum(self)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use pnet::packet::udp;
    use std::f32::consts::PI;
    use std::net::Ipv6Addr;

    #[test]
    fn udp_checksum() {
        let mut packet = [0u8; 200];

        //
        // ipv6
        //

        packet[0] = 6; // version = 6
        packet[5] = 160; // 160 byte payload (200 - payload=40)
        packet[6] = 17; // next header = udp
        packet[7] = 255; // hop limit = 255

        // src = fd00::1
        packet[8] = 0xfd;
        packet[23] = 0x01;

        // dst = fd00::2
        packet[24] = 0xfd;
        packet[39] = 0x02;

        //
        // udp
        //

        packet[41] = 47; // source port = 47
        packet[43] = 74; // dstination port = 74
        packet[45] = 160; // udp header + payload = 160 bytes
        for (i, data_point) in packet.iter_mut().enumerate().skip(46) {
            *data_point = ((i as f32) * (PI / 32.0) * 10.0) as u8;
        }

        let x = udp6_checksum(&packet);

        let p = udp::UdpPacket::new(&packet[40..]).unwrap();
        let src: Ipv6Addr = "fd00::1".parse().unwrap();
        let dst: Ipv6Addr = "fd00::2".parse().unwrap();
        let y = udp::ipv6_checksum(&p, &src, &dst);

        assert_eq!(x, y);
    }
}