diff --git a/physical/raft/testing.go b/physical/raft/testing.go index 5b2ea601bed8..0a72e3f13cc6 100644 --- a/physical/raft/testing.go +++ b/physical/raft/testing.go @@ -14,25 +14,32 @@ import ( ) func GetRaft(t testing.TB, bootstrap bool, noStoreState bool) (*RaftBackend, string) { - return getRaftInternal(t, bootstrap, defaultRaftConfig(t, bootstrap, noStoreState), nil, nil) + return getRaftInternal(t, bootstrap, defaultRaftConfig(t, bootstrap, noStoreState), nil, nil, nil) } func GetRaftWithConfig(t testing.TB, bootstrap bool, noStoreState bool, conf map[string]string) (*RaftBackend, string) { defaultConf := defaultRaftConfig(t, bootstrap, noStoreState) conf["path"] = defaultConf["path"] conf["doNotStoreLatestState"] = defaultConf["doNotStoreLatestState"] - return getRaftInternal(t, bootstrap, conf, nil, nil) + return getRaftInternal(t, bootstrap, conf, nil, nil, nil) +} + +func GetRaftWithConfigAndSetupOpts(t testing.TB, bootstrap bool, noStoreState bool, conf map[string]string, setupOpts *SetupOpts) (*RaftBackend, string) { + defaultConf := defaultRaftConfig(t, bootstrap, noStoreState) + conf["path"] = defaultConf["path"] + conf["doNotStoreLatestState"] = defaultConf["doNotStoreLatestState"] + return getRaftInternal(t, bootstrap, conf, setupOpts, nil, nil) } func GetRaftWithConfigAndInitFn(t testing.TB, bootstrap bool, noStoreState bool, conf map[string]string, initFn func(b *RaftBackend)) (*RaftBackend, string) { defaultConf := defaultRaftConfig(t, bootstrap, noStoreState) conf["path"] = defaultConf["path"] conf["doNotStoreLatestState"] = defaultConf["doNotStoreLatestState"] - return getRaftInternal(t, bootstrap, conf, nil, initFn) + return getRaftInternal(t, bootstrap, conf, nil, nil, initFn) } func GetRaftWithLogOutput(t testing.TB, bootstrap bool, noStoreState bool, logOutput io.Writer) (*RaftBackend, string) { - return getRaftInternal(t, bootstrap, defaultRaftConfig(t, bootstrap, noStoreState), logOutput, nil) + return getRaftInternal(t, bootstrap, defaultRaftConfig(t, bootstrap, noStoreState), nil, logOutput, nil) } func defaultRaftConfig(t testing.TB, bootstrap bool, noStoreState bool) map[string]string { @@ -51,7 +58,7 @@ func defaultRaftConfig(t testing.TB, bootstrap bool, noStoreState bool) map[stri return conf } -func getRaftInternal(t testing.TB, bootstrap bool, conf map[string]string, logOutput io.Writer, initFn func(b *RaftBackend)) (*RaftBackend, string) { +func getRaftInternal(t testing.TB, bootstrap bool, conf map[string]string, setupOpts *SetupOpts, logOutput io.Writer, initFn func(b *RaftBackend)) (*RaftBackend, string) { id, err := uuid.GenerateUUID() if err != nil { t.Fatal(err) @@ -85,7 +92,12 @@ func getRaftInternal(t testing.TB, bootstrap bool, conf map[string]string, logOu t.Fatal(err) } - err = backend.SetupCluster(context.Background(), SetupOpts{}) + so := SetupOpts{} + if setupOpts != nil { + so = *setupOpts + } + + err = backend.SetupCluster(context.Background(), so) if err != nil { t.Fatal(err) }