-
Notifications
You must be signed in to change notification settings - Fork 0
/
segment.py
224 lines (169 loc) · 7.2 KB
/
segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zlib
import six
from cassandra import DriverException
from cassandra.marshal import int32_pack
from cassandra.protocol import write_uint_le, read_uint_le
CRC24_INIT = 0x875060
CRC24_POLY = 0x1974F0B
CRC24_LENGTH = 3
CRC32_LENGTH = 4
CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca")
class CrcException(Exception):
"""
CRC mismatch error.
TODO: here to avoid import cycles with cassandra.connection. In the next
major, the exceptions should be declared in a separated exceptions.py
file.
"""
pass
def compute_crc24(data, length):
crc = CRC24_INIT
for _ in range(length):
crc ^= (data & 0xff) << 16
data >>= 8
for i in range(8):
crc <<= 1
if crc & 0x1000000 != 0:
crc ^= CRC24_POLY
return crc
def compute_crc32(data, value):
crc32 = zlib.crc32(data, value)
if six.PY2:
crc32 &= 0xffffffff
return crc32
class SegmentHeader(object):
payload_length = None
uncompressed_payload_length = None
is_self_contained = None
def __init__(self, payload_length, uncompressed_payload_length, is_self_contained):
self.payload_length = payload_length
self.uncompressed_payload_length = uncompressed_payload_length
self.is_self_contained = is_self_contained
@property
def segment_length(self):
"""
Return the total length of the segment, including the CRC.
"""
hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \
else SegmentCodec.COMPRESSED_HEADER_LENGTH
return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH
class Segment(object):
MAX_PAYLOAD_LENGTH = 128 * 1024 - 1
payload = None
is_self_contained = None
def __init__(self, payload, is_self_contained):
self.payload = payload
self.is_self_contained = is_self_contained
class SegmentCodec(object):
COMPRESSED_HEADER_LENGTH = 5
UNCOMPRESSED_HEADER_LENGTH = 3
FLAG_OFFSET = 17
compressor = None
decompressor = None
def __init__(self, compressor=None, decompressor=None):
self.compressor = compressor
self.decompressor = decompressor
@property
def header_length(self):
return self.COMPRESSED_HEADER_LENGTH if self.compression \
else self.UNCOMPRESSED_HEADER_LENGTH
@property
def header_length_with_crc(self):
return (self.COMPRESSED_HEADER_LENGTH if self.compression
else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH
@property
def compression(self):
return self.compressor and self.decompressor
def compress(self, data):
# the uncompressed length is already encoded in the header, so
# we remove it here
return self.compressor(data)[4:]
def decompress(self, encoded_data, uncompressed_length):
return self.decompressor(int32_pack(uncompressed_length) + encoded_data)
def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained):
if payload_length > Segment.MAX_PAYLOAD_LENGTH:
raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH')
header_data = payload_length
flag_offset = self.FLAG_OFFSET
if self.compression:
header_data |= uncompressed_length << flag_offset
flag_offset += 17
if is_self_contained:
header_data |= 1 << flag_offset
write_uint_le(buffer, header_data, size=self.header_length)
header_crc = compute_crc24(header_data, self.header_length)
write_uint_le(buffer, header_crc, size=CRC24_LENGTH)
def _encode_segment(self, buffer, payload, is_self_contained):
"""
Encode a message to a single segment.
"""
uncompressed_payload = payload
uncompressed_payload_length = len(payload)
if self.compression:
compressed_payload = self.compress(uncompressed_payload)
if len(compressed_payload) >= uncompressed_payload_length:
encoded_payload = uncompressed_payload
uncompressed_payload_length = 0
else:
encoded_payload = compressed_payload
else:
encoded_payload = uncompressed_payload
payload_length = len(encoded_payload)
self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained)
payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
buffer.write(encoded_payload)
write_uint_le(buffer, payload_crc)
def encode(self, buffer, msg):
"""
Encode a message to one of more segments.
"""
msg_length = len(msg)
if msg_length > Segment.MAX_PAYLOAD_LENGTH:
payloads = []
for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH):
payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH])
else:
payloads = [msg]
is_self_contained = len(payloads) == 1
for payload in payloads:
self._encode_segment(buffer, payload, is_self_contained)
def decode_header(self, buffer):
header_data = read_uint_le(buffer, self.header_length)
expected_header_crc = read_uint_le(buffer, CRC24_LENGTH)
actual_header_crc = compute_crc24(header_data, self.header_length)
if actual_header_crc != expected_header_crc:
raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format(
header_data, expected_header_crc, actual_header_crc))
payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
header_data >>= 17
if self.compression:
uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
header_data >>= 17
else:
uncompressed_payload_length = -1
is_self_contained = (header_data & 1) == 1
return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained)
def decode(self, buffer, header):
encoded_payload = buffer.read(header.payload_length)
expected_payload_crc = read_uint_le(buffer)
actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
if actual_payload_crc != expected_payload_crc:
raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format(
expected_payload_crc, actual_payload_crc))
payload = encoded_payload
if self.compression and header.uncompressed_payload_length > 0:
payload = self.decompress(encoded_payload, header.uncompressed_payload_length)
return Segment(payload, header.is_self_contained)