diff --git a/error.go b/error.go index aaf4a0b..fd92ae9 100644 --- a/error.go +++ b/error.go @@ -10,7 +10,8 @@ type JobErrorCode string const ( ErrPrecedentStepFailure JobErrorCode = "precedent step failed" ErrStepFailed JobErrorCode = "step failed" - ErrStepNotInJob JobErrorCode = "trying to reference to a step not registered in job" + ErrRefStepNotInJob JobErrorCode = "trying to reference to a step not registered in job" + ErrAddStepInSealedJob JobErrorCode = "trying to add step to a sealed job definition" ) func (code JobErrorCode) Error() string { diff --git a/go.mod b/go.mod index f56ddb7..c319393 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,15 @@ -module github.com/Azure/go-asyncjob/v2 +module github.com/Azure/go-asyncjob go 1.18 require ( - github.com/Azure/go-asyncjob/graph v0.2.0 // indirect - github.com/Azure/go-asynctask v1.3.1 // indirect + github.com/Azure/go-asyncjob/graph v0.2.0 + github.com/Azure/go-asynctask v1.3.1 + github.com/stretchr/testify v1.8.1 +) + +require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect - github.com/stretchr/testify v1.8.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4262283..b1aea8a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/Azure/go-asyncjob/graph v0.1.0 h1:qisFc4PtgaE2FDE41GRcbk2eASsR12OecJeD8qk6fkc= -github.com/Azure/go-asyncjob/graph v0.1.0/go.mod h1:3Z7w9aUBIrDriypH8O+hK0aeqKWKYuKSNxwrDxFy34s= github.com/Azure/go-asyncjob/graph v0.2.0 h1:0GFnQit3+ZUxpc67ogusooa38GSFRPH2e1+h+L/33hc= github.com/Azure/go-asyncjob/graph v0.2.0/go.mod h1:3Z7w9aUBIrDriypH8O+hK0aeqKWKYuKSNxwrDxFy34s= github.com/Azure/go-asynctask v1.3.1 h1:zE/7Zwbdg7/+V2kRKb3IV4RTqmn8DUKriVzXcNq7ubg= @@ -11,12 +9,12 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/job.go b/job.go index 820a68b..05883a5 100644 --- a/job.go +++ b/job.go @@ -13,6 +13,8 @@ import ( type JobDefinitionMeta interface { GetName() string GetStep(stepName string) (StepDefinitionMeta, bool) // TODO: switch bool to error + Seal() + Sealed() bool // not exposing for now. addStep(step StepDefinitionMeta, precedingSteps ...StepDefinitionMeta) @@ -21,7 +23,9 @@ type JobDefinitionMeta interface { // JobDefinition defines a job with child steps, and step is organized in a Directed Acyclic Graph (DAG). type JobDefinition[T any] struct { - Name string + name string + + sealed bool steps map[string]StepDefinitionMeta stepsDag *graph.Graph[StepDefinitionMeta] rootStep *StepDefinition[T] @@ -31,7 +35,7 @@ type JobDefinition[T any] struct { // it is suggest to build jobDefinition statically on process start, and reuse it for each job instance. func NewJobDefinition[T any](name string) *JobDefinition[T] { j := &JobDefinition[T]{ - Name: name, + name: name, steps: make(map[string]StepDefinitionMeta), stepsDag: graph.NewGraph[StepDefinitionMeta](connectStepDefinition), } @@ -49,6 +53,9 @@ func NewJobDefinition[T any](name string) *JobDefinition[T] { // this will create and return new instance of the job // caller will then be able to wait for the job instance func (jd *JobDefinition[T]) Start(ctx context.Context, input *T, jobOptions ...JobOptionPreparer) *JobInstance[T] { + if !jd.Sealed() { + jd.Seal() + } ji := newJobInstance(jd, input, jobOptions...) ji.start(ctx) @@ -61,7 +68,18 @@ func (jd *JobDefinition[T]) getRootStep() StepDefinitionMeta { } func (jd *JobDefinition[T]) GetName() string { - return jd.Name + return jd.name +} + +func (jd *JobDefinition[T]) Seal() { + if jd.sealed { + return + } + jd.sealed = true +} + +func (jd *JobDefinition[T]) Sealed() bool { + return jd.sealed } // GetStep returns the stepDefinition by name @@ -155,7 +173,7 @@ func (ji *JobInstance[T]) start(ctx context.Context) { // construct job instance graph, with TopologySort ordering orderedSteps := ji.Definition.stepsDag.TopologicalSort() for _, stepDef := range orderedSteps { - if stepDef.GetName() == ji.Definition.Name { + if stepDef.GetName() == ji.Definition.GetName() { continue } ji.steps[stepDef.GetName()] = stepDef.createStepInstance(ctx, ji) diff --git a/job_result.go b/job_result.go index 6a9d979..fca0cec 100644 --- a/job_result.go +++ b/job_result.go @@ -12,7 +12,7 @@ type JobDefinitionWithResult[Tin, Tout any] struct { func JobWithResult[Tin, Tout any](jd *JobDefinition[Tin], resultStep *StepDefinition[Tout]) (*JobDefinitionWithResult[Tin, Tout], error) { sdGet, ok := jd.GetStep(resultStep.GetName()) if !ok || sdGet != resultStep { - return nil, ErrStepNotInJob + return nil, ErrRefStepNotInJob } return &JobDefinitionWithResult[Tin, Tout]{ diff --git a/job_result_test.go b/job_result_test.go index b5f824b..a79f558 100644 --- a/job_result_test.go +++ b/job_result_test.go @@ -4,18 +4,13 @@ import ( "context" "testing" - "github.com/Azure/go-asyncjob/v2" "github.com/stretchr/testify/assert" ) func TestSimpleJobWithResult(t *testing.T) { t.Parallel() - jd, err := BuildJobWithResult(context.Background(), map[string]asyncjob.RetryPolicy{}) - assert.NoError(t, err) - renderGraph(t, jd) - - jobInstance := jd.Start(context.WithValue(context.Background(), testLoggingContextKey, t), &SqlSummaryJobLibAdvanced{ + jobInstance := SqlSummaryAsyncJobDefinition.Start(context.WithValue(context.Background(), testLoggingContextKey, t), &SqlSummaryJobLib{ Params: &SqlSummaryJobParameters{ ServerName: "server2", Table1: "table3", @@ -23,7 +18,6 @@ func TestSimpleJobWithResult(t *testing.T) { Table2: "table4", Query2: "query4", }, - SqlSummaryJobLib: SqlSummaryJobLib{}, }) jobErr := jobInstance.Wait(context.Background()) assert.NoError(t, jobErr) diff --git a/job_test.go b/job_test.go index 7a46082..0d6fbd5 100644 --- a/job_test.go +++ b/job_test.go @@ -7,18 +7,14 @@ import ( "testing" "time" - "github.com/Azure/go-asyncjob/v2" + "github.com/Azure/go-asyncjob" "github.com/stretchr/testify/assert" ) func TestSimpleJob(t *testing.T) { t.Parallel() - jd, err := BuildJob(context.Background(), map[string]asyncjob.RetryPolicy{}) - assert.NoError(t, err) - renderGraph(t, jd) - - jobInstance := jd.Start(context.WithValue(context.Background(), testLoggingContextKey, t), &SqlSummaryJobLibAdvanced{ + jobInstance := SqlSummaryAsyncJobDefinition.Start(context.WithValue(context.Background(), testLoggingContextKey, t), &SqlSummaryJobLib{ Params: &SqlSummaryJobParameters{ ServerName: "server1", Table1: "table1", @@ -26,13 +22,12 @@ func TestSimpleJob(t *testing.T) { Table2: "table2", Query2: "query2", }, - SqlSummaryJobLib: SqlSummaryJobLib{}, }) jobErr := jobInstance.Wait(context.Background()) assert.NoError(t, jobErr) renderGraph(t, jobInstance) - jobInstance2 := jd.Start(context.WithValue(context.Background(), testLoggingContextKey, t), &SqlSummaryJobLibAdvanced{ + jobInstance2 := SqlSummaryAsyncJobDefinition.Start(context.WithValue(context.Background(), testLoggingContextKey, t), &SqlSummaryJobLib{ Params: &SqlSummaryJobParameters{ ServerName: "server2", Table1: "table3", @@ -40,7 +35,6 @@ func TestSimpleJob(t *testing.T) { Table2: "table4", Query2: "query4", }, - SqlSummaryJobLib: SqlSummaryJobLib{}, }) jobErr = jobInstance2.Wait(context.Background()) assert.NoError(t, jobErr) @@ -50,23 +44,21 @@ func TestSimpleJob(t *testing.T) { func TestJobError(t *testing.T) { t.Parallel() - jd, err := BuildJob(context.Background(), map[string]asyncjob.RetryPolicy{}) - assert.NoError(t, err) - ctx := context.WithValue(context.Background(), testLoggingContextKey, t) - ctx = context.WithValue(ctx, "error-injection.server1.table1", fmt.Errorf("table1 not exists")) - jobInstance := jd.Start(ctx, &SqlSummaryJobLibAdvanced{ + jobInstance := SqlSummaryAsyncJobDefinition.Start(ctx, &SqlSummaryJobLib{ Params: &SqlSummaryJobParameters{ ServerName: "server1", Table1: "table1", Query1: "query1", Table2: "table2", Query2: "query2", + ErrorInjection: map[string]func() error{ + "GetTableClient.server1.table1": func() error { return fmt.Errorf("table1 not exists") }, + }, }, - SqlSummaryJobLib: SqlSummaryJobLib{}, }) - err = jobInstance.Wait(context.Background()) + err := jobInstance.Wait(context.Background()) assert.Error(t, err) jobErr := &asyncjob.JobError{} @@ -77,30 +69,28 @@ func TestJobError(t *testing.T) { func TestJobPanic(t *testing.T) { t.Parallel() - jd, err := BuildJob(context.Background(), map[string]asyncjob.RetryPolicy{}) - assert.NoError(t, err) ctx := context.WithValue(context.Background(), testLoggingContextKey, t) - ctx = context.WithValue(ctx, "panic-injection.server1.table2", true) - jobInstance := jd.Start(ctx, &SqlSummaryJobLibAdvanced{ + jobInstance := SqlSummaryAsyncJobDefinition.Start(ctx, &SqlSummaryJobLib{ Params: &SqlSummaryJobParameters{ ServerName: "server1", Table1: "table1", Query1: "query1", Table2: "table2", Query2: "query2", + PanicInjection: map[string]bool{ + "GetTableClient.server1.table2": true, + }, }, - SqlSummaryJobLib: SqlSummaryJobLib{}, }) - err = jobInstance.Wait(context.Background()) + err := jobInstance.Wait(context.Background()) assert.Error(t, err) - /* panic is out of reach of jobError, but planning to catch panic in the future jobErr := &asyncjob.JobError{} assert.True(t, errors.As(err, &jobErr)) assert.Equal(t, jobErr.Code, asyncjob.ErrStepFailed) - assert.Equal(t, jobErr.StepName, "getTableClient1")*/ + assert.Equal(t, jobErr.StepInstance.GetName(), "GetTableClient2") } func TestJobStepRetry(t *testing.T) { @@ -110,15 +100,17 @@ func TestJobStepRetry(t *testing.T) { ctx := context.WithValue(context.Background(), testLoggingContextKey, t) ctx = context.WithValue(ctx, "error-injection.server1.table1.query1", fmt.Errorf("query exeeded memory limit")) - jobInstance := jd.Start(ctx, &SqlSummaryJobLibAdvanced{ + jobInstance := jd.Start(ctx, &SqlSummaryJobLib{ Params: &SqlSummaryJobParameters{ ServerName: "server1", Table1: "table1", Query1: "query1", Table2: "table2", Query2: "query2", + ErrorInjection: map[string]func() error{ + "ExecuteQuery.server1.table1.query1": func() error { return fmt.Errorf("query exeeded memory limit") }, + }, }, - SqlSummaryJobLib: SqlSummaryJobLib{}, }) err = jobInstance.Wait(context.Background()) diff --git a/retryer.go b/retryer.go index 9ee2b31..74e647b 100644 --- a/retryer.go +++ b/retryer.go @@ -1,8 +1,6 @@ package asyncjob import ( - "fmt" - "runtime/debug" "time" ) @@ -17,24 +15,13 @@ func newRetryer[T any](policy RetryPolicy, report *RetryReport, toRetry func() ( return &retryer[T]{retryPolicy: policy, retryReport: report, function: toRetry} } -func (r *retryer[T]) funcWithPanicHandled() (result *T, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("Panic cought: %v, StackTrace: %s", r, debug.Stack()) - } - }() - result, err = r.function() - - return result, err -} - func (r retryer[T]) Run() (*T, error) { - t, err := r.funcWithPanicHandled() + t, err := r.function() for err != nil { if shouldRetry, duration := r.retryPolicy.ShouldRetry(err); shouldRetry { r.retryReport.Count++ time.Sleep(duration) - t, err = r.funcWithPanicHandled() + t, err = r.function() } else { break } diff --git a/step.go b/step.go index 099764a..c89f8c8 100644 --- a/step.go +++ b/step.go @@ -154,7 +154,7 @@ func (si *StepInstance[T]) GetState() StepState { func (si *StepInstance[T]) EnrichContext(ctx context.Context) (result context.Context) { result = ctx if si.Definition.executionOptions.ContextPolicy != nil { - // handle panic from user code + // TODO: bubble up the error somehow defer func() { if r := recover(); r != nil { fmt.Println("Recovered in EnrichContext", r) diff --git a/step_builder.go b/step_builder.go index 5eaaf32..1d53d9d 100644 --- a/step_builder.go +++ b/step_builder.go @@ -3,6 +3,7 @@ package asyncjob import ( "context" "fmt" + "runtime/debug" "time" "github.com/Azure/go-asynctask" @@ -10,6 +11,10 @@ import ( // AddStep adds a step to the job definition. func AddStep[JT, ST any](bCtx context.Context, j *JobDefinition[JT], stepName string, stepFuncCreator func(input *JT) asynctask.AsyncFunc[ST], optionDecorators ...ExecutionOptionPreparer) (*StepDefinition[ST], error) { + if j.Sealed() { + return nil, ErrAddStepInSealedJob + } + stepD := newStepDefinition[ST](stepName, stepTypeTask, optionDecorators...) precedingDefSteps, err := getDependsOnSteps(stepD, j) if err != nil { @@ -27,8 +32,21 @@ func AddStep[JT, ST any](bCtx context.Context, j *JobDefinition[JT], stepName st precedingInstances, precedingTasks, _ := getDependsOnStepInstances(stepD, ji) jiStrongTyped := ji.(*JobInstance[JT]) + stepFunc := stepFuncCreator(jiStrongTyped.input) + stepFuncWithPanicHandling := func(ctx context.Context) (result *ST, err error) { + // handle panic from user code + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("Panic cought: %v, StackTrace: %s", r, debug.Stack()) + } + }() + + result, err = stepFunc(ctx) + return result, err + } + stepInstance := newStepInstance[ST](stepD, ji) - stepInstance.task = asynctask.Start(ctx, instrumentedAddStep(stepInstance, precedingTasks, stepFuncCreator(jiStrongTyped.input))) + stepInstance.task = asynctask.Start(ctx, instrumentedAddStep(stepInstance, precedingTasks, stepFuncWithPanicHandling)) ji.addStepInstance(stepInstance, precedingInstances...) return stepInstance } @@ -39,6 +57,10 @@ func AddStep[JT, ST any](bCtx context.Context, j *JobDefinition[JT], stepName st // StepAfter add a step after a preceding step, also take input from that preceding step func StepAfter[JT, PT, ST any](bCtx context.Context, j *JobDefinition[JT], stepName string, parentStep *StepDefinition[PT], stepAfterFuncCreator func(input *JT) asynctask.ContinueFunc[PT, ST], optionDecorators ...ExecutionOptionPreparer) (*StepDefinition[ST], error) { + if j.Sealed() { + return nil, ErrAddStepInSealedJob + } + // check parentStepT is in this job if get, ok := j.GetStep(parentStep.GetName()); !ok || get != parentStep { return nil, fmt.Errorf("step [%s] not found in job", parentStep.GetName()) @@ -55,9 +77,22 @@ func StepAfter[JT, PT, ST any](bCtx context.Context, j *JobDefinition[JT], stepN precedingInstances, precedingTasks, _ := getDependsOnStepInstances(stepD, ji) jiStrongTyped := ji.(*JobInstance[JT]) + stepFunc := stepAfterFuncCreator(jiStrongTyped.input) + stepFuncWithPanicHandling := func(ctx context.Context, pt *PT) (result *ST, err error) { + // handle panic from user code + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("Panic cought: %v, StackTrace: %s", r, debug.Stack()) + } + }() + + result, err = stepFunc(ctx, pt) + return result, err + } + parentStepInstance := getStrongTypedStepInstance(parentStep, ji) stepInstance := newStepInstance[ST](stepD, ji) - stepInstance.task = asynctask.ContinueWith(ctx, parentStepInstance.task, instrumentedStepAfter(stepInstance, precedingTasks, stepAfterFuncCreator(jiStrongTyped.input))) + stepInstance.task = asynctask.ContinueWith(ctx, parentStepInstance.task, instrumentedStepAfter(stepInstance, precedingTasks, stepFuncWithPanicHandling)) ji.addStepInstance(stepInstance, precedingInstances...) return stepInstance } @@ -68,6 +103,10 @@ func StepAfter[JT, PT, ST any](bCtx context.Context, j *JobDefinition[JT], stepN // StepAfterBoth add a step after both preceding steps, also take input from both preceding steps func StepAfterBoth[JT, PT1, PT2, ST any](bCtx context.Context, j *JobDefinition[JT], stepName string, parentStep1 *StepDefinition[PT1], parentStep2 *StepDefinition[PT2], stepAfterBothFuncCreator func(input *JT) asynctask.AfterBothFunc[PT1, PT2, ST], optionDecorators ...ExecutionOptionPreparer) (*StepDefinition[ST], error) { + if j.Sealed() { + return nil, ErrAddStepInSealedJob + } + // check parentStepT is in this job if get, ok := j.GetStep(parentStep1.GetName()); !ok || get != parentStep1 { return nil, fmt.Errorf("step [%s] not found in job", parentStep1.GetName()) @@ -93,10 +132,22 @@ func StepAfterBoth[JT, PT1, PT2, ST any](bCtx context.Context, j *JobDefinition[ precedingInstances, precedingTasks, _ := getDependsOnStepInstances(stepD, ji) jiStrongTyped := ji.(*JobInstance[JT]) + stepFunc := stepAfterBothFuncCreator(jiStrongTyped.input) + stepFuncWithPanicHandling := func(ctx context.Context, pt1 *PT1, pt2 *PT2) (result *ST, err error) { + // handle panic from user code + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("Panic cought: %v, StackTrace: %s", r, debug.Stack()) + } + }() + + result, err = stepFunc(ctx, pt1, pt2) + return result, err + } parentStepInstance1 := getStrongTypedStepInstance(parentStep1, ji) parentStepInstance2 := getStrongTypedStepInstance(parentStep2, ji) stepInstance := newStepInstance[ST](stepD, ji) - stepInstance.task = asynctask.AfterBoth(ctx, parentStepInstance1.task, parentStepInstance2.task, instrumentedStepAfterBoth(stepInstance, precedingTasks, stepAfterBothFuncCreator(jiStrongTyped.input))) + stepInstance.task = asynctask.AfterBoth(ctx, parentStepInstance1.task, parentStepInstance2.task, instrumentedStepAfterBoth(stepInstance, precedingTasks, stepFuncWithPanicHandling)) ji.addStepInstance(stepInstance, precedingInstances...) return stepInstance } diff --git a/util_test.go b/util_test.go index af723dc..50ae090 100644 --- a/util_test.go +++ b/util_test.go @@ -6,85 +6,98 @@ import ( "testing" "time" - "github.com/Azure/go-asyncjob/v2" + "github.com/Azure/go-asyncjob" "github.com/Azure/go-asynctask" ) const testLoggingContextKey = "test-logging" -type SqlSummaryJobLibAdvanced struct { - SqlSummaryJobLib +// SqlSummaryAsyncJobDefinition is the job definition for the SqlSummaryJobLib +// JobDefinition fit perfectly in init() function +var SqlSummaryAsyncJobDefinition *asyncjob.JobDefinitionWithResult[SqlSummaryJobLib, SummarizedResult] + +func init() { + var err error + SqlSummaryAsyncJobDefinition, err = BuildJobWithResult(context.Background(), map[string]asyncjob.RetryPolicy{}) + if err != nil { + panic(err) + } + + SqlSummaryAsyncJobDefinition.Seal() +} + +type SqlSummaryJobLib struct { Params *SqlSummaryJobParameters } -func serverNameStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[string] { +func serverNameStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[string] { return func(ctx context.Context) (*string, error) { return &sql.Params.ServerName, nil } } -func table1NameStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[string] { +func table1NameStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[string] { return func(ctx context.Context) (*string, error) { return &sql.Params.Table1, nil } } -func table2NameStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[string] { +func table2NameStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[string] { return func(ctx context.Context) (*string, error) { return &sql.Params.Table2, nil } } -func query1ParamStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[string] { +func query1ParamStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[string] { return func(ctx context.Context) (*string, error) { return &sql.Params.Query1, nil } } -func query2ParamStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[string] { +func query2ParamStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[string] { return func(ctx context.Context) (*string, error) { return &sql.Params.Query2, nil } } -func connectionStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.ContinueFunc[string, SqlConnection] { +func connectionStepFunc(sql *SqlSummaryJobLib) asynctask.ContinueFunc[string, SqlConnection] { return func(ctx context.Context, serverName *string) (*SqlConnection, error) { return sql.GetConnection(ctx, serverName) } } -func checkAuthStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[interface{}] { +func checkAuthStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[interface{}] { return asynctask.ActionToFunc(func(ctx context.Context) error { return sql.CheckAuth(ctx) }) } -func tableClientStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AfterBothFunc[SqlConnection, string, SqlTableClient] { +func tableClientStepFunc(sql *SqlSummaryJobLib) asynctask.AfterBothFunc[SqlConnection, string, SqlTableClient] { return func(ctx context.Context, conn *SqlConnection, tableName *string) (*SqlTableClient, error) { return sql.GetTableClient(ctx, conn, tableName) } } -func queryTableStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AfterBothFunc[SqlTableClient, string, SqlQueryResult] { +func queryTableStepFunc(sql *SqlSummaryJobLib) asynctask.AfterBothFunc[SqlTableClient, string, SqlQueryResult] { return func(ctx context.Context, tableClient *SqlTableClient, query *string) (*SqlQueryResult, error) { return sql.ExecuteQuery(ctx, tableClient, query) } } -func summarizeQueryResultStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AfterBothFunc[SqlQueryResult, SqlQueryResult, SummarizedResult] { +func summarizeQueryResultStepFunc(sql *SqlSummaryJobLib) asynctask.AfterBothFunc[SqlQueryResult, SqlQueryResult, SummarizedResult] { return func(ctx context.Context, query1Result *SqlQueryResult, query2Result *SqlQueryResult) (*SummarizedResult, error) { return sql.SummarizeQueryResult(ctx, query1Result, query2Result) } } -func emailNotificationStepFunc(sql *SqlSummaryJobLibAdvanced) asynctask.AsyncFunc[interface{}] { +func emailNotificationStepFunc(sql *SqlSummaryJobLib) asynctask.AsyncFunc[interface{}] { return asynctask.ActionToFunc(func(ctx context.Context) error { return sql.EmailNotification(ctx) }) } -func BuildJob(bCtx context.Context, retryPolicies map[string]asyncjob.RetryPolicy) (*asyncjob.JobDefinition[SqlSummaryJobLibAdvanced], error) { - job := asyncjob.NewJobDefinition[SqlSummaryJobLibAdvanced]("sqlSummaryJob") +func BuildJob(bCtx context.Context, retryPolicies map[string]asyncjob.RetryPolicy) (*asyncjob.JobDefinition[SqlSummaryJobLib], error) { + job := asyncjob.NewJobDefinition[SqlSummaryJobLib]("sqlSummaryJob") serverNameParamTask, err := asyncjob.AddStep(bCtx, job, "ServerNameParam", serverNameStepFunc) if err != nil { return nil, fmt.Errorf("error adding step ServerNameParam: %w", err) @@ -152,7 +165,7 @@ func BuildJob(bCtx context.Context, retryPolicies map[string]asyncjob.RetryPolic return job, nil } -func BuildJobWithResult(bCtx context.Context, retryPolicies map[string]asyncjob.RetryPolicy) (*asyncjob.JobDefinitionWithResult[SqlSummaryJobLibAdvanced, SummarizedResult], error) { +func BuildJobWithResult(bCtx context.Context, retryPolicies map[string]asyncjob.RetryPolicy) (*asyncjob.JobDefinitionWithResult[SqlSummaryJobLib, SummarizedResult], error) { job, err := BuildJob(bCtx, retryPolicies) if err != nil { return nil, err @@ -169,9 +182,6 @@ func BuildJobWithResult(bCtx context.Context, retryPolicies map[string]asyncjob. return asyncjob.JobWithResult(job, summaryStep) } -type SqlSummaryJobLib struct { -} - type SqlSummaryJobParameters struct { ServerName string Table1 string @@ -179,6 +189,7 @@ type SqlSummaryJobParameters struct { Table2 string Query2 string ErrorInjection map[string]func() error + PanicInjection map[string]bool } type SqlConnection struct { @@ -201,9 +212,9 @@ type SummarizedResult struct { func (sql *SqlSummaryJobLib) GetConnection(ctx context.Context, serverName *string) (*SqlConnection, error) { sql.Logging(ctx, "GetConnection") - if v := ctx.Value(fmt.Sprintf("error-injection.%s", *serverName)); v != nil { - if err, ok := v.(error); ok { - return nil, err + if sql.Params.ErrorInjection != nil { + if errFunc, ok := sql.Params.ErrorInjection["GetConnection"]; ok { + return nil, errFunc() } } return &SqlConnection{ServerName: *serverName}, nil @@ -211,35 +222,47 @@ func (sql *SqlSummaryJobLib) GetConnection(ctx context.Context, serverName *stri func (sql *SqlSummaryJobLib) GetTableClient(ctx context.Context, conn *SqlConnection, tableName *string) (*SqlTableClient, error) { sql.Logging(ctx, fmt.Sprintf("GetTableClient with tableName: %s", *tableName)) - if v := ctx.Value(fmt.Sprintf("error-injection.%s.%s", conn.ServerName, *tableName)); v != nil { - if err, ok := v.(error); ok { - return nil, err + injectionKey := fmt.Sprintf("GetTableClient.%s.%s", conn.ServerName, *tableName) + if sql.Params.PanicInjection != nil { + if shouldPanic, ok := sql.Params.PanicInjection[injectionKey]; ok && shouldPanic { + panic("as you wish") } } - - if v := ctx.Value(fmt.Sprintf("panic-injection.%s.%s", conn.ServerName, *tableName)); v != nil { - if shouldPanic := v.(bool); shouldPanic { - panic("as you wish") + if sql.Params.ErrorInjection != nil { + if errFunc, ok := sql.Params.ErrorInjection[injectionKey]; ok { + return nil, errFunc() } } return &SqlTableClient{ServerName: conn.ServerName, TableName: *tableName}, nil } func (sql *SqlSummaryJobLib) CheckAuth(ctx context.Context) error { - if v := ctx.Value("error-injection.checkAuth"); v != nil { - if err, ok := v.(error); ok { - return err + sql.Logging(ctx, "CheckAuth") + injectionKey := "CheckAuth" + if sql.Params.PanicInjection != nil { + if shouldPanic, ok := sql.Params.PanicInjection[injectionKey]; ok && shouldPanic { + panic("as you wish") + } + } + if sql.Params.ErrorInjection != nil { + if errFunc, ok := sql.Params.ErrorInjection[injectionKey]; ok { + return errFunc() } } - return nil } func (sql *SqlSummaryJobLib) ExecuteQuery(ctx context.Context, tableClient *SqlTableClient, queryString *string) (*SqlQueryResult, error) { sql.Logging(ctx, fmt.Sprintf("ExecuteQuery: %s", *queryString)) - if v := ctx.Value(fmt.Sprintf("error-injection.%s.%s.%s", tableClient.ServerName, tableClient.TableName, *queryString)); v != nil { - if err, ok := v.(error); ok { - return nil, err + injectionKey := fmt.Sprintf("ExecuteQuery.%s.%s.%s", tableClient.ServerName, tableClient.TableName, *queryString) + if sql.Params.PanicInjection != nil { + if shouldPanic, ok := sql.Params.PanicInjection[injectionKey]; ok && shouldPanic { + panic("as you wish") + } + } + if sql.Params.ErrorInjection != nil { + if errFunc, ok := sql.Params.ErrorInjection[injectionKey]; ok { + return nil, errFunc() } } @@ -248,9 +271,15 @@ func (sql *SqlSummaryJobLib) ExecuteQuery(ctx context.Context, tableClient *SqlT func (sql *SqlSummaryJobLib) SummarizeQueryResult(ctx context.Context, result1 *SqlQueryResult, result2 *SqlQueryResult) (*SummarizedResult, error) { sql.Logging(ctx, "SummarizeQueryResult") - if v := ctx.Value("error-injection.summarize"); v != nil { - if err, ok := v.(error); ok { - return nil, err + injectionKey := "SummarizeQueryResult" + if sql.Params.PanicInjection != nil { + if shouldPanic, ok := sql.Params.PanicInjection[injectionKey]; ok && shouldPanic { + panic("as you wish") + } + } + if sql.Params.ErrorInjection != nil { + if errFunc, ok := sql.Params.ErrorInjection[injectionKey]; ok { + return nil, errFunc() } } return &SummarizedResult{QueryResult1: result1.Data, QueryResult2: result2.Data}, nil