Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: Add Function Type Annotation Syntax #439

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions resolve/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ type Module struct {
// A Function contains resolver information about a named or anonymous function.
// The resolver populates the Function field of each syntax.DefStmt and syntax.LambdaExpr.
type Function struct {
Pos syntax.Position // of DEF or LAMBDA
Name string // name of def, or "lambda"
Params []syntax.Expr // param = ident | ident=expr | * | *ident | **ident
Body []syntax.Stmt // contains synthetic 'return expr' for lambda
Pos syntax.Position // of DEF or LAMBDA
Name string // name of def, or "lambda"
Params []syntax.Expr // param = ident | ident=expr | * | *ident | **ident
Body []syntax.Stmt // contains synthetic 'return expr' for lambda
ReturnType syntax.Expr // can be nil, type hint expression after '->'

HasVarargs bool // whether params includes *args (convenience)
HasKwargs bool // whether params includes **kwargs (convenience)
Expand Down
41 changes: 33 additions & 8 deletions resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ const doesnt = "this Starlark dialect does not "
// These features are either not standard Starlark (yet), or deprecated
// features of the BUILD language, so we put them behind flags.
var (
AllowSet = false // allow the 'set' built-in
AllowGlobalReassign = false // allow reassignment to top-level names; also, allow if/for/while at top-level
AllowRecursion = false // allow while statements and recursive functions
LoadBindsGlobally = false // load creates global not file-local bindings (deprecated)
AllowSet = false // allow the 'set' built-in
AllowGlobalReassign = false // allow reassignment to top-level names; also, allow if/for/while at top-level
AllowRecursion = false // allow while statements and recursive functions
ResolveTypeHintIdents = false // resolve identifiers in type hints
LoadBindsGlobally = false // load creates global not file-local bindings (deprecated)

// obsolete flags for features that are now standard. No effect.
AllowNestedDef = true
Expand Down Expand Up @@ -510,10 +511,11 @@ func (r *resolver) stmt(stmt syntax.Stmt) {
case *syntax.DefStmt:
r.bind(stmt.Name)
fn := &Function{
Name: stmt.Name.Name,
Pos: stmt.Def,
Params: stmt.Params,
Body: stmt.Body,
Name: stmt.Name.Name,
Pos: stmt.Def,
Params: stmt.Params,
Body: stmt.Body,
ReturnType: stmt.ReturnType,
}
stmt.Function = fn
r.function(fn, stmt.Def)
Expand Down Expand Up @@ -804,6 +806,29 @@ func (r *resolver) function(function *Function, pos syntax.Position) {
}
}

// Resolve function type hints in enclosing environment.
if ResolveTypeHintIdents {
if function.ReturnType != nil {
r.expr(function.ReturnType)
}
for _, param := range function.Params {
switch param := param.(type) {
case *syntax.Ident:
if param.TypeHint != nil {
r.expr(param.TypeHint)
}
case *syntax.BinaryExpr:
if param.X.(*syntax.Ident).TypeHint != nil {
r.expr(param.X.(*syntax.Ident).TypeHint)
}
case *syntax.UnaryExpr:
if param.X.(*syntax.Ident).TypeHint != nil {
r.expr(param.X.(*syntax.Ident).TypeHint)
}
}
}
}

// Enter function block.
b := &block{function: function}
r.push(b)
Expand Down
4 changes: 2 additions & 2 deletions syntax/grammar.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ File = {Statement | newline} eof .

Statement = DefStmt | IfStmt | ForStmt | WhileStmt | SimpleStmt .

DefStmt = 'def' identifier '(' [Parameters [',']] ')' ':' Suite .
DefStmt = 'def' identifier '(' [Parameters [',']] ')' ['->' Test] ':' Suite .

Parameters = Parameter {',' Parameter}.

Parameter = identifier | identifier '=' Test | '*' | '*' identifier | '**' identifier .
Parameter = identifier [':' Test] | identifier [':' Test] '=' Test | '*' | '*' identifier [':' Test] | '**' identifier [':' Test] .

IfStmt = 'if' Test ':' Suite {'elif' Test ':' Suite} ['else' ':' Suite] .

Expand Down
109 changes: 71 additions & 38 deletions syntax/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,24 @@ func (p *parser) parseDefStmt() Stmt {
defpos := p.nextToken() // consume DEF
id := p.parseIdent()
p.consume(LPAREN)
params := p.parseParams()
params := p.parseParams(false)
p.consume(RPAREN)

var returnType Expr
// def fn() -> type:
if p.tok == ARROW {
p.consume(ARROW)
returnType = p.parseTest()
}

p.consume(COLON)
body := p.parseSuite()
return &DefStmt{
Def: defpos,
Name: id,
Params: params,
Body: body,
Def: defpos,
Name: id,
Params: params,
Body: body,
ReturnType: returnType,
}
}

Expand Down Expand Up @@ -275,10 +284,11 @@ func (p *parser) parseSimpleStmt(stmts []Stmt, consumeNL bool) []Stmt {
}

// small_stmt = RETURN expr?
// | PASS | BREAK | CONTINUE
// | LOAD ...
// | expr ('=' | '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' | '<<=' | '>>=') expr // assign
// | expr
//
// | PASS | BREAK | CONTINUE
// | LOAD ...
// | expr ('=' | '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' | '<<=' | '>>=') expr // assign
// | expr
func (p *parser) parseSmallStmt() Stmt {
switch p.tok {
case RETURN:
Expand All @@ -300,6 +310,7 @@ func (p *parser) parseSmallStmt() Stmt {

// Assignment
x := p.parseExpr(false)

switch p.tok {
case EQ, PLUS_EQ, MINUS_EQ, STAR_EQ, SLASH_EQ, SLASHSLASH_EQ, PERCENT_EQ, AMP_EQ, PIPE_EQ, CIRCUMFLEX_EQ, LTLT_EQ, GTGT_EQ:
op := p.tok
Expand Down Expand Up @@ -415,22 +426,24 @@ func (p *parser) consume(t Token) Position {
}

// params = (param COMMA)* param COMMA?
// |
//
// |
//
// param = IDENT
// | IDENT EQ test
// | STAR
// | STAR IDENT
// | STARSTAR IDENT
//
// | IDENT EQ test
// | STAR
// | STAR IDENT
// | STARSTAR IDENT
//
// parseParams parses a parameter list. The resulting expressions are of the form:
//
// *Ident x
// *Binary{Op: EQ, X: *Ident, Y: Expr} x=y
// *Unary{Op: STAR} *
// *Unary{Op: STAR, X: *Ident} *args
// *Unary{Op: STARSTAR, X: *Ident} **kwargs
func (p *parser) parseParams() []Expr {
// *Ident x
// *Binary{Op: EQ, X: *Ident, Y: Expr} x=y
// *Unary{Op: STAR} *
// *Unary{Op: STAR, X: *Ident} *args
// *Unary{Op: STARSTAR, X: *Ident} **kwargs
func (p *parser) parseParams(lambda bool) []Expr {
var params []Expr
for p.tok != RPAREN && p.tok != COLON && p.tok != EOF {
if len(params) > 0 {
Expand All @@ -446,7 +459,9 @@ func (p *parser) parseParams() []Expr {
pos := p.nextToken()
var x Expr
if op == STARSTAR || p.tok == IDENT {
x = p.parseIdent()
id := p.parseIdent()
id.TypeHint = p.maybeTypeHint()
x = id
}
params = append(params, &UnaryExpr{
OpPos: pos,
Expand All @@ -459,6 +474,9 @@ func (p *parser) parseParams() []Expr {
// IDENT
// IDENT = test
id := p.parseIdent()
if !lambda { // type hint syntax not compatible with lambdas
id.TypeHint = p.maybeTypeHint()
}
if p.tok == EQ { // default value
eq := p.nextToken()
dflt := p.parseTest()
Expand All @@ -476,6 +494,16 @@ func (p *parser) parseParams() []Expr {
return params
}

// potentially consume a type hint expression
// returns nil if there is no type hint
func (p *parser) maybeTypeHint() Expr {
if p.tok == COLON {
p.consume(COLON)
return p.parseTest()
}
return nil
}

// parseExpr parses an expression, possible consisting of a
// comma-separated list of 'test' expressions.
//
Expand Down Expand Up @@ -547,7 +575,7 @@ func (p *parser) parseLambda(allowCond bool) Expr {
lambda := p.nextToken()
var params []Expr
if p.tok != COLON {
params = p.parseParams()
params = p.parseParams(true)
}
p.consume(COLON)

Expand Down Expand Up @@ -651,9 +679,10 @@ func init() {
}

// primary_with_suffix = primary
// | primary '.' IDENT
// | primary slice_suffix
// | primary call_suffix
//
// | primary '.' IDENT
// | primary slice_suffix
// | primary call_suffix
func (p *parser) parsePrimaryWithSuffix() Expr {
x := p.parsePrimary()
for {
Expand Down Expand Up @@ -770,12 +799,13 @@ func (p *parser) parseArgs() []Expr {
return args
}

// primary = IDENT
// | INT | FLOAT | STRING | BYTES
// | '[' ... // list literal or comprehension
// | '{' ... // dict literal or comprehension
// | '(' ... // tuple or parenthesized expression
// | ('-'|'+'|'~') primary_with_suffix
// primary = IDENT
//
// | INT | FLOAT | STRING | BYTES
// | '[' ... // list literal or comprehension
// | '{' ... // dict literal or comprehension
// | '(' ... // tuple or parenthesized expression
// | ('-'|'+'|'~') primary_with_suffix
func (p *parser) parsePrimary() Expr {
switch p.tok {
case IDENT:
Expand Down Expand Up @@ -836,9 +866,10 @@ func (p *parser) parsePrimary() Expr {
}

// list = '[' ']'
// | '[' expr ']'
// | '[' expr expr_list ']'
// | '[' expr (FOR loop_variables IN expr)+ ']'
//
// | '[' expr ']'
// | '[' expr expr_list ']'
// | '[' expr (FOR loop_variables IN expr)+ ']'
func (p *parser) parseList() Expr {
lbrack := p.nextToken()
if p.tok == RBRACK {
Expand All @@ -865,8 +896,9 @@ func (p *parser) parseList() Expr {
}

// dict = '{' '}'
// | '{' dict_entry_list '}'
// | '{' dict_entry FOR loop_variables IN expr '}'
//
// | '{' dict_entry_list '}'
// | '{' dict_entry FOR loop_variables IN expr '}'
func (p *parser) parseDict() Expr {
lbrace := p.nextToken()
if p.tok == RBRACE {
Expand Down Expand Up @@ -904,8 +936,9 @@ func (p *parser) parseDictEntry() *DictEntry {
}

// comp_suffix = FOR loopvars IN expr comp_suffix
// | IF expr comp_suffix
// | ']' or ')' (end)
//
// | IF expr comp_suffix
// | ']' or ')' (end)
//
// There can be multiple FOR/IF clauses; the first is always a FOR.
func (p *parser) parseComprehensionSuffix(lbrace Position, body Expr, endBrace Token) Expr {
Expand Down
3 changes: 3 additions & 0 deletions syntax/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ else:
`(DefStmt Name=f Params=(x (UnaryExpr Op=* X=args) (UnaryExpr Op=** X=kwargs)) Body=((BranchStmt Token=pass)))`},
{`def f(**kwargs, *args): pass`,
`(DefStmt Name=f Params=((UnaryExpr Op=** X=kwargs) (UnaryExpr Op=* X=args)) Body=((BranchStmt Token=pass)))`},
{`def f(x, y: str, z: list[str]=None) -> int:
pass`,
`(DefStmt Name=f Params=(x y (BinaryExpr X=z Op== Y=None)) Body=((BranchStmt Token=pass)) ReturnType=int)`},
{`def f(a, b, c=d): pass`,
`(DefStmt Name=f Params=(a b (BinaryExpr X=c Op== Y=d)) Body=((BranchStmt Token=pass)))`},
{`def f(a, b=c, d): pass`,
Expand Down
6 changes: 6 additions & 0 deletions syntax/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
LTLT_EQ // <<=
GTGT_EQ // >>=
STARSTAR // **
ARROW // ->

// Keywords
AND
Expand Down Expand Up @@ -164,6 +165,7 @@ var tokenNames = [...]string{
LTLT_EQ: "<<=",
GTGT_EQ: ">>=",
STARSTAR: "**",
ARROW: "->",
AND: "and",
BREAK: "break",
CONTINUE: "continue",
Expand Down Expand Up @@ -772,6 +774,10 @@ start:
case '+':
return PLUS
case '-':
if sc.peekRune() == '>' {
sc.readRune()
return ARROW
}
return MINUS
case '/':
if sc.peekRune() == '/' {
Expand Down
21 changes: 14 additions & 7 deletions syntax/syntax.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ func (*LoadStmt) stmt() {}
func (*ReturnStmt) stmt() {}

// An AssignStmt represents an assignment:
//
// x = 0
// x, y = y, x
// x += 1
// x += 1
type AssignStmt struct {
commentsRef
OpPos Position
Expand All @@ -119,10 +120,11 @@ func (x *AssignStmt) Span() (start, end Position) {
// A DefStmt represents a function definition.
type DefStmt struct {
commentsRef
Def Position
Name *Ident
Params []Expr // param = ident | ident=expr | * | *ident | **ident
Body []Stmt
Def Position
Name *Ident
Params []Expr // param = ident | ident=expr | * | *ident | **ident
Body []Stmt
ReturnType Expr

Function interface{} // a *resolve.Function, set by resolver
}
Expand Down Expand Up @@ -238,13 +240,18 @@ func (*UnaryExpr) expr() {}
// An Ident represents an identifier.
type Ident struct {
commentsRef
NamePos Position
Name string
NamePos Position
Name string
TypeHint Expr

Binding interface{} // a *resolver.Binding, set by resolver
}

func (x *Ident) Span() (start, end Position) {
if x.TypeHint != nil {
_, end := x.TypeHint.Span()
return x.NamePos, end
}
return x.NamePos, x.NamePos.add(x.Name)
}

Expand Down