diff --git a/binary.go b/binary.go index d2fb738..7a0ee74 100644 --- a/binary.go +++ b/binary.go @@ -2,7 +2,7 @@ package raknet import ( "bytes" - "fmt" + "io" ) // uint24 represents an integer existing out of 3 bytes. It is actually a @@ -11,20 +11,52 @@ type uint24 uint32 // readUint24 reads 3 bytes from the buffer passed and combines it into a // uint24. If there were no 3 bytes to read, an error is returned. -func readUint24(b *bytes.Buffer) (uint24, error) { - ba, _ := b.ReadByte() - bb, _ := b.ReadByte() - bc, err := b.ReadByte() - if err != nil { - return 0, fmt.Errorf("error reading uint24: %v", err) +func readUint24(buf *bytes.Buffer) (uint24, error) { + b := make([]byte, 3) + if _, err := buf.Read(b); err != nil { + return 0, io.ErrUnexpectedEOF } - return uint24(ba) | (uint24(bb) << 8) | (uint24(bc) << 16), nil + return uint24(b[0]) | (uint24(b[1]) << 8) | (uint24(b[2]) << 16), nil +} + +func readUint16(buf *bytes.Buffer) (uint16, error) { + b := make([]byte, 2) + if _, err := buf.Read(b); err != nil { + return 0, io.ErrUnexpectedEOF + } + return (uint16(b[0]) << 8) | uint16(b[1]), nil +} + +func readUint32(buf *bytes.Buffer) (uint32, error) { + b := make([]byte, 4) + if _, err := buf.Read(b); err != nil { + return 0, io.ErrUnexpectedEOF + } + return (uint32(b[0]) << 24) | (uint32(b[1]) << 16) | (uint32(b[2]) << 8) | uint32(b[3]), nil } // writeUint24 writes a uint24 to the buffer passed as 3 bytes. If not // successful, an error is returned. -func writeUint24(b *bytes.Buffer, value uint24) { - b.WriteByte(byte(value)) - b.WriteByte(byte(value >> 8)) - b.WriteByte(byte(value >> 16)) +func writeUint24(b *bytes.Buffer, v uint24) { + b.Write([]byte{ + byte(v), + byte(v >> 8), + byte(v >> 16), + }) +} + +func writeUint16(b *bytes.Buffer, v uint16) { + b.Write([]byte{ + byte(v >> 8), + byte(v), + }) +} + +func writeUint32(b *bytes.Buffer, v uint32) { + b.Write([]byte{ + byte(v >> 24), + byte(v >> 16), + byte(v >> 8), + byte(v), + }) } diff --git a/conn.go b/conn.go index f43feb9..33de724 100644 --- a/conn.go +++ b/conn.go @@ -526,14 +526,14 @@ func (conn *Conn) receiveDatagram(b *bytes.Buffer) error { func (conn *Conn) handleDatagram(b *bytes.Buffer) error { for b.Len() > 0 { if err := conn.pk.read(b); err != nil { - return fmt.Errorf("error decoding datagram packet: %v", err) + return fmt.Errorf("handle datagram: decode packet: %v", err) } handle := conn.receivePacket if conn.pk.split { handle = conn.receiveSplitPacket } if err := handle(conn.pk); err != nil { - return fmt.Errorf("error handling packet in datagram: %v", err) + return fmt.Errorf("handle datagram: handle packet: %v", err) } } return nil @@ -732,10 +732,7 @@ func (conn *Conn) sendAcknowledgement(packets []uint24, bitflag byte, buf *bytes for len(ack.packets) != 0 { buf.WriteByte(bitflag | bitFlagDatagram) - n, err := ack.write(buf, conn.mtu) - if err != nil { - panic(fmt.Sprintf("error encoding ACK packet: %v", err)) - } + n := ack.write(buf, conn.mtu) // We managed to write n packets in the ACK with this MTU size, write // the next of the packets in a new ACK. ack.packets = ack.packets[n:] diff --git a/internal/message/open_connection_request_1.go b/internal/message/open_connection_request_1.go index 5488ab6..d88028f 100644 --- a/internal/message/open_connection_request_1.go +++ b/internal/message/open_connection_request_1.go @@ -9,11 +9,20 @@ type OpenConnectionRequest1 struct { MaximumSizeNotDropped uint16 } +var cachedOCR1 = map[uint16][]byte{} + func (pk *OpenConnectionRequest1) MarshalBinary() (data []byte, err error) { + if b, ok := cachedOCR1[pk.MaximumSizeNotDropped]; ok { + // Cache OpenConnectionRequest1 data. These are independent of any other + // inputs and are pretty big. + return b, nil + } b := make([]byte, pk.MaximumSizeNotDropped-20-8) // IP Header: 20 bytes, UDP Header: 8 bytes. b[0] = IDOpenConnectionRequest1 copy(b[1:], unconnectedMessageSequence[:]) b[17] = pk.Protocol + + cachedOCR1[pk.MaximumSizeNotDropped] = b return b, nil } diff --git a/packet.go b/packet.go index d22f8db..473c3c8 100644 --- a/packet.go +++ b/packet.go @@ -3,8 +3,9 @@ package raknet import ( "bytes" "encoding/binary" - "fmt" - "sort" + "errors" + "io" + "slices" ) const ( @@ -65,93 +66,87 @@ type packet struct { } // write writes the packet and its content to the buffer passed. -func (packet *packet) write(b *bytes.Buffer) { - header := packet.reliability << 5 - if packet.split { +func (pk *packet) write(buf *bytes.Buffer) { + header := pk.reliability << 5 + if pk.split { header |= splitFlag } - b.WriteByte(header) - _ = binary.Write(b, binary.BigEndian, uint16(len(packet.content))<<3) - if packet.reliable() { - writeUint24(b, packet.messageIndex) + + buf.WriteByte(header) + writeUint16(buf, uint16(len(pk.content))<<3) + if pk.reliable() { + writeUint24(buf, pk.messageIndex) } - if packet.sequenced() { - writeUint24(b, packet.sequenceIndex) + if pk.sequenced() { + writeUint24(buf, pk.sequenceIndex) } - if packet.sequencedOrOrdered() { - writeUint24(b, packet.orderIndex) + if pk.sequencedOrOrdered() { + writeUint24(buf, pk.orderIndex) // Order channel, we don't care about this. - b.WriteByte(0) + buf.WriteByte(0) } - if packet.split { - _ = binary.Write(b, binary.BigEndian, packet.splitCount) - _ = binary.Write(b, binary.BigEndian, packet.splitID) - _ = binary.Write(b, binary.BigEndian, packet.splitIndex) + if pk.split { + writeUint32(buf, pk.splitCount) + writeUint16(buf, pk.splitID) + writeUint32(buf, pk.splitIndex) } - b.Write(packet.content) + buf.Write(pk.content) } // read reads a packet and its content from the buffer passed. -func (packet *packet) read(b *bytes.Buffer) error { - header, err := b.ReadByte() +func (pk *packet) read(buf *bytes.Buffer) error { + header, err := buf.ReadByte() if err != nil { - return fmt.Errorf("error reading packet header: %v", err) + return io.ErrUnexpectedEOF } - packet.split = (header & splitFlag) != 0 - packet.reliability = (header & 224) >> 5 - var packetLength uint16 - if err := binary.Read(b, binary.BigEndian, &packetLength); err != nil { - return fmt.Errorf("error reading packet length: %v", err) + pk.split = (header & splitFlag) != 0 + pk.reliability = (header & 224) >> 5 + packetLength, err := readUint16(buf) + if err != nil { + return io.ErrUnexpectedEOF } packetLength >>= 3 if packetLength == 0 { - return fmt.Errorf("invalid packet length: cannot be 0") + return errors.New("invalid packet length: cannot be 0") } - if packet.reliable() { - packet.messageIndex, err = readUint24(b) - if err != nil { - return fmt.Errorf("error reading packet message index: %v", err) + if pk.reliable() { + if pk.messageIndex, err = readUint24(buf); err != nil { + return io.ErrUnexpectedEOF } } - if packet.sequenced() { - packet.sequenceIndex, err = readUint24(b) - if err != nil { - return fmt.Errorf("error reading packet sequence index: %v", err) + if pk.sequenced() { + if pk.sequenceIndex, err = readUint24(buf); err != nil { + return io.ErrUnexpectedEOF } } - if packet.sequencedOrOrdered() { - packet.orderIndex, err = readUint24(b) - if err != nil { - return fmt.Errorf("error reading packet order index: %v", err) + if pk.sequencedOrOrdered() { + if pk.orderIndex, err = readUint24(buf); err != nil { + return io.ErrUnexpectedEOF } // Order channel (byte), we don't care about this. - b.Next(1) + buf.Next(1) } - if packet.split { - if err := binary.Read(b, binary.BigEndian, &packet.splitCount); err != nil { - return fmt.Errorf("error reading packet split count: %v", err) - } - if err := binary.Read(b, binary.BigEndian, &packet.splitID); err != nil { - return fmt.Errorf("error reading packet split ID: %v", err) - } - if err := binary.Read(b, binary.BigEndian, &packet.splitIndex); err != nil { - return fmt.Errorf("error reading packet split index: %v", err) + if pk.split { + pk.splitCount, _ = readUint32(buf) + pk.splitID, _ = readUint16(buf) + if pk.splitIndex, err = readUint32(buf); err != nil { + return io.ErrUnexpectedEOF } } - packet.content = make([]byte, packetLength) - if n, err := b.Read(packet.content); err != nil || n != int(packetLength) { - return fmt.Errorf("not enough data in packet: %v bytes read but need %v", n, packetLength) + pk.content = make([]byte, packetLength) + if n, err := buf.Read(pk.content); err != nil || n != int(packetLength) { + return io.ErrUnexpectedEOF } return nil } -func (packet *packet) reliable() bool { - switch packet.reliability { +func (pk *packet) reliable() bool { + switch pk.reliability { case reliabilityReliable, reliabilityReliableOrdered, reliabilityReliableSequenced: @@ -160,8 +155,8 @@ func (packet *packet) reliable() bool { return false } -func (packet *packet) sequencedOrOrdered() bool { - switch packet.reliability { +func (pk *packet) sequencedOrOrdered() bool { + switch pk.reliability { case reliabilityUnreliableSequenced, reliabilityReliableOrdered, reliabilityReliableSequenced: @@ -170,8 +165,8 @@ func (packet *packet) sequencedOrOrdered() bool { return false } -func (packet *packet) sequenced() bool { - switch packet.reliability { +func (pk *packet) sequenced() bool { + switch pk.reliability { case reliabilityUnreliableSequenced, reliabilityReliableSequenced: return true @@ -195,23 +190,24 @@ type acknowledgement struct { // write encodes an acknowledgement packet and returns an error if not // successful. -func (ack *acknowledgement) write(b *bytes.Buffer, mtu uint16) (n int, err error) { +func (ack *acknowledgement) write(buf *bytes.Buffer, mtu uint16) int { + lenOffset := buf.Len() + writeUint16(buf, 0) // Placeholder for record count. + packets := ack.packets if len(packets) == 0 { - return 0, binary.Write(b, binary.BigEndian, int16(0)) + return 0 } - buffer := bytes.NewBuffer(nil) - // Sort packets before encoding to ensure packets are encoded correctly. - sort.Slice(packets, func(i, j int) bool { - return packets[i] < packets[j] - }) - var firstPacketInRange uint24 - var lastPacketInRange uint24 - var recordCount int16 + var firstPacketInRange, lastPacketInRange uint24 + var records uint16 + n := 0 + + // Sort packets before encoding to ensure packets are encoded correctly. + slices.Sort(packets) - for index, packet := range packets { - if buffer.Len() >= int(mtu-10) { + for index, pk := range packets { + if buf.Len() >= int(mtu-(28+10)) { // We must make sure the final packet length doesn't exceed the MTU // size. break @@ -219,112 +215,83 @@ func (ack *acknowledgement) write(b *bytes.Buffer, mtu uint16) (n int, err error n++ if index == 0 { // The first packet, set the first and last packet to it. - firstPacketInRange = packet - lastPacketInRange = packet + firstPacketInRange, lastPacketInRange = pk, pk continue } - if packet == lastPacketInRange+1 { + if pk == lastPacketInRange+1 { // Packet is still part of the current range, as it's sequenced // properly with the last packet. Set the last packet in range to // the packet and continue to the next packet. - lastPacketInRange = packet + lastPacketInRange = pk continue - } else { - // We got to the end of a range/single packet. We need to write - // those down now. - if firstPacketInRange == lastPacketInRange { - // First packet equals last packet, so we have a single packet - // record. Write down the packet, and set the first and last - // packet to the current packet. - if err := buffer.WriteByte(packetSingle); err != nil { - return 0, err - } - writeUint24(buffer, firstPacketInRange) - - firstPacketInRange = packet - lastPacketInRange = packet - } else { - // There's a gap between the first and last packet, so we have a - // range of packets. Write the first and last packet of the - // range and set both to the current packet. - if err := buffer.WriteByte(packetRange); err != nil { - return 0, err - } - writeUint24(buffer, firstPacketInRange) - writeUint24(buffer, lastPacketInRange) - - firstPacketInRange = packet - lastPacketInRange = packet - } - // Keep track of the amount of records as we need to write that - // first. - recordCount++ } + ack.writeRecord(buf, firstPacketInRange, lastPacketInRange, &records) + firstPacketInRange, lastPacketInRange = pk, pk } - // Make sure the last single packet/range is written, as we always need to // know one packet ahead to know how we should write the current. - if firstPacketInRange == lastPacketInRange { - if err := buffer.WriteByte(packetSingle); err != nil { - return 0, err - } - writeUint24(buffer, firstPacketInRange) + ack.writeRecord(buf, firstPacketInRange, lastPacketInRange, &records) + + binary.BigEndian.PutUint16(buf.Bytes()[lenOffset:], records) + return n +} + +func (ack *acknowledgement) writeRecord(buf *bytes.Buffer, first, last uint24, count *uint16) { + if first == last { + // First packet equals last packet, so we have a single packet + // record. Write down the packet, and set the first and last + // packet to the current packet. + buf.WriteByte(packetSingle) + writeUint24(buf, first) } else { - if err := buffer.WriteByte(packetRange); err != nil { - return 0, err - } - writeUint24(buffer, firstPacketInRange) - writeUint24(buffer, lastPacketInRange) - } - recordCount++ - if err := binary.Write(b, binary.BigEndian, recordCount); err != nil { - return 0, err + // There's a gap between the first and last packet, so we have a + // range of packets. Write the first and last packet of the + // range and set both to the current packet. + buf.WriteByte(packetRange) + writeUint24(buf, first) + writeUint24(buf, last) } - if _, err := b.Write(buffer.Bytes()); err != nil { - return 0, err - } - return n, nil + *count++ } // read decodes an acknowledgement packet and returns an error if not // successful. -func (ack *acknowledgement) read(b *bytes.Buffer) error { +func (ack *acknowledgement) read(buf *bytes.Buffer) error { const maxAcknowledgementPackets = 8192 - var recordCount int16 - if err := binary.Read(b, binary.BigEndian, &recordCount); err != nil { - return err + recordCount, err := readUint16(buf) + if err != nil { + return io.ErrUnexpectedEOF } - for i := int16(0); i < recordCount; i++ { - recordType, err := b.ReadByte() + for i := uint16(0); i < recordCount; i++ { + recordType, err := buf.ReadByte() if err != nil { - return err + return io.ErrUnexpectedEOF } switch recordType { case packetRange: - start, err := readUint24(b) + start, _ := readUint24(buf) + end, err := readUint24(buf) if err != nil { - return err + return io.ErrUnexpectedEOF } - end, err := readUint24(b) - if err != nil { - return err + if uint24(len(ack.packets))+end-start > maxAcknowledgementPackets { + return errMaxAcknowledgement } - for pack := start; pack <= end; pack++ { - ack.packets = append(ack.packets, pack) - if len(ack.packets) > maxAcknowledgementPackets { - return fmt.Errorf("maximum amount of packets in acknowledgement exceeded") - } + for pk := start; pk <= end; pk++ { + ack.packets = append(ack.packets, pk) } case packetSingle: - packet, err := readUint24(b) - if err != nil { - return err + if len(ack.packets)+1 > maxAcknowledgementPackets { + return errMaxAcknowledgement } - ack.packets = append(ack.packets, packet) - if len(ack.packets) > maxAcknowledgementPackets { - return fmt.Errorf("maximum amount of packets in acknowledgement exceeded") + pk, err := readUint24(buf) + if err != nil { + return io.ErrUnexpectedEOF } + ack.packets = append(ack.packets, pk) } } return nil } + +var errMaxAcknowledgement = errors.New("maximum amount of packets in acknowledgement exceeded")