Skip to content

Commit

Permalink
feat: Add TypeDescriptor for CUnion (#183)
Browse files Browse the repository at this point in the history
* Add TypeDescriptor for CUnion
Fixes #174

* chore: add unit tests for descriptor logic
  • Loading branch information
markehammons authored May 10, 2023
1 parent 1fc1bf5 commit a33580c
Show file tree
Hide file tree
Showing 20 changed files with 150 additions and 45 deletions.
2 changes: 1 addition & 1 deletion core/src/fr/hammons/slinc/CUnion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import fr.hammons.slinc.modules.ReadWriteModule
import fr.hammons.slinc.modules.DescriptorModule
import scala.NonEmptyTuple

class CUnion[T <: Tuple](mem: Mem):
class CUnion[T <: Tuple](private[slinc] val mem: Mem):
private inline def getHelper[T <: Tuple, A](using
dO: DescriptorOf[A],
rwm: ReadWriteModule
Expand Down
13 changes: 13 additions & 0 deletions core/src/fr/hammons/slinc/DescriptorOf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package fr.hammons.slinc

import fr.hammons.slinc.container.*
import scala.quoted.*
import scala.compiletime.{summonInline, erasedValue}
import scala.NonEmptyTuple

/** Typeclass that summons TypeDescriptors
*/
Expand Down Expand Up @@ -64,3 +66,14 @@ object DescriptorOf:
)

'{ $expr.descriptor }

private inline def helper[B <: Tuple]: Set[TypeDescriptor] =
inline erasedValue[B] match
case _: (a *: t) => helper[t] + summonInline[DescriptorOf[a]].descriptor
case _: EmptyTuple => Set.empty[TypeDescriptor]

inline given [A <: NonEmptyTuple]: DescriptorOf[CUnion[A]] =
new DescriptorOf[CUnion[A]]:
val descriptor: CUnionDescriptor { type Inner = CUnion[A] } =
CUnionDescriptor(helper[A])
.asInstanceOf[CUnionDescriptor { type Inner = CUnion[A] }]
2 changes: 0 additions & 2 deletions core/src/fr/hammons/slinc/MacroHelpers.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package fr.hammons.slinc

import scala.quoted.*
import scala.annotation.nowarn

private[slinc] object MacroHelpers:
def widenExpr(t: Expr[?])(using Quotes) =
import quotes.reflect.*
Expand Down
2 changes: 0 additions & 2 deletions core/src/fr/hammons/slinc/Mem.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package fr.hammons.slinc

import scala.annotation.nowarn

trait Mem:
import scala.compiletime.asMatchable
def offset(bytes: Bytes): Mem
Expand Down
1 change: 0 additions & 1 deletion core/src/fr/hammons/slinc/MethodHandleTools.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fr.hammons.slinc
import scala.quoted.*
import java.lang.invoke.MethodHandle
import fr.hammons.slinc.modules.TransitionModule
import scala.annotation.nowarn

object MethodHandleTools:
def exprNameMapping(expr: Expr[Any])(using Quotes): String =
Expand Down
1 change: 0 additions & 1 deletion core/src/fr/hammons/slinc/SlincImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package fr.hammons.slinc

import scala.annotation.StaticAnnotation
import scala.quoted.*
import scala.annotation.nowarn

class SlincImpl(val version: Int) extends StaticAnnotation

Expand Down
3 changes: 0 additions & 3 deletions core/src/fr/hammons/slinc/Struct.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import fr.hammons.slinc.modules.TransitionModule
import fr.hammons.slinc.modules.ReadWriteModule
import fr.hammons.slinc.modules.Reader
import fr.hammons.slinc.modules.Writer
import scala.annotation.nowarn

trait Struct[A <: Product] extends DescriptorOf[A]

Expand Down Expand Up @@ -110,12 +109,10 @@ object Struct:
val reader = readGen[A]
val writer = writeGen[A]

@nowarn("msg=unused implicit parameter")
override val returnTransition = returnValue =>
val mem = summon[TransitionModule].memReturn(returnValue)
summon[ReadWriteModule].read(mem, Bytes(0), this)

@nowarn("msg=unused implicit parameter")
override val argumentTransition = argument =>
val mem = summon[Allocator].allocate(this, 1)
summon[ReadWriteModule].write(mem, Bytes(0), this, argument)
Expand Down
34 changes: 19 additions & 15 deletions core/src/fr/hammons/slinc/TypeDescriptor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import fr.hammons.slinc.modules.{
ArrayReader,
readWriteModule
}
import scala.annotation.nowarn
import scala.reflect.ClassTag
import scala.quoted.*
import fr.hammons.slinc.modules.TransitionModule
import fr.hammons.slinc.modules.{ArgumentTransition, ReturnTransition}
import scala.NonEmptyTuple

/** Describes types used by C interop
*/
Expand Down Expand Up @@ -106,31 +106,23 @@ case object LongDescriptor extends BasicDescriptor:

case object FloatDescriptor extends BasicDescriptor:
type Inner = Float
@nowarn("msg=unused implicit parameter")
val reader = readWriteModule.floatReader
@nowarn("msg=unused implicit parameter")
val writer = readWriteModule.floatWriter

case object DoubleDescriptor extends BasicDescriptor:
type Inner = Double
@nowarn("msg=unused implicit parameter")
val reader = readWriteModule.doubleReader
@nowarn("msg=unused implicit parameter")
val writer = readWriteModule.doubleWriter

case object PtrDescriptor extends TypeDescriptor:
type Inner = Ptr[?]
@nowarn("msg=unused implicit parameter")
override val reader = (mem, offset) =>
Ptr(readWriteModule.memReader(mem, offset), Bytes(0))
@nowarn("msg=unused implicit parameter")
override val writer = (mem, offset, a) =>
readWriteModule.memWriter(mem, offset, a.mem)

@nowarn("msg=unused implicit parameter")
override val argumentTransition = _.mem.asAddress

@nowarn("msg=unused implicit parameter")
override val returnTransition = o =>
Ptr[Any](summon[TransitionModule].addressReturn(o), Bytes(0))

Expand Down Expand Up @@ -174,11 +166,9 @@ case class AliasDescriptor[A](val real: TypeDescriptor) extends TypeDescriptor:
val writer: (ReadWriteModule, DescriptorModule) ?=> Writer[Inner] =
(rwm, _) ?=> (mem, bytes, a) => rwm.write(mem, bytes, real, a)

@nowarn("msg=unused implicit parameter")
override val argumentTransition =
summon[TransitionModule].methodArgument(real, _, summon[Allocator])

@nowarn("msg=unused implicit parameter")
override val returnTransition = summon[TransitionModule].methodReturn(real, _)
override def size(using dm: DescriptorModule): Bytes = dm.sizeOf(real)
override def alignment(using dm: DescriptorModule): Bytes =
Expand All @@ -189,22 +179,36 @@ case class AliasDescriptor[A](val real: TypeDescriptor) extends TypeDescriptor:
case object VaListDescriptor extends TypeDescriptor:
type Inner = VarArgs

@nowarn(TypeDescriptor.unusedImplicit)
override val reader: (ReadWriteModule, DescriptorModule) ?=> Reader[Inner] =
(mem, offset) => summon[ReadWriteModule].memReader(mem, offset).asVarArgs

@nowarn(TypeDescriptor.unusedImplicit)
override val argumentTransition
: (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[
Inner
] = _.mem.asAddress

@nowarn(TypeDescriptor.unusedImplicit)
override val writer: (ReadWriteModule, DescriptorModule) ?=> Writer[Inner] =
(mem, offset, value) =>
summon[ReadWriteModule].memWriter(mem, offset, value.mem)

@nowarn(TypeDescriptor.unusedImplicit)
override val returnTransition
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = o =>
summon[TransitionModule].addressReturn(o).asVarArgs

case class CUnionDescriptor(possibleTypes: Set[TypeDescriptor])
extends TypeDescriptor:
type Inner = CUnion[? <: NonEmptyTuple]

override val reader: (ReadWriteModule, DescriptorModule) ?=> Reader[Inner] =
???

override val returnTransition
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = ???

override val argumentTransition
: (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[
Inner
] = ???

override val writer: (ReadWriteModule, DescriptorModule) ?=> Writer[Inner] =
???
4 changes: 0 additions & 4 deletions core/src/fr/hammons/slinc/types/Arch.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package fr.hammons.slinc.types

import scala.annotation.nowarn

//todo: remove once https://github.com/lampepfl/dotty/issues/16878 is fixed
@nowarn("msg=unused explicit parameter")
private[slinc] enum Arch:
case I386
case X64
Expand Down
4 changes: 0 additions & 4 deletions core/src/fr/hammons/slinc/types/OS.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package fr.hammons.slinc.types

import scala.annotation.nowarn

//todo: remove once https://github.com/lampepfl/dotty/issues/16878 is fixed
@nowarn("msg=unused explicit parameter")
enum OS:
case Linux
case Darwin
Expand Down
27 changes: 27 additions & 0 deletions core/test/src/fr/hammons/slinc/TypeDescriptorSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package fr.hammons.slinc
import types.*

class TypeDescriptorSpec extends munit.FunSuite:
case class A(a: CInt, b: CInt) derives Struct
test("DescriptorOf[CUnion[(Int, Float, A)]] gives appropriate descriptor"):
assertEquals(
DescriptorOf[CUnion[(CInt, CFloat, A)]]: TypeDescriptor,
CUnionDescriptor(
Set(IntDescriptor, FloatDescriptor, summon[Struct[A]].descriptor)
): TypeDescriptor
)

test(
"DescriptorOf[CUnion[(CUnion[(Int, Float)], CUnion[(Int, Float)], A,A])]] doesn't double descriptors"
):
assertEquals(
DescriptorOf[
CUnion[(CUnion[(Int, Float)], CUnion[(Int, Float)], A, A)]
]: TypeDescriptor,
CUnionDescriptor(
Set(
CUnionDescriptor(Set(IntDescriptor, FloatDescriptor)),
summon[Struct[A]].descriptor
)
): TypeDescriptor
)
25 changes: 25 additions & 0 deletions core/test/src/fr/hammons/slinc/modules/DescriptorSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package fr.hammons.slinc.modules

import fr.hammons.slinc.Slinc
import fr.hammons.slinc.DescriptorOf
import fr.hammons.slinc.types.CLongLong
import fr.hammons.slinc.types.CInt
import fr.hammons.slinc.types.CFloat
import fr.hammons.slinc.CUnion
import fr.hammons.slinc.Struct

trait DescriptorSpec(val slinc: Slinc) extends munit.FunSuite:
import slinc.dm
case class A(a: CInt, b: CInt, c: CInt, d: CLongLong, e: CLongLong)
derives Struct
test("CUnionDescriptor.size gives the right size"):
assertEquals(
DescriptorOf[CUnion[(CInt, A, CFloat)]].size,
DescriptorOf[A].size
)

test("CUnionDescriptor.alignment gives the right alignment"):
assertEquals(
DescriptorOf[CUnion[(CInt, A, CFloat)]].alignment,
DescriptorOf[A].alignment
)
19 changes: 16 additions & 3 deletions j17/src/fr/hammons/slinc/Allocator17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ import jdk.incubator.foreign.{
MemorySegment,
ResourceScope,
CLinker,
FunctionDescriptor as JFunctionDescriptor
FunctionDescriptor as JFunctionDescriptor,
GroupLayout
}, CLinker.{C_POINTER, C_INT, C_LONG_LONG, C_DOUBLE, VaList}
import fr.hammons.slinc.modules.{descriptorModule17, transitionModule17}
import fr.hammons.slinc.modules.LinkageModule17
import scala.annotation.nowarn

@nowarn("msg=unused import")
class Allocator17(
segmentAllocator: SegmentAllocator,
scope: ResourceScope,
Expand Down Expand Up @@ -60,6 +59,20 @@ class Allocator17(
.asInstanceOf[Addressable]
)
)
case (cd: CUnionDescriptor, v: CUnion[?]) =>
builder.vargFromSegment(
descriptorModule17.toMemoryLayout(cd) match
case gl: GroupLayout => gl
case _ => throw Error("got a non group layout from CUnionDescriptor")
,
v.mem.asBase match
case ms: MemorySegment => ms
case _ => throw Error("base of mem was not J17 MemorySegment!!")
)
case (a, d) =>
throw Error(
s"Unsupported type descriptor/data pairing for VarArgs: $a - $d"
)

override def makeVarArgs(vbuilder: VarArgsBuilder): VarArgs =
VarArgs17(
Expand Down
10 changes: 6 additions & 4 deletions j17/src/fr/hammons/slinc/VarArgs17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ class VarArgs17(args: VaList) extends VarArgs:
)
.nn
)
case AliasDescriptor(real) => get(real)
case VaListDescriptor => args.vargAsAddress(C_POINTER).nn
case AliasDescriptor(real) => get(real)
case VaListDescriptor => args.vargAsAddress(C_POINTER).nn
case CUnionDescriptor(possibleTypes) => get(possibleTypes.maxBy(_.size))
def get[A](using d: DescriptorOf[A]): A =
transitionModule17.methodReturn[A](d.descriptor, get(d.descriptor))

Expand All @@ -46,8 +47,9 @@ class VarArgs17(args: VaList) extends VarArgs:
case PtrDescriptor => args.skip(C_POINTER)
case sd: StructDescriptor =>
args.skip(descriptorModule17.toGroupLayout(sd))
case AliasDescriptor(real) => skip(real)
case VaListDescriptor => args.skip(C_POINTER)
case AliasDescriptor(real) => skip(real)
case VaListDescriptor => args.skip(C_POINTER)
case CUnionDescriptor(possibleTypes) => skip(possibleTypes.maxBy(_.size))

def skip[A](using dO: DescriptorOf[A]): Unit = skip(dO.descriptor)

Expand Down
13 changes: 9 additions & 4 deletions j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ import jdk.incubator.foreign.CLinker.{
C_DOUBLE,
C_FLOAT,
C_LONG_LONG,
C_POINTER,
C_VA_LIST
C_POINTER
}
import jdk.incubator.foreign.{
MemoryLayout,
MemoryAddress,
MemorySegment,
GroupLayout,
CLinker
}, CLinker.VaList
}
import scala.collection.concurrent.TrieMap

given descriptorModule17: DescriptorModule with
Expand All @@ -36,6 +35,7 @@ given descriptorModule17: DescriptorModule with
case _: StructDescriptor => classOf[MemorySegment]
case VaListDescriptor => classOf[MemoryAddress]
case ad: AliasDescriptor[?] => toCarrierType(ad.real)
case ud: CUnionDescriptor => classOf[MemorySegment]

def genLayoutList(
layouts: Seq[MemoryLayout],
Expand Down Expand Up @@ -80,11 +80,14 @@ given descriptorModule17: DescriptorModule with
case sd: StructDescriptor =>
Bytes(toGroupLayout(sd).byteSize())
case VaListDescriptor => Bytes(toMemoryLayout(VaListDescriptor).byteSize())
case ad: AliasDescriptor[?] => sizeOf(ad.real)
case ad: AliasDescriptor[?] => sizeOf(ad.real)
case CUnionDescriptor(possibleTypes) => possibleTypes.map(sizeOf).max

override def alignmentOf(td: TypeDescriptor): Bytes = td match
case s: StructDescriptor =>
s.members.view.map(_.descriptor).map(alignmentOf).max
case CUnionDescriptor(possibleTypes) =>
possibleTypes.view.map(alignmentOf).max
case _ => sizeOf(td)

override def memberOffsets(sd: List[TypeDescriptor]): IArray[Bytes] =
Expand Down Expand Up @@ -125,6 +128,8 @@ given descriptorModule17: DescriptorModule with
case VaListDescriptor => C_POINTER.nn
case sd: StructDescriptor => toGroupLayout(sd)
case ad: AliasDescriptor[?] => toMemoryLayout(ad.real)
case CUnionDescriptor(possibleTypes) =>
MemoryLayout.unionLayout(possibleTypes.map(toMemoryLayout).toSeq*).nn

def toMemoryLayout(smd: StructMemberDescriptor): MemoryLayout =
toMemoryLayout(smd.descriptor).withName(smd.name).nn
Expand Down
5 changes: 5 additions & 0 deletions j17/test/src/fr/hammons/slinc/modules/Descriptor17Spec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package fr.hammons.slinc.modules

import fr.hammons.slinc.Slinc17

class Descriptor17Spec extends DescriptorSpec(Slinc17.default)
Loading

0 comments on commit a33580c

Please sign in to comment.