1use bitvec::prelude::*;
4
5#[derive(Default)]
6pub struct Csum(u16);
7
8impl Csum {
9 pub fn add(&mut self, a: u8, b: u8) {
10 let x = u16::from_be_bytes([a, b]);
11 let (mut result, overflow) = self.0.overflowing_add(x);
12 if overflow {
13 result += 1;
14 }
15 self.0 = result;
16 }
17 pub fn add128(&mut self, data: [u8; 16]) {
18 self.add(data[0], data[1]);
19 self.add(data[2], data[3]);
20 self.add(data[4], data[5]);
21 self.add(data[6], data[7]);
22 self.add(data[8], data[9]);
23 self.add(data[10], data[11]);
24 self.add(data[12], data[13]);
25 self.add(data[14], data[15]);
26 }
27 pub fn add32(&mut self, data: [u8; 4]) {
28 self.add(data[0], data[1]);
29 self.add(data[2], data[3]);
30 }
31 pub fn add16(&mut self, data: [u8; 2]) {
32 self.add(data[0], data[1]);
33 }
34 pub fn result(&self) -> u16 {
35 !self.0
36 }
37}
38
39pub fn udp6_checksum(data: &[u8]) -> u16 {
40 let src = &data[8..24];
41 let dst = &data[24..40];
42 let udp_len = &data[4..6];
43 let next_header = &data[6];
44 let src_port = &data[40..42];
45 let dst_port = &data[42..44];
46 let payload_len = &data[44..46];
47 let payload = &data[48..];
48
49 let mut csum = Csum(0);
50
51 for i in (0..src.len()).step_by(2) {
52 csum.add(src[i], src[i + 1]);
53 }
54 for i in (0..dst.len()).step_by(2) {
55 csum.add(dst[i], dst[i + 1]);
56 }
57 csum.add(udp_len[0], udp_len[1]);
58 csum.add(0, *next_header);
60 csum.add(src_port[0], src_port[1]);
61 csum.add(dst_port[0], dst_port[1]);
62 csum.add(payload_len[0], payload_len[1]);
63
64 let len = payload.len();
65 let (odd, len) = if len % 2 == 0 {
66 (false, len)
67 } else {
68 (true, len - 1)
69 };
70 for i in (0..len).step_by(2) {
71 csum.add(payload[i], payload[i + 1]);
72 }
73 if odd {
74 csum.add(payload[len], 0);
75 }
76
77 csum.result()
78}
79
80pub trait Checksum {
81 fn csum(&self) -> BitVec<u8, Msb0>;
82}
83
84fn bvec_csum(bv: &BitVec<u8, Msb0>) -> BitVec<u8, Msb0> {
85 let x: u128 = bv.load();
86 let buf = x.to_be_bytes();
87 let mut c: u16 = 0;
88 for i in (0..16).step_by(2) {
89 c += u16::from_be_bytes([buf[i], buf[i + 1]])
90 }
91 let c = !c;
92 let mut result = bitvec![u8, Msb0; 0u8, 16];
93 result.store(c);
94 result
95}
96
97impl Checksum for BitVec<u8, Msb0> {
98 fn csum(&self) -> BitVec<u8, Msb0> {
99 bvec_csum(self)
100 }
101}
102
103impl Checksum for &BitVec<u8, Msb0> {
104 fn csum(&self) -> BitVec<u8, Msb0> {
105 bvec_csum(self)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use pnet::packet::udp;
113 use std::f32::consts::PI;
114 use std::net::Ipv6Addr;
115
116 #[test]
117 fn udp_checksum() {
118 let mut packet = [0u8; 200];
119
120 packet[0] = 6; packet[5] = 160; packet[6] = 17; packet[7] = 255; packet[8] = 0xfd;
131 packet[23] = 0x01;
132
133 packet[24] = 0xfd;
135 packet[39] = 0x02;
136
137 packet[41] = 47; packet[43] = 74; packet[45] = 160; for (i, data_point) in packet.iter_mut().enumerate().skip(46) {
145 *data_point = ((i as f32) * (PI / 32.0) * 10.0) as u8;
146 }
147
148 let x = udp6_checksum(&packet);
149
150 let p = udp::UdpPacket::new(&packet[40..]).unwrap();
151 let src: Ipv6Addr = "fd00::1".parse().unwrap();
152 let dst: Ipv6Addr = "fd00::2".parse().unwrap();
153 let y = udp::ipv6_checksum(&p, &src, &dst);
154
155 assert_eq!(x, y);
156 }
157}