In our previous post we talked about how to use the context.Timeout() to ensure our code didn’t run for longer than a specific amount of time. In this post we are going to look at how we can add context to our logic so your users can benefit from it.
The idea is you want to provide code to your users that allows them to use the context functionality so they can, for example, tell you code: run your logic but timeout if it runs for longer than two seconds.
package main
import (
"context"
"fmt"
"log"
"time"
)
const (
defaultTaskD int = 1
defaultTO int = 2
)
// This is the logic that performs that task that can potentially take a lot of time
// We simulate that by adding an artificial sleep
func longRuningThing(data string) (string, error) {
taskDuration := time.Duration(defaultTaskD) * time.Second
log.Printf("start running task")
time.Sleep(taskDuration)
return fmt.Sprintf("hello %s", data), nil
}
// This is the function or method that you'd expose to your users.
// They pass the data you need for your logic but also a context that has a timeout defined in it
// What we do here is to either wait for our computation to complete or to the timeout to occur
func longRunningThingManager(ctx context.Context, data string) (string, error) {
// A new type we use to communicate the result of our computations
type wrapper struct {
result string
err error
}
// Run the actual logic and send the results down a channel
ch := make(chan wrapper, 1)
go func() {
result, err := longRuningThing(data)
ch <- wrapper{result, err}
}()
// block and wait either for the results of our computation or for the timeout
select {
case data := <-ch:
return data.result, data.err
case <-ctx.Done():
return "", ctx.Err()
}
}
func main() {
// prepare the timeout context
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(defaultTO)*time.Second)
defer cancel()
// Compute, but make sure we don't run for more than defaultTO seconds
result, err := longRunningThingManager(ctx, "drio")
if err != nil {
log.Printf("something went wrong: %s", err)
return
}
log.Printf("result: %s ", result)
}