Skip to content

Commit abc7083

Browse files
authored
support smart anchor option (#662)
1 parent b46780d commit abc7083

File tree

3 files changed

+270
-11
lines changed

3 files changed

+270
-11
lines changed

encode.go

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ type Encoder struct {
3232
isFlowStyle bool
3333
isJSONStyle bool
3434
useJSONMarshaler bool
35+
enableSmartAnchor bool
36+
aliasRefToName map[uintptr]string
37+
anchorRefToName map[uintptr]string
38+
anchorNameMap map[string]struct{}
3539
anchorCallback func(*ast.AnchorNode, interface{}) error
36-
anchorPtrToNameMap map[uintptr]string
3740
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
3841
useLiteralStyleIfMultiline bool
3942
commentMap map[*Path][]*Comment
@@ -53,12 +56,14 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
5356
return &Encoder{
5457
writer: w,
5558
opts: opts,
56-
anchorPtrToNameMap: map[uintptr]string{},
5759
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
5860
line: 1,
5961
column: 1,
6062
offset: 0,
6163
indentNum: DefaultIndentSpaces,
64+
anchorRefToName: make(map[uintptr]string),
65+
anchorNameMap: make(map[string]struct{}),
66+
aliasRefToName: make(map[uintptr]string),
6267
}
6368
}
6469

@@ -110,6 +115,13 @@ func (e *Encoder) EncodeToNodeContext(ctx context.Context, v interface{}) (ast.N
110115
return nil, err
111116
}
112117
}
118+
if e.enableSmartAnchor {
119+
// during the first encoding, store all mappings between alias addresses and their names.
120+
if _, err := e.encodeValue(ctx, reflect.ValueOf(v), 1); err != nil {
121+
return nil, err
122+
}
123+
e.clearSmartAnchorRef()
124+
}
113125
node, err := e.encodeValue(ctx, reflect.ValueOf(v), 1)
114126
if err != nil {
115127
return nil, err
@@ -445,12 +457,8 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
445457
case reflect.Float64:
446458
return e.encodeFloat(v.Float(), 64), nil
447459
case reflect.Ptr:
448-
anchorName := e.anchorPtrToNameMap[v.Pointer()]
449-
if anchorName != "" {
450-
aliasName := anchorName
451-
alias := ast.Alias(token.New("*", "*", e.pos(column)))
452-
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
453-
return alias, nil
460+
if value := e.encodePtrAnchor(v, column); value != nil {
461+
return value, nil
454462
}
455463
return e.encodeValue(ctx, v.Elem(), column)
456464
case reflect.Interface:
@@ -463,6 +471,9 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
463471
if mapSlice, ok := v.Interface().(MapSlice); ok {
464472
return e.encodeMapSlice(ctx, mapSlice, column)
465473
}
474+
if value := e.encodePtrAnchor(v, column); value != nil {
475+
return value, nil
476+
}
466477
return e.encodeSlice(ctx, v)
467478
case reflect.Array:
468479
return e.encodeArray(ctx, v)
@@ -477,12 +488,27 @@ func (e *Encoder) encodeValue(ctx context.Context, v reflect.Value, column int)
477488
}
478489
return e.encodeStruct(ctx, v, column)
479490
case reflect.Map:
491+
if value := e.encodePtrAnchor(v, column); value != nil {
492+
return value, nil
493+
}
480494
return e.encodeMap(ctx, v, column), nil
481495
default:
482496
return nil, fmt.Errorf("unknown value type %s", v.Type().String())
483497
}
484498
}
485499

500+
func (e *Encoder) encodePtrAnchor(v reflect.Value, column int) ast.Node {
501+
anchorName, exists := e.getAnchor(v.Pointer())
502+
if !exists {
503+
return nil
504+
}
505+
aliasName := anchorName
506+
alias := ast.Alias(token.New("*", "*", e.pos(column)))
507+
alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column)))
508+
e.setSmartAlias(aliasName, v.Pointer())
509+
return alias
510+
}
511+
486512
func (e *Encoder) pos(column int) *token.Position {
487513
return &token.Position{
488514
Line: e.line,
@@ -676,11 +702,23 @@ func (e *Encoder) encodeMap(ctx context.Context, value reflect.Value, column int
676702
if e.isTagAndMapNode(value) {
677703
value.AddColumn(e.indentNum)
678704
}
705+
keyText := fmt.Sprint(key)
706+
vRef := e.toPointer(v)
707+
708+
// during the second encoding, an anchor is assigned if it is found to be used by an alias.
709+
if aliasName, exists := e.getSmartAlias(vRef); exists {
710+
anchorName := aliasName
711+
anchorNode := ast.Anchor(token.New("&", "&", e.pos(column)))
712+
anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column)))
713+
anchorNode.Value = value
714+
value = anchorNode
715+
}
679716
node.Values = append(node.Values, ast.MappingValue(
680717
nil,
681-
e.encodeString(fmt.Sprint(key), column),
718+
e.encodeString(keyText, column),
682719
value,
683720
))
721+
e.setSmartAnchor(vRef, keyText)
684722
}
685723
return node
686724
}
@@ -761,7 +799,7 @@ func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue ref
761799
}
762800
}
763801
if fieldValue.Kind() == reflect.Ptr {
764-
e.anchorPtrToNameMap[fieldValue.Pointer()] = anchorName
802+
e.setAnchor(fieldValue.Pointer(), anchorName)
765803
}
766804
return anchorNode, nil
767805
}
@@ -876,9 +914,87 @@ func (e *Encoder) encodeStruct(ctx context.Context, value reflect.Value, column
876914
}
877915
}
878916
if inlineAnchorValue.Kind() == reflect.Ptr {
879-
e.anchorPtrToNameMap[inlineAnchorValue.Pointer()] = anchorName
917+
e.setAnchor(inlineAnchorValue.Pointer(), anchorName)
880918
}
881919
return anchorNode, nil
882920
}
883921
return node, nil
884922
}
923+
924+
func (e *Encoder) toPointer(v reflect.Value) uintptr {
925+
if e.isInvalidValue(v) {
926+
return 0
927+
}
928+
929+
switch v.Type().Kind() {
930+
case reflect.Ptr:
931+
return v.Pointer()
932+
case reflect.Interface:
933+
return e.toPointer(v.Elem())
934+
case reflect.Slice:
935+
return v.Pointer()
936+
case reflect.Map:
937+
return v.Pointer()
938+
}
939+
return 0
940+
}
941+
942+
func (e *Encoder) clearSmartAnchorRef() {
943+
if !e.enableSmartAnchor {
944+
return
945+
}
946+
e.anchorRefToName = make(map[uintptr]string)
947+
e.anchorNameMap = make(map[string]struct{})
948+
}
949+
950+
func (e *Encoder) setSmartAnchor(ptr uintptr, name string) {
951+
if !e.enableSmartAnchor {
952+
return
953+
}
954+
e.setAnchor(ptr, e.generateAnchorName(name))
955+
}
956+
957+
func (e *Encoder) setAnchor(ptr uintptr, name string) {
958+
if ptr == 0 {
959+
return
960+
}
961+
if name == "" {
962+
return
963+
}
964+
e.anchorRefToName[ptr] = name
965+
e.anchorNameMap[name] = struct{}{}
966+
}
967+
968+
func (e *Encoder) generateAnchorName(base string) string {
969+
if _, exists := e.anchorNameMap[base]; !exists {
970+
return base
971+
}
972+
for i := 1; i < 100; i++ {
973+
name := base + strconv.Itoa(i)
974+
if _, exists := e.anchorNameMap[name]; exists {
975+
continue
976+
}
977+
return name
978+
}
979+
return ""
980+
}
981+
982+
func (e *Encoder) getAnchor(ref uintptr) (string, bool) {
983+
anchorName, exists := e.anchorRefToName[ref]
984+
return anchorName, exists
985+
}
986+
987+
func (e *Encoder) setSmartAlias(name string, ref uintptr) {
988+
if !e.enableSmartAnchor {
989+
return
990+
}
991+
e.aliasRefToName[ref] = name
992+
}
993+
994+
func (e *Encoder) getSmartAlias(ref uintptr) (string, bool) {
995+
if !e.enableSmartAnchor {
996+
return "", false
997+
}
998+
aliasName, exists := e.aliasRefToName[ref]
999+
return aliasName, exists
1000+
}

option.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ func Flow(isFlowStyle bool) EncodeOption {
143143
}
144144
}
145145

146+
// WithSmartAnchor when multiple map values share the same pointer,
147+
// an anchor is automatically assigned to the first occurrence, and aliases are used for subsequent elements.
148+
// The map key name is used as the anchor name by default.
149+
// If key names conflict, a suffix is automatically added to avoid collisions.
150+
// This is an experimental feature and cannot be used simultaneously with anchor tags.
151+
func WithSmartAnchor() EncodeOption {
152+
return func(e *Encoder) error {
153+
e.enableSmartAnchor = true
154+
return nil
155+
}
156+
}
157+
146158
// UseLiteralStyleIfMultiline causes encoding multiline strings with a literal syntax,
147159
// no matter what characters they include
148160
func UseLiteralStyleIfMultiline(useLiteralStyleIfMultiline bool) EncodeOption {

yaml_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,134 @@ d: &d [*c,*c]
9898
}
9999
}
100100
}
101+
102+
func TestSmartAnchor(t *testing.T) {
103+
var data = `
104+
a: &a [_,_,_,_,_,_,_,_,_,_,_,_,_,_,_]
105+
b: &b [*a,*a,*a,*a,*a,*a,*a,*a,*a,*a]
106+
c: &c [*b,*b,*b,*b,*b,*b,*b,*b,*b,*b]
107+
d: &d [*c,*c,*c,*c,*c,*c,*c,*c,*c,*c]
108+
e: &e [*d,*d,*d,*d,*d,*d,*d,*d,*d,*d]
109+
f: &f [*e,*e,*e,*e,*e,*e,*e,*e,*e,*e]
110+
g: &g [*f,*f,*f,*f,*f,*f,*f,*f,*f,*f]
111+
h: &h [*g,*g,*g,*g,*g,*g,*g,*g,*g,*g]
112+
i: &i [*h,*h,*h,*h,*h,*h,*h,*h,*h,*h]
113+
`
114+
var v any
115+
if err := yaml.Unmarshal([]byte(data), &v); err != nil {
116+
t.Fatal(err)
117+
}
118+
got, err := yaml.MarshalWithOptions(v, yaml.WithSmartAnchor())
119+
if err != nil {
120+
t.Fatal(err)
121+
}
122+
expected := `
123+
a: &a
124+
- _
125+
- _
126+
- _
127+
- _
128+
- _
129+
- _
130+
- _
131+
- _
132+
- _
133+
- _
134+
- _
135+
- _
136+
- _
137+
- _
138+
- _
139+
b: &b
140+
- *a
141+
- *a
142+
- *a
143+
- *a
144+
- *a
145+
- *a
146+
- *a
147+
- *a
148+
- *a
149+
- *a
150+
c: &c
151+
- *b
152+
- *b
153+
- *b
154+
- *b
155+
- *b
156+
- *b
157+
- *b
158+
- *b
159+
- *b
160+
- *b
161+
d: &d
162+
- *c
163+
- *c
164+
- *c
165+
- *c
166+
- *c
167+
- *c
168+
- *c
169+
- *c
170+
- *c
171+
- *c
172+
e: &e
173+
- *d
174+
- *d
175+
- *d
176+
- *d
177+
- *d
178+
- *d
179+
- *d
180+
- *d
181+
- *d
182+
- *d
183+
f: &f
184+
- *e
185+
- *e
186+
- *e
187+
- *e
188+
- *e
189+
- *e
190+
- *e
191+
- *e
192+
- *e
193+
- *e
194+
g: &g
195+
- *f
196+
- *f
197+
- *f
198+
- *f
199+
- *f
200+
- *f
201+
- *f
202+
- *f
203+
- *f
204+
- *f
205+
h: &h
206+
- *g
207+
- *g
208+
- *g
209+
- *g
210+
- *g
211+
- *g
212+
- *g
213+
- *g
214+
- *g
215+
- *g
216+
i:
217+
- *h
218+
- *h
219+
- *h
220+
- *h
221+
- *h
222+
- *h
223+
- *h
224+
- *h
225+
- *h
226+
- *h
227+
`
228+
if strings.TrimPrefix(expected, "\n") != string(got) {
229+
t.Fatalf("failed to encode: %s", string(got))
230+
}
231+
}

0 commit comments

Comments
 (0)