Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add TypeDescriptor for CUnion #183

Merged
merged 2 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/fr/hammons/slinc/CUnion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import fr.hammons.slinc.modules.ReadWriteModule
import fr.hammons.slinc.modules.DescriptorModule
import scala.NonEmptyTuple

class CUnion[T <: Tuple](mem: Mem):
class CUnion[T <: Tuple](private[slinc] val mem: Mem):
private inline def getHelper[T <: Tuple, A](using
dO: DescriptorOf[A],
rwm: ReadWriteModule
Expand Down
13 changes: 13 additions & 0 deletions core/src/fr/hammons/slinc/DescriptorOf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package fr.hammons.slinc

import fr.hammons.slinc.container.*
import scala.quoted.*
import scala.compiletime.{summonInline, erasedValue}
import scala.NonEmptyTuple

/** Typeclass that summons TypeDescriptors
*/
Expand Down Expand Up @@ -64,3 +66,14 @@ object DescriptorOf:
)

'{ $expr.descriptor }

private inline def helper[B <: Tuple]: Set[TypeDescriptor] =
inline erasedValue[B] match
case _: (a *: t) => helper[t] + summonInline[DescriptorOf[a]].descriptor
case _: EmptyTuple => Set.empty[TypeDescriptor]

inline given [A <: NonEmptyTuple]: DescriptorOf[CUnion[A]] =
new DescriptorOf[CUnion[A]]:
val descriptor: CUnionDescriptor { type Inner = CUnion[A] } =
CUnionDescriptor(helper[A])
.asInstanceOf[CUnionDescriptor { type Inner = CUnion[A] }]
2 changes: 0 additions & 2 deletions core/src/fr/hammons/slinc/MacroHelpers.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package fr.hammons.slinc

import scala.quoted.*
import scala.annotation.nowarn

private[slinc] object MacroHelpers:
def widenExpr(t: Expr[?])(using Quotes) =
import quotes.reflect.*
Expand Down
2 changes: 0 additions & 2 deletions core/src/fr/hammons/slinc/Mem.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package fr.hammons.slinc

import scala.annotation.nowarn

trait Mem:
import scala.compiletime.asMatchable
def offset(bytes: Bytes): Mem
Expand Down
1 change: 0 additions & 1 deletion core/src/fr/hammons/slinc/MethodHandleTools.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fr.hammons.slinc
import scala.quoted.*
import java.lang.invoke.MethodHandle
import fr.hammons.slinc.modules.TransitionModule
import scala.annotation.nowarn

object MethodHandleTools:
def exprNameMapping(expr: Expr[Any])(using Quotes): String =
Expand Down
1 change: 0 additions & 1 deletion core/src/fr/hammons/slinc/SlincImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package fr.hammons.slinc

import scala.annotation.StaticAnnotation
import scala.quoted.*
import scala.annotation.nowarn

class SlincImpl(val version: Int) extends StaticAnnotation

Expand Down
3 changes: 0 additions & 3 deletions core/src/fr/hammons/slinc/Struct.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import fr.hammons.slinc.modules.TransitionModule
import fr.hammons.slinc.modules.ReadWriteModule
import fr.hammons.slinc.modules.Reader
import fr.hammons.slinc.modules.Writer
import scala.annotation.nowarn

trait Struct[A <: Product] extends DescriptorOf[A]

Expand Down Expand Up @@ -110,12 +109,10 @@ object Struct:
val reader = readGen[A]
val writer = writeGen[A]

@nowarn("msg=unused implicit parameter")
override val returnTransition = returnValue =>
val mem = summon[TransitionModule].memReturn(returnValue)
summon[ReadWriteModule].read(mem, Bytes(0), this)

@nowarn("msg=unused implicit parameter")
override val argumentTransition = argument =>
val mem = summon[Allocator].allocate(this, 1)
summon[ReadWriteModule].write(mem, Bytes(0), this, argument)
Expand Down
34 changes: 19 additions & 15 deletions core/src/fr/hammons/slinc/TypeDescriptor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import fr.hammons.slinc.modules.{
ArrayReader,
readWriteModule
}
import scala.annotation.nowarn
import scala.reflect.ClassTag
import scala.quoted.*
import fr.hammons.slinc.modules.TransitionModule
import fr.hammons.slinc.modules.{ArgumentTransition, ReturnTransition}
import scala.NonEmptyTuple

/** Describes types used by C interop
*/
Expand Down Expand Up @@ -106,31 +106,23 @@ case object LongDescriptor extends BasicDescriptor:

case object FloatDescriptor extends BasicDescriptor:
type Inner = Float
@nowarn("msg=unused implicit parameter")
val reader = readWriteModule.floatReader
@nowarn("msg=unused implicit parameter")
val writer = readWriteModule.floatWriter

case object DoubleDescriptor extends BasicDescriptor:
type Inner = Double
@nowarn("msg=unused implicit parameter")
val reader = readWriteModule.doubleReader
@nowarn("msg=unused implicit parameter")
val writer = readWriteModule.doubleWriter

case object PtrDescriptor extends TypeDescriptor:
type Inner = Ptr[?]
@nowarn("msg=unused implicit parameter")
override val reader = (mem, offset) =>
Ptr(readWriteModule.memReader(mem, offset), Bytes(0))
@nowarn("msg=unused implicit parameter")
override val writer = (mem, offset, a) =>
readWriteModule.memWriter(mem, offset, a.mem)

@nowarn("msg=unused implicit parameter")
override val argumentTransition = _.mem.asAddress

@nowarn("msg=unused implicit parameter")
override val returnTransition = o =>
Ptr[Any](summon[TransitionModule].addressReturn(o), Bytes(0))

Expand Down Expand Up @@ -174,11 +166,9 @@ case class AliasDescriptor[A](val real: TypeDescriptor) extends TypeDescriptor:
val writer: (ReadWriteModule, DescriptorModule) ?=> Writer[Inner] =
(rwm, _) ?=> (mem, bytes, a) => rwm.write(mem, bytes, real, a)

@nowarn("msg=unused implicit parameter")
override val argumentTransition =
summon[TransitionModule].methodArgument(real, _, summon[Allocator])

@nowarn("msg=unused implicit parameter")
override val returnTransition = summon[TransitionModule].methodReturn(real, _)
override def size(using dm: DescriptorModule): Bytes = dm.sizeOf(real)
override def alignment(using dm: DescriptorModule): Bytes =
Expand All @@ -189,22 +179,36 @@ case class AliasDescriptor[A](val real: TypeDescriptor) extends TypeDescriptor:
case object VaListDescriptor extends TypeDescriptor:
type Inner = VarArgs

@nowarn(TypeDescriptor.unusedImplicit)
override val reader: (ReadWriteModule, DescriptorModule) ?=> Reader[Inner] =
(mem, offset) => summon[ReadWriteModule].memReader(mem, offset).asVarArgs

@nowarn(TypeDescriptor.unusedImplicit)
override val argumentTransition
: (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[
Inner
] = _.mem.asAddress

@nowarn(TypeDescriptor.unusedImplicit)
override val writer: (ReadWriteModule, DescriptorModule) ?=> Writer[Inner] =
(mem, offset, value) =>
summon[ReadWriteModule].memWriter(mem, offset, value.mem)

@nowarn(TypeDescriptor.unusedImplicit)
override val returnTransition
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = o =>
summon[TransitionModule].addressReturn(o).asVarArgs

case class CUnionDescriptor(possibleTypes: Set[TypeDescriptor])
extends TypeDescriptor:
type Inner = CUnion[? <: NonEmptyTuple]

override val reader: (ReadWriteModule, DescriptorModule) ?=> Reader[Inner] =
???

override val returnTransition
: (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = ???

override val argumentTransition
: (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[
Inner
] = ???

override val writer: (ReadWriteModule, DescriptorModule) ?=> Writer[Inner] =
???
4 changes: 0 additions & 4 deletions core/src/fr/hammons/slinc/types/Arch.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package fr.hammons.slinc.types

import scala.annotation.nowarn

//todo: remove once https://github.com/lampepfl/dotty/issues/16878 is fixed
@nowarn("msg=unused explicit parameter")
private[slinc] enum Arch:
case I386
case X64
Expand Down
4 changes: 0 additions & 4 deletions core/src/fr/hammons/slinc/types/OS.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package fr.hammons.slinc.types

import scala.annotation.nowarn

//todo: remove once https://github.com/lampepfl/dotty/issues/16878 is fixed
@nowarn("msg=unused explicit parameter")
enum OS:
case Linux
case Darwin
Expand Down
27 changes: 27 additions & 0 deletions core/test/src/fr/hammons/slinc/TypeDescriptorSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package fr.hammons.slinc
import types.*

class TypeDescriptorSpec extends munit.FunSuite:
case class A(a: CInt, b: CInt) derives Struct
test("DescriptorOf[CUnion[(Int, Float, A)]] gives appropriate descriptor"):
assertEquals(
DescriptorOf[CUnion[(CInt, CFloat, A)]]: TypeDescriptor,
CUnionDescriptor(
Set(IntDescriptor, FloatDescriptor, summon[Struct[A]].descriptor)
): TypeDescriptor
)

test(
"DescriptorOf[CUnion[(CUnion[(Int, Float)], CUnion[(Int, Float)], A,A])]] doesn't double descriptors"
):
assertEquals(
DescriptorOf[
CUnion[(CUnion[(Int, Float)], CUnion[(Int, Float)], A, A)]
]: TypeDescriptor,
CUnionDescriptor(
Set(
CUnionDescriptor(Set(IntDescriptor, FloatDescriptor)),
summon[Struct[A]].descriptor
)
): TypeDescriptor
)
25 changes: 25 additions & 0 deletions core/test/src/fr/hammons/slinc/modules/DescriptorSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package fr.hammons.slinc.modules

import fr.hammons.slinc.Slinc
import fr.hammons.slinc.DescriptorOf
import fr.hammons.slinc.types.CLongLong
import fr.hammons.slinc.types.CInt
import fr.hammons.slinc.types.CFloat
import fr.hammons.slinc.CUnion
import fr.hammons.slinc.Struct

trait DescriptorSpec(val slinc: Slinc) extends munit.FunSuite:
import slinc.dm
case class A(a: CInt, b: CInt, c: CInt, d: CLongLong, e: CLongLong)
derives Struct
test("CUnionDescriptor.size gives the right size"):
assertEquals(
DescriptorOf[CUnion[(CInt, A, CFloat)]].size,
DescriptorOf[A].size
)

test("CUnionDescriptor.alignment gives the right alignment"):
assertEquals(
DescriptorOf[CUnion[(CInt, A, CFloat)]].alignment,
DescriptorOf[A].alignment
)
19 changes: 16 additions & 3 deletions j17/src/fr/hammons/slinc/Allocator17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ import jdk.incubator.foreign.{
MemorySegment,
ResourceScope,
CLinker,
FunctionDescriptor as JFunctionDescriptor
FunctionDescriptor as JFunctionDescriptor,
GroupLayout
}, CLinker.{C_POINTER, C_INT, C_LONG_LONG, C_DOUBLE, VaList}
import fr.hammons.slinc.modules.{descriptorModule17, transitionModule17}
import fr.hammons.slinc.modules.LinkageModule17
import scala.annotation.nowarn

@nowarn("msg=unused import")
class Allocator17(
segmentAllocator: SegmentAllocator,
scope: ResourceScope,
Expand Down Expand Up @@ -60,6 +59,20 @@ class Allocator17(
.asInstanceOf[Addressable]
)
)
case (cd: CUnionDescriptor, v: CUnion[?]) =>
builder.vargFromSegment(
descriptorModule17.toMemoryLayout(cd) match
case gl: GroupLayout => gl
case _ => throw Error("got a non group layout from CUnionDescriptor")
,
v.mem.asBase match
case ms: MemorySegment => ms
case _ => throw Error("base of mem was not J17 MemorySegment!!")
)
case (a, d) =>
throw Error(
s"Unsupported type descriptor/data pairing for VarArgs: $a - $d"
)

override def makeVarArgs(vbuilder: VarArgsBuilder): VarArgs =
VarArgs17(
Expand Down
10 changes: 6 additions & 4 deletions j17/src/fr/hammons/slinc/VarArgs17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ class VarArgs17(args: VaList) extends VarArgs:
)
.nn
)
case AliasDescriptor(real) => get(real)
case VaListDescriptor => args.vargAsAddress(C_POINTER).nn
case AliasDescriptor(real) => get(real)
case VaListDescriptor => args.vargAsAddress(C_POINTER).nn
case CUnionDescriptor(possibleTypes) => get(possibleTypes.maxBy(_.size))
def get[A](using d: DescriptorOf[A]): A =
transitionModule17.methodReturn[A](d.descriptor, get(d.descriptor))

Expand All @@ -46,8 +47,9 @@ class VarArgs17(args: VaList) extends VarArgs:
case PtrDescriptor => args.skip(C_POINTER)
case sd: StructDescriptor =>
args.skip(descriptorModule17.toGroupLayout(sd))
case AliasDescriptor(real) => skip(real)
case VaListDescriptor => args.skip(C_POINTER)
case AliasDescriptor(real) => skip(real)
case VaListDescriptor => args.skip(C_POINTER)
case CUnionDescriptor(possibleTypes) => skip(possibleTypes.maxBy(_.size))

def skip[A](using dO: DescriptorOf[A]): Unit = skip(dO.descriptor)

Expand Down
13 changes: 9 additions & 4 deletions j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ import jdk.incubator.foreign.CLinker.{
C_DOUBLE,
C_FLOAT,
C_LONG_LONG,
C_POINTER,
C_VA_LIST
C_POINTER
}
import jdk.incubator.foreign.{
MemoryLayout,
MemoryAddress,
MemorySegment,
GroupLayout,
CLinker
}, CLinker.VaList
}
import scala.collection.concurrent.TrieMap

given descriptorModule17: DescriptorModule with
Expand All @@ -36,6 +35,7 @@ given descriptorModule17: DescriptorModule with
case _: StructDescriptor => classOf[MemorySegment]
case VaListDescriptor => classOf[MemoryAddress]
case ad: AliasDescriptor[?] => toCarrierType(ad.real)
case ud: CUnionDescriptor => classOf[MemorySegment]

def genLayoutList(
layouts: Seq[MemoryLayout],
Expand Down Expand Up @@ -80,11 +80,14 @@ given descriptorModule17: DescriptorModule with
case sd: StructDescriptor =>
Bytes(toGroupLayout(sd).byteSize())
case VaListDescriptor => Bytes(toMemoryLayout(VaListDescriptor).byteSize())
case ad: AliasDescriptor[?] => sizeOf(ad.real)
case ad: AliasDescriptor[?] => sizeOf(ad.real)
case CUnionDescriptor(possibleTypes) => possibleTypes.map(sizeOf).max

override def alignmentOf(td: TypeDescriptor): Bytes = td match
case s: StructDescriptor =>
s.members.view.map(_.descriptor).map(alignmentOf).max
case CUnionDescriptor(possibleTypes) =>
possibleTypes.view.map(alignmentOf).max
case _ => sizeOf(td)

override def memberOffsets(sd: List[TypeDescriptor]): IArray[Bytes] =
Expand Down Expand Up @@ -125,6 +128,8 @@ given descriptorModule17: DescriptorModule with
case VaListDescriptor => C_POINTER.nn
case sd: StructDescriptor => toGroupLayout(sd)
case ad: AliasDescriptor[?] => toMemoryLayout(ad.real)
case CUnionDescriptor(possibleTypes) =>
MemoryLayout.unionLayout(possibleTypes.map(toMemoryLayout).toSeq*).nn

def toMemoryLayout(smd: StructMemberDescriptor): MemoryLayout =
toMemoryLayout(smd.descriptor).withName(smd.name).nn
Expand Down
5 changes: 5 additions & 0 deletions j17/test/src/fr/hammons/slinc/modules/Descriptor17Spec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package fr.hammons.slinc.modules

import fr.hammons.slinc.Slinc17

class Descriptor17Spec extends DescriptorSpec(Slinc17.default)
Loading