Skip to content

Commit f82d6f2

Browse files
committed
Adds Persistent{Pre,Post}Run hook chaining
PersistentPreRun and PersistentPostRun are chained together so that each child PersistentPreRun is ran, and the PersistentPostRun are ran in reverse order. For example: Commands: root -> subcommand-a -> subcommand-b root - PersistentPreRun subcommand-a - PersistentPreRun subcommand-b - PersistentPreRun subcommand-b - Run subcommand-b - PersistentPostRun subcommand-a - PersistentPostRun root - PersistentPostRun fixes spf13#252
1 parent b97b5ea commit f82d6f2

File tree

2 files changed

+382
-44
lines changed

2 files changed

+382
-44
lines changed

command.go

Lines changed: 130 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,21 @@ type Command struct {
105105
// * PersistentPostRun()
106106
// All functions get the same args, the arguments after the command name.
107107
//
108+
// When TraverseChildrenHooks is set, PersistentPreRun and
109+
// PersistentPostRun are chained together so that each child
110+
// PersistentPreRun is ran, and the PersistentPostRun are ran in reverse
111+
// order. For example:
112+
//
113+
// Commands: root -> subcommand-a -> subcommand-b
114+
//
115+
// root - PersistentPreRun
116+
// subcommand-a - PersistentPreRun
117+
// subcommand-b - PersistentPreRun
118+
// subcommand-b - Run
119+
// subcommand-b - PersistentPostRun
120+
// subcommand-a - PersistentPostRun
121+
// root - PersistentPostRun
122+
//
108123
// PersistentPreRun: children of this command will inherit and execute.
109124
PersistentPreRun func(cmd *Command, args []string)
110125
// PersistentPreRunE: PersistentPreRun but returns an error.
@@ -154,6 +169,11 @@ type Command struct {
154169
// TraverseChildren parses flags on all parents before executing child command.
155170
TraverseChildren bool
156171

172+
// TraverseChildrenHooks will have each subcommand's PersistentPreRun and
173+
// PersistentPostRun instead of overriding. It should be set on the root
174+
// command.
175+
TraverseChildrenHooks bool
176+
157177
// FParseErrWhitelist flag parse errors to be ignored
158178
FParseErrWhitelist FParseErrWhitelist
159179

@@ -824,55 +844,130 @@ func (c *Command) execute(a []string) (err error) {
824844
return err
825845
}
826846

827-
for p := c; p != nil; p = p.Parent() {
828-
if p.PersistentPreRunE != nil {
829-
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
830-
return err
847+
// Look to see if TraverseChildrenHooks is set on the root command.
848+
if _, err := c.runTree(c, argWoFlags, c.traverseChildrenHooks()); err != nil {
849+
return err
850+
}
851+
852+
return nil
853+
}
854+
855+
func (c *Command) traverseChildrenHooks() bool {
856+
if c.HasParent() {
857+
return c.Parent().traverseChildrenHooks()
858+
}
859+
860+
return c.TraverseChildrenHooks
861+
}
862+
863+
func (c *Command) runTree(
864+
cmd *Command,
865+
args []string,
866+
traverseChildrenHooks bool,
867+
) (
868+
persistentPostRunEs []func(cmd *Command, args []string) error,
869+
err error,
870+
) {
871+
if c == nil {
872+
return nil, nil
873+
}
874+
875+
// Traverse command tree and save the PersistentPostRun{,E} functions.
876+
persistentPostRunEs, err = c.Parent().runTree(cmd, args, traverseChildrenHooks)
877+
if err != nil {
878+
return nil, err
879+
}
880+
881+
if traverseChildrenHooks || c == cmd {
882+
// PersistentPreRun/PersistentPreRunE
883+
switch {
884+
case c.PersistentPreRun != nil:
885+
c.PersistentPreRun(cmd, args)
886+
case c.PersistentPreRunE != nil:
887+
if err := c.PersistentPreRunE(cmd, args); err != nil {
888+
return nil, err
831889
}
832-
break
833-
} else if p.PersistentPreRun != nil {
834-
p.PersistentPreRun(c, argWoFlags)
835-
break
890+
default:
891+
// Doesn't have a registered PersistentPreRun{,E}. Move on...
892+
}
893+
894+
// PersistentPostRun/PersistentPostRunE
895+
switch {
896+
case c.PersistentPostRun != nil:
897+
persistentPostRunEs = append(
898+
persistentPostRunEs,
899+
func(cmd *Command, args []string) error {
900+
c.PersistentPostRun(cmd, args)
901+
return nil
902+
},
903+
)
904+
case c.PersistentPostRunE != nil:
905+
persistentPostRunEs = append(
906+
persistentPostRunEs,
907+
c.PersistentPostRunE,
908+
)
909+
default:
910+
// Doesn't have a registered PersistentPostRun{,E}. Move on...
836911
}
837912
}
838-
if c.PreRunE != nil {
839-
if err := c.PreRunE(c, argWoFlags); err != nil {
840-
return err
913+
914+
if c != cmd {
915+
// Don't run a parent command.
916+
return persistentPostRunEs, nil
917+
}
918+
919+
// PreRun/PreRunE
920+
switch {
921+
case c.PreRun != nil:
922+
c.PreRun(cmd, args)
923+
case c.PreRunE != nil:
924+
if err := c.PreRunE(cmd, args); err != nil {
925+
return nil, err
841926
}
842-
} else if c.PreRun != nil {
843-
c.PreRun(c, argWoFlags)
927+
default:
928+
// Doesn't have a registered PreRun{,E}. Move on...
844929
}
845930

846931
if err := c.validateRequiredFlags(); err != nil {
847-
return err
932+
return nil, err
848933
}
849-
if c.RunE != nil {
850-
if err := c.RunE(c, argWoFlags); err != nil {
851-
return err
934+
935+
// Run/RunE
936+
switch {
937+
case c.RunE != nil:
938+
if err := c.RunE(cmd, args); err != nil {
939+
return nil, err
852940
}
853-
} else {
854-
c.Run(c, argWoFlags)
855-
}
856-
if c.PostRunE != nil {
857-
if err := c.PostRunE(c, argWoFlags); err != nil {
858-
return err
941+
case c.Run != nil:
942+
c.Run(cmd, args)
943+
default:
944+
// Both RunE and Run are nil...
945+
panic(fmt.Sprintf("command %q does not have a non-nil RunE or Run function", c.Use))
946+
}
947+
948+
// PostRun/PostRunE
949+
switch {
950+
case c.PostRun != nil:
951+
c.PostRun(cmd, args)
952+
case c.PostRunE != nil:
953+
if err := c.PostRunE(cmd, args); err != nil {
954+
return nil, err
859955
}
860-
} else if c.PostRun != nil {
861-
c.PostRun(c, argWoFlags)
956+
default:
957+
// Doesn't have a registered PostRun{,E}. Move on...
862958
}
863-
for p := c; p != nil; p = p.Parent() {
864-
if p.PersistentPostRunE != nil {
865-
if err := p.PersistentPostRunE(c, argWoFlags); err != nil {
866-
return err
867-
}
868-
break
869-
} else if p.PersistentPostRun != nil {
870-
p.PersistentPostRun(c, argWoFlags)
871-
break
959+
960+
// PersistentPostRun/PersistentPostRunE
961+
// Iterate through the list in reverse order. Similar to a defer, allow
962+
// the topmost commands to cleanup first.
963+
for i := range persistentPostRunEs {
964+
r := persistentPostRunEs[len(persistentPostRunEs)-1-i]
965+
if err := r(cmd, args); err != nil {
966+
return nil, err
872967
}
873968
}
874969

875-
return nil
970+
return nil, nil
876971
}
877972

878973
func (c *Command) preRun() {

0 commit comments

Comments
 (0)