Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic func infer lambda expr #1826

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions cl/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package cl

import (
"bytes"
"errors"
goast "go/ast"
gotoken "go/token"
"go/types"
Expand Down Expand Up @@ -566,12 +567,14 @@ func identVal(ctx *blockCtx, x *ast.Ident, flags int, v types.Object, alias bool
}

type fnType struct {
next *fnType
params *types.Tuple
base int
size int
variadic bool
typetype bool
next *fnType
params *types.Tuple
sig *types.Signature
base int
size int
variadic bool
typetype bool
typeparam bool
}

func (p *fnType) arg(i int, ellipsis bool) types.Type {
Expand All @@ -590,7 +593,8 @@ func (p *fnType) arg(i int, ellipsis bool) types.Type {

func (p *fnType) init(base int, t *types.Signature) {
p.base = base
p.params, p.variadic = t.Params(), t.Variadic()
p.sig = t
p.params, p.variadic, p.typeparam = t.Params(), t.Variadic(), t.TypeParams() != nil
p.size = p.params.Len()
if p.variadic {
p.size--
Expand Down Expand Up @@ -669,7 +673,6 @@ func compileCallExpr(ctx *blockCtx, v *ast.CallExpr, inFlags int) {
var err error
var stk = ctx.cb.InternalStack()
var base = stk.Len()
var fnt = stk.Get(-1).Type
var flags gogen.InstrFlags
var ellipsis = v.Ellipsis != gotoken.NoPos
if ellipsis {
Expand All @@ -678,10 +681,12 @@ func compileCallExpr(ctx *blockCtx, v *ast.CallExpr, inFlags int) {
if (inFlags & clCallWithTwoValue) != 0 {
flags |= gogen.InstrFlagTwoValue
}
pfn := stk.Get(-1)
fnt := pfn.Type
fn := &fnType{}
fn.load(fnt)
for fn != nil {
if err = compileCallArgs(fn, ctx, v, ellipsis, flags); err == nil {
if err = compileCallArgs(ctx, pfn, fn, v, ellipsis, flags); err == nil {
if rec := ctx.recorder(); rec != nil {
rec.recordCallExpr(ctx, v, fnt)
}
Expand Down Expand Up @@ -736,10 +741,16 @@ func fnCall(ctx *blockCtx, v *ast.CallExpr, flags gogen.InstrFlags, extra int) e
return ctx.cb.CallWithEx(len(v.Args)+extra, flags, v)
}

func compileCallArgs(fn *fnType, ctx *blockCtx, v *ast.CallExpr, ellipsis bool, flags gogen.InstrFlags) (err error) {
func compileCallArgs(ctx *blockCtx, pfn *gogen.Element, fn *fnType, v *ast.CallExpr, ellipsis bool, flags gogen.InstrFlags) (err error) {
var needInferFunc bool
for i, arg := range v.Args {
switch expr := arg.(type) {
case *ast.LambdaExpr:
if fn.typeparam {
needInferFunc = true
compileIdent(ctx, ast.NewIdent("nil"), 0)
continue
}
sig, e := checkLambdaFuncType(ctx, expr, fn.arg(i, ellipsis), clLambaArgument, v.Fun)
if e != nil {
return e
Expand All @@ -748,6 +759,11 @@ func compileCallArgs(fn *fnType, ctx *blockCtx, v *ast.CallExpr, ellipsis bool,
return
}
case *ast.LambdaExpr2:
if fn.typeparam {
needInferFunc = true
compileIdent(ctx, ast.NewIdent("nil"), 0)
continue
}
sig, e := checkLambdaFuncType(ctx, expr, fn.arg(i, ellipsis), clLambaArgument, v.Fun)
if e != nil {
return e
Expand Down Expand Up @@ -784,9 +800,24 @@ func compileCallArgs(fn *fnType, ctx *blockCtx, v *ast.CallExpr, ellipsis bool,
compileExpr(ctx, arg)
}
}
if needInferFunc {
typ, err := gogen.InferFunc(ctx.pkg, pfn, fn.sig, nil, ctx.cb.InternalStack().GetArgs(len(v.Args)), flags)
if err != nil {
return err
}
next := &fnType{}
next.init(fn.base, typ.(*types.Signature))
next.next = fn.next
fn.next = next
return errCallNext
}
return ctx.cb.CallWithEx(len(v.Args), flags, v)
}

var (
errCallNext = errors.New("call next")
)

type clLambaFlag string

const (
Expand Down
71 changes: 71 additions & 0 deletions cl/typeparams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,74 @@ func main() {
}
`)
}

func TestInferFuncLambda(t *testing.T) {
gopMixedClTest(t, "main", `package main
func ListMap[T any](ar []T, fn func(v T) T)[]T {
for i, v := range ar {
ar[i] = fn(v)
}
return ar
}
`, `
println ListMap([1,2,3,4], x => x*x)
ListMap [1,2,3,4], x => {
println x
return x
}
`, `package main
import "fmt"
func main() {
fmt.Println(ListMap([]int{1, 2, 3, 4}, func(x int) int {
return x * x
}))
ListMap([]int{1, 2, 3, 4}, func(x int) int {
fmt.Println(x)
return x
})
}
`)
}

func TestInferOverloadFuncLambda(t *testing.T) {
gopMixedClTest(t, "main", `package main
func ListMap__0[T any](ar []T, fn func(v T) T)[]T {
for i, v := range ar {
ar[i] = fn(v)
}
return ar
}
func ListMap__1(a string, fn func(s string)) {
for _, c := range a {
fn(string(c))
}
}
`, `
println ListMap([1,2,3,4], x => x*x)
ListMap [1,2,3,4], x => {
println x
return x
}
ListMap "hello", x => {
println x
}
`, `package main
import "fmt"
func main() {
fmt.Println(ListMap__0([]int{1, 2, 3, 4}, func(x int) int {
return x * x
}))
ListMap__0([]int{1, 2, 3, 4}, func(x int) int {
fmt.Println(x)
return x
})
ListMap__1("hello", func(x string) {
fmt.Println(x)
})
}
`)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.18
require (
github.com/fsnotify/fsnotify v1.7.0
github.com/goplus/c2go v0.7.25
github.com/goplus/gogen v1.15.1
github.com/goplus/gogen v1.15.2-0.20240325030304-38a18ebdfb1f
github.com/goplus/mod v0.13.9
github.com/qiniu/x v1.13.9
golang.org/x/tools v0.19.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyT
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/goplus/c2go v0.7.25 h1:QvQfOwGVGKFYOTLry8i4p9U55jnIOQWeLsnSyFg0hhc=
github.com/goplus/c2go v0.7.25/go.mod h1:e9oe4jDVhGFMJLEGmPSrVkLuXbLZAEmAu0/uD6fSz5E=
github.com/goplus/gogen v1.15.1 h1:iz/fFpOeldjwmnjLzEdNsZF2mCf+sOHJavbAvV3o7sY=
github.com/goplus/gogen v1.15.1/go.mod h1:92qEzVgv7y8JEFICWG9GvYI5IzfEkxYdsA1DbmnTkqk=
github.com/goplus/gogen v1.15.2-0.20240325030304-38a18ebdfb1f h1:ScHZQ/KkVjVyea/1ivGUaBh24KnqG3PpQXn0LffjwUc=
github.com/goplus/gogen v1.15.2-0.20240325030304-38a18ebdfb1f/go.mod h1:92qEzVgv7y8JEFICWG9GvYI5IzfEkxYdsA1DbmnTkqk=
github.com/goplus/mod v0.13.9 h1:B9zZoHi2AzMltTSOFqZNVjqGlSMlhhNTWwEzVqhTQzg=
github.com/goplus/mod v0.13.9/go.mod h1:MibsLSftGmxaQq78YzUzNviyFwB9RtpMaoscufvEKH4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down