diff --git a/starlark/hashtable.go b/starlark/hashtable.go index 7a427618..e1bbeaaf 100644 --- a/starlark/hashtable.go +++ b/starlark/hashtable.go @@ -416,6 +416,16 @@ func (it *keyIterator) Done() { } } +// entries is a go1.23 iterator over the entries of the hash table. +func (ht *hashtable) entries(yield func(k, v Value) bool) { + if !ht.frozen { + ht.itercount++ + defer func() { ht.itercount-- }() + } + for e := ht.head; e != nil && yield(e.key, e.value); e = e.next { + } +} + var seed = maphash.MakeSeed() // hashString computes the hash of s. diff --git a/starlark/iterator_test.go b/starlark/iterator_test.go new file mode 100644 index 00000000..bcf23495 --- /dev/null +++ b/starlark/iterator_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 The Bazel Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.23 + +package starlark_test + +// This file defines tests of the starlark.Value Go API's go1.23 iterators: +// +// ({Tuple,*List,(Set}).Elements +// Elements +// (*Dict).Entries +// Entries + +import ( + "fmt" + "reflect" + "testing" + + . "go.starlark.net/starlark" +) + +func TestTupleElements(t *testing.T) { + tuple := Tuple{MakeInt(1), MakeInt(2), MakeInt(3)} + + var got []string + for elem := range tuple.Elements { + got = append(got, fmt.Sprint(elem)) + if len(got) == 2 { + break // skip 3 + } + } + for elem := range Elements(tuple) { + got = append(got, fmt.Sprint(elem)) + if len(got) == 4 { + break // skip 3 + } + } + want := []string{"1", "2", "1", "2"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestListElements(t *testing.T) { + list := NewList([]Value{MakeInt(1), MakeInt(2), MakeInt(3)}) + + var got []string + for elem := range list.Elements { + got = append(got, fmt.Sprint(elem)) + if len(got) == 2 { + break // skip 3 + } + } + for elem := range Elements(list) { + got = append(got, fmt.Sprint(elem)) + if len(got) == 4 { + break // skip 3 + } + } + want := []string{"1", "2", "1", "2"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestSetElements(t *testing.T) { + set := NewSet(3) + set.Insert(MakeInt(1)) + set.Insert(MakeInt(2)) + set.Insert(MakeInt(3)) + + var got []string + for elem := range set.Elements { + got = append(got, fmt.Sprint(elem)) + if len(got) == 2 { + break // skip 3 + } + } + for elem := range Elements(set) { + got = append(got, fmt.Sprint(elem)) + if len(got) == 4 { + break // skip 3 + } + } + want := []string{"1", "2", "1", "2"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestDictEntries(t *testing.T) { + d := NewDict(2) + d.SetKey(String("one"), MakeInt(1)) + d.SetKey(String("two"), MakeInt(2)) + d.SetKey(String("three"), MakeInt(3)) + + var got []string + for k, v := range d.Entries { + got = append(got, fmt.Sprintf("%v %v", k, v)) + if len(got) == 2 { + break // skip 3 + } + } + for k, v := range Entries(d) { + got = append(got, fmt.Sprintf("%v %v", k, v)) + if len(got) == 4 { + break // skip 3 + } + } + want := []string{`"one" 1`, `"two" 2`, `"one" 1`, `"two" 2`} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} diff --git a/starlark/value.go b/starlark/value.go index 22a37c8a..c43be962 100644 --- a/starlark/value.go +++ b/starlark/value.go @@ -254,12 +254,17 @@ var ( // // Example usage: // -// iter := iterable.Iterator() +// var seq Iterator = ... +// iter := seq.Iterate() // defer iter.Done() -// var x Value -// for iter.Next(&x) { +// var elem Value +// for iter.Next(elem) { // ... // } +// +// Or, using go1.23 iterators: +// +// for elem := range Elements(seq) { ... } type Iterator interface { // If the iterator is exhausted, Next returns false. // Otherwise it sets *p to the current element of the sequence, @@ -283,6 +288,8 @@ type Mapping interface { } // An IterableMapping is a mapping that supports key enumeration. +// +// See [Entries] for example use. type IterableMapping interface { Mapping Iterate() Iterator // see Iterable interface @@ -847,6 +854,7 @@ func (d *Dict) Type() string { return "dict" func (d *Dict) Freeze() { d.ht.freeze() } func (d *Dict) Truth() Bool { return d.Len() > 0 } func (d *Dict) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: dict") } +func (d *Dict) Entries(yield func(k, v Value) bool) { d.ht.entries(yield) } func (x *Dict) Union(y *Dict) *Dict { z := new(Dict) @@ -892,6 +900,10 @@ func dictsEqual(x, y *Dict, depth int) (bool, error) { } // A *List represents a Starlark list value. +// +// Example of go1.23 iteration: +// +// for elem := range list.Elements { ... } type List struct { elems []Value frozen bool @@ -954,6 +966,23 @@ func (l *List) Iterate() Iterator { return &listIterator{l: l} } +// Elements is a go1.23 iterator over the elements of the list. +// +// Example: +// +// for elem := range l.Elements { ... } +func (l *List) Elements(yield func(Value) bool) { + if !l.frozen { + l.itercount++ + defer func() { l.itercount-- }() + } + for _, x := range l.elems { + if !yield(x) { + break + } + } +} + func (x *List) CompareSameType(op syntax.Token, y_ Value, depth int) (bool, error) { y := y_.(*List) // It's tempting to check x == y as an optimization here, @@ -1053,6 +1082,20 @@ func (t Tuple) Slice(start, end, step int) Value { } func (t Tuple) Iterate() Iterator { return &tupleIterator{elems: t} } + +// Elements is a go1.23 iterator over the elements of the tuple. +// +// (A Tuple is a slice, so it is of course directly iterable. This +// method exists to provide a fast path for the [Elements] standalone +// function.) +func (t Tuple) Elements(yield func(Value) bool) { + for _, x := range t { + if !yield(x) { + break + } + } +} + func (t Tuple) Freeze() { for _, elem := range t { elem.Freeze() @@ -1124,6 +1167,9 @@ func (s *Set) Truth() Bool { return s.Len() > 0 } func (s *Set) Attr(name string) (Value, error) { return builtinAttr(s, name, setMethods) } func (s *Set) AttrNames() []string { return builtinAttrNames(setMethods) } +func (s *Set) Elements(yield func(k Value) bool) { + s.ht.entries(func(k, _ Value) bool { return yield(k) }) +} func (x *Set) CompareSameType(op syntax.Token, y_ Value, depth int) (bool, error) { y := y_.(*Set) @@ -1561,6 +1607,74 @@ func Iterate(x Value) Iterator { return nil } +// Elements returns a "push" iterator for the elements of the iterable value. +// +// Example of go1.23 iteration: +// +// for elem := range Elements(iterable) { ... } +// +// Push iterators are provided as a convience for Go client code. The +// core iteration behavior of Starlark for-loops is defined by the +// [Iterable] interface. +// +// TODO(adonovan): change return type to go1.23 iter.Seq[Value]. +func Elements(iterable Iterable) func(yield func(Value) bool) { + // Use specialized push iterator if available (*List, Tuple, *Set). + type hasElements interface { + Elements(yield func(k Value) bool) + } + if iterable, ok := iterable.(hasElements); ok { + return iterable.Elements + } + + iter := iterable.Iterate() + return func(yield func(Value) bool) { + defer iter.Done() + var x Value + for iter.Next(&x) && yield(x) { + } + } +} + +// Entries returns an iterator over the entries (key/value pairs) of +// the iterable mapping. +// +// Example of go1.23 iteration: +// +// for k, v := range Entries(mapping) { ... } +// +// Push iterators are provided as a convience for Go client code. The +// core iteration behavior of Starlark for-loops is defined by the +// [Iterable] interface. +// +// TODO(adonovan): change return type to go1.23 iter.Seq2[Value, Value]. +func Entries(mapping IterableMapping) func(yield func(k, v Value) bool) { + // If available (e.g *Dict) use specialized push iterator, + // as it gets k and v in one shot. + type hasEntries interface { + Entries(yield func(k, v Value) bool) + } + if mapping, ok := mapping.(hasEntries); ok { + return mapping.Entries + } + + iter := mapping.Iterate() + return func(yield func(k, v Value) bool) { + defer iter.Done() + var k Value + for iter.Next(&k) { + v, found, err := mapping.Get(k) + if err != nil || !found { + panic(fmt.Sprintf("Iterate and Get are inconsistent (mapping=%v, key=%v)", + mapping.Type(), k.Type())) + } + if !yield(k, v) { + break + } + } + } +} + // Bytes is the type of a Starlark binary string. // // A Bytes encapsulates an immutable sequence of bytes.