diff --git a/hclwrite/ast_attribute.go b/hclwrite/ast_attribute.go index 3edc68c0..70ca8275 100644 --- a/hclwrite/ast_attribute.go +++ b/hclwrite/ast_attribute.go @@ -49,3 +49,10 @@ func (a *Attribute) init(name string, expr *Expression) { func (a *Attribute) Expr() *Expression { return a.expr.content.(*Expression) } + +// SetName updates the name of the attribute to a given name. +func (a *Attribute) SetName(name string) { + nameTok := newIdentToken(name) + nameObj := newIdentifier(nameTok) + a.name.ReplaceWith(nameObj) +} diff --git a/hclwrite/ast_attribute_test.go b/hclwrite/ast_attribute_test.go new file mode 100644 index 00000000..85ea8d7e --- /dev/null +++ b/hclwrite/ast_attribute_test.go @@ -0,0 +1,81 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package hclwrite + +import ( + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/hclsyntax" +) + +func TestAttributeSetName(t *testing.T) { + t.Parallel() + + tests := []struct { + src string + oldName string + newName string + want Tokens + }{ + { + "old = 123", + "old", + "new", + Tokens{ + { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`new`), + SpacesBefore: 0, + }, + { + Type: hclsyntax.TokenEqual, + Bytes: []byte{'='}, + SpacesBefore: 1, + }, + { + Type: hclsyntax.TokenNumberLit, + Bytes: []byte(`123`), + SpacesBefore: 1, + }, + { + Type: hclsyntax.TokenEOF, + Bytes: []byte{}, + SpacesBefore: 0, + }, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(fmt.Sprintf("%s %s in %s", test.oldName, test.newName, test.src), func(t *testing.T) { + t.Parallel() + + f, diags := ParseConfig([]byte(test.src), "", hcl.Pos{Line: 1, Column: 1}) + + if len(diags) != 0 { + for _, diag := range diags { + t.Logf("- %s", diag.Error()) + } + t.Fatalf("unexpected diagnostics") + } + + attr := f.Body().GetAttribute(test.oldName) + attr.SetName(test.newName) + got := f.BuildTokens(nil) + format(got) + + if !reflect.DeepEqual(got, test.want) { + diff := cmp.Diff(test.want, got) + t.Errorf("wrong result\ngot: %s\nwant: %s\ndiff:\n%s", spew.Sdump(got), spew.Sdump(test.want), diff) + } + }) + } +}