Skip to content

Commit

Permalink
generic func infer lambda expr
Browse files Browse the repository at this point in the history
  • Loading branch information
visualfc committed Mar 25, 2024
1 parent 0a7dbfe commit 564f50f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 10 deletions.
55 changes: 45 additions & 10 deletions cl/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package cl

import (
"bytes"
"errors"
goast "go/ast"
gotoken "go/token"
"go/types"
Expand All @@ -27,6 +28,7 @@ import (
"strconv"
"strings"
"syscall"
_ "unsafe"

"github.com/goplus/gogen"
"github.com/goplus/gogen/cpackages"
Expand Down Expand Up @@ -566,12 +568,14 @@ func identVal(ctx *blockCtx, x *ast.Ident, flags int, v types.Object, alias bool
}

type fnType struct {
next *fnType
params *types.Tuple
base int
size int
variadic bool
typetype bool
next *fnType
params *types.Tuple
sig *types.Signature
base int
size int
variadic bool
typetype bool
typeparam bool
}

func (p *fnType) arg(i int, ellipsis bool) types.Type {
Expand All @@ -590,7 +594,8 @@ func (p *fnType) arg(i int, ellipsis bool) types.Type {

func (p *fnType) init(base int, t *types.Signature) {
p.base = base
p.params, p.variadic = t.Params(), t.Variadic()
p.sig = t
p.params, p.variadic, p.typeparam = t.Params(), t.Variadic(), t.TypeParams() != nil
p.size = p.params.Len()
if p.variadic {
p.size--
Expand Down Expand Up @@ -669,7 +674,6 @@ func compileCallExpr(ctx *blockCtx, v *ast.CallExpr, inFlags int) {
var err error
var stk = ctx.cb.InternalStack()
var base = stk.Len()
var fnt = stk.Get(-1).Type
var flags gogen.InstrFlags
var ellipsis = v.Ellipsis != gotoken.NoPos
if ellipsis {
Expand All @@ -678,10 +682,12 @@ func compileCallExpr(ctx *blockCtx, v *ast.CallExpr, inFlags int) {
if (inFlags & clCallWithTwoValue) != 0 {
flags |= gogen.InstrFlagTwoValue
}
pfn := stk.Get(-1)
fnt := pfn.Type
fn := &fnType{}
fn.load(fnt)
for fn != nil {
if err = compileCallArgs(fn, ctx, v, ellipsis, flags); err == nil {
if err = compileCallArgs(ctx, pfn, fn, v, ellipsis, flags); err == nil {
if rec := ctx.recorder(); rec != nil {
rec.recordCallExpr(ctx, v, fnt)
}
Expand Down Expand Up @@ -736,10 +742,16 @@ func fnCall(ctx *blockCtx, v *ast.CallExpr, flags gogen.InstrFlags, extra int) e
return ctx.cb.CallWithEx(len(v.Args)+extra, flags, v)
}

func compileCallArgs(fn *fnType, ctx *blockCtx, v *ast.CallExpr, ellipsis bool, flags gogen.InstrFlags) (err error) {
func compileCallArgs(ctx *blockCtx, pfn *gogen.Element, fn *fnType, v *ast.CallExpr, ellipsis bool, flags gogen.InstrFlags) (err error) {
var needInferFunc bool
for i, arg := range v.Args {
switch expr := arg.(type) {
case *ast.LambdaExpr:
if fn.typeparam {
needInferFunc = true
compileIdent(ctx, ast.NewIdent("nil"), 0)
continue
}
sig, e := checkLambdaFuncType(ctx, expr, fn.arg(i, ellipsis), clLambaArgument, v.Fun)
if e != nil {
return e
Expand All @@ -748,6 +760,11 @@ func compileCallArgs(fn *fnType, ctx *blockCtx, v *ast.CallExpr, ellipsis bool,
return
}
case *ast.LambdaExpr2:
if fn.typeparam {
needInferFunc = true
compileIdent(ctx, ast.NewIdent("nil"), 0)
continue
}
sig, e := checkLambdaFuncType(ctx, expr, fn.arg(i, ellipsis), clLambaArgument, v.Fun)
if e != nil {
return e
Expand Down Expand Up @@ -784,9 +801,27 @@ func compileCallArgs(fn *fnType, ctx *blockCtx, v *ast.CallExpr, ellipsis bool,
compileExpr(ctx, arg)
}
}
if needInferFunc {
typ, err := inferFunc(ctx.pkg, pfn, fn.sig, nil, ctx.cb.InternalStack().GetArgs(len(v.Args)), 0)
if err != nil {
return err
}
next := &fnType{}
next.init(fn.base, typ.(*types.Signature))
next.next = fn.next
fn.next = next
return errCallNext
}
return ctx.cb.CallWithEx(len(v.Args), flags, v)
}

//go:linkname inferFunc github.com/goplus/gogen.inferFunc
func inferFunc(pkg *gogen.Package, fn *gogen.Element, sig *types.Signature, targs []types.Type, args []*gogen.Element, flags gogen.InstrFlags) (types.Type, error)

var (
errCallNext = errors.New("call next")
)

type clLambaFlag string

const (
Expand Down
71 changes: 71 additions & 0 deletions cl/typeparams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,74 @@ func main() {
}
`)
}

func TestInferFuncLambda(t *testing.T) {
gopMixedClTest(t, "main", `package main
func ListMap[T any](ar []T, fn func(v T) T)[]T {
for i, v := range ar {
ar[i] = fn(v)
}
return ar
}
`, `
println ListMap([1,2,3,4], x => x*x)
ListMap [1,2,3,4], x => {
println x
return x
}
`, `package main
import "fmt"
func main() {
fmt.Println(ListMap([]int{1, 2, 3, 4}, func(x int) int {
return x * x
}))
ListMap([]int{1, 2, 3, 4}, func(x int) int {
fmt.Println(x)
return x
})
}
`)
}

func TestInferOverloadFuncLambda(t *testing.T) {
gopMixedClTest(t, "main", `package main
func ListMap__0[T any](ar []T, fn func(v T) T)[]T {
for i, v := range ar {
ar[i] = fn(v)
}
return ar
}
func ListMap__1(a string, fn func(s string)) {
for _, c := range a {
fn(string(c))
}
}
`, `
println ListMap([1,2,3,4], x => x*x)
ListMap [1,2,3,4], x => {
println x
return x
}
ListMap "hello", x => {
println x
}
`, `package main
import "fmt"
func main() {
fmt.Println(ListMap__0([]int{1, 2, 3, 4}, func(x int) int {
return x * x
}))
ListMap__0([]int{1, 2, 3, 4}, func(x int) int {
fmt.Println(x)
return x
})
ListMap__1("hello", func(x string) {
fmt.Println(x)
})
}
`)
}

0 comments on commit 564f50f

Please sign in to comment.