diff --git a/src/commonMain/kotlin/NbtDecoder.kt b/src/commonMain/kotlin/NbtDecoder.kt index f66e3675..0f24d41a 100644 --- a/src/commonMain/kotlin/NbtDecoder.kt +++ b/src/commonMain/kotlin/NbtDecoder.kt @@ -6,12 +6,14 @@ import kotlinx.serialization.PolymorphicSerializer import kotlinx.serialization.builtins.ByteArraySerializer import kotlinx.serialization.builtins.IntArraySerializer import kotlinx.serialization.builtins.LongArraySerializer +import kotlinx.serialization.descriptors.PolymorphicKind import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.StructureKind import kotlinx.serialization.encoding.AbstractDecoder import kotlinx.serialization.encoding.CompositeDecoder import kotlinx.serialization.encoding.Decoder import net.benwoodworth.knbt.internal.NbtDecodingException +import net.benwoodworth.knbt.internal.NbtEncodingException public sealed interface NbtDecoder : Decoder { public val nbt: NbtFormat @@ -81,10 +83,14 @@ internal abstract class AbstractNbtDecoder : AbstractDecoder(), NbtDecoder, Comp ByteArraySerializer() -> decodeByteArray() as T IntArraySerializer() -> decodeIntArray() as T LongArraySerializer() -> decodeLongArray() as T - is PolymorphicSerializer<*> -> when (deserializer.baseClass) { - NbtTag::class -> decodeNbtTag() as T - else -> throw NbtDecodingException("Polymorphic serialization is not yet supported") + else -> when { + deserializer is PolymorphicSerializer && deserializer.baseClass == NbtTag::class -> { + decodeNbtTag() as T + } + deserializer.descriptor.kind is PolymorphicKind -> { + throw NbtEncodingException("Polymorphic serialization is not yet supported") + } + else -> super.decodeSerializableValue(deserializer) } - else -> super.decodeSerializableValue(deserializer) } } diff --git a/src/commonMain/kotlin/NbtEncoder.kt b/src/commonMain/kotlin/NbtEncoder.kt index 0f3ab8b7..b69f1985 100644 --- a/src/commonMain/kotlin/NbtEncoder.kt +++ b/src/commonMain/kotlin/NbtEncoder.kt @@ -91,12 +91,16 @@ internal abstract class AbstractNbtEncoder : AbstractEncoder(), NbtEncoder, Comp override fun encodeSerializableValue(serializer: SerializationStrategy, value: T): Unit = when (serializer) { - NbtTag.serializer() -> encodeNbtTag(value as NbtTag) ByteArraySerializer() -> encodeByteArray(value as ByteArray) IntArraySerializer() -> encodeIntArray(value as IntArray) LongArraySerializer() -> encodeLongArray(value as LongArray) - else -> when (serializer.descriptor.kind) { - is PolymorphicKind -> throw NbtEncodingException("Polymorphic serialization is not yet supported") + else -> when { + serializer is PolymorphicSerializer && serializer.baseClass == NbtTag::class -> { + encodeNbtTag(value as NbtTag) + } + serializer.descriptor.kind is PolymorphicKind -> { + throw NbtEncodingException("Polymorphic serialization is not yet supported") + } else -> super.encodeSerializableValue(serializer, value) } } diff --git a/src/commonTest/kotlin/NbtTagPolymorphismTest.kt b/src/commonTest/kotlin/NbtTagPolymorphismTest.kt new file mode 100644 index 00000000..0f29f649 --- /dev/null +++ b/src/commonTest/kotlin/NbtTagPolymorphismTest.kt @@ -0,0 +1,51 @@ +package net.benwoodworth.knbt + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlin.test.Test +import kotlin.test.assertEquals + +class NbtTagPolymorphismTest { + private val nbt = Nbt { + variant = NbtVariant.Java + compression = NbtCompression.None + } + + @Serializable + @SerialName("") + private data class NbtTagContainer( + val nbtTag: NbtTag, + ) + + @Test + fun Should_encode_NbtCompound_to_NbtTag_property_correctly() { + val compound = buildNbtCompound { + put("entry", "Hello, world!") + } + + val toEncode = NbtTagContainer(compound) + + assertEquals( + expected = buildNbtCompound("") { + put(NbtTagContainer::nbtTag.name, compound) + }, + actual = nbt.encodeToNbtTag(NbtTagContainer.serializer(), toEncode), + ) + } + + @Test + fun Should_decode_NbtCompound_from_NbtTag_property_correctly() { + val compound = buildNbtCompound { + put("entry", "Hello, world!") + } + + val toDecode = buildNbtCompound("") { + put(NbtTagContainer::nbtTag.name, compound) + } + + assertEquals( + expected = NbtTagContainer(compound), + actual = nbt.decodeFromNbtTag(NbtTagContainer.serializer(), toDecode), + ) + } +}