Skip to content

Commit

Permalink
[Go SDK] Container Worker pool functionality. (#33572)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck authored Jan 14, 2025
1 parent cf26a42 commit c0a8ff6
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 9 deletions.
40 changes: 39 additions & 1 deletion sdks/go/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"strings"
"time"

"github.com/apache/beam/sdks/v2/go/container/pool"
"github.com/apache/beam/sdks/v2/go/container/tools"
"github.com/apache/beam/sdks/v2/go/pkg/beam/artifact"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
Expand All @@ -44,6 +45,7 @@ import (
var (
// Contract: https://s.apache.org/beam-fn-api-container-contract.

workerPool = flag.Bool("worker_pool", false, "Run as worker pool (optional).")
id = flag.String("id", "", "Local identifier (required).")
loggingEndpoint = flag.String("logging_endpoint", "", "Local logging endpoint for FnHarness (required).")
artifactEndpoint = flag.String("artifact_endpoint", "", "Local artifact endpoint for FnHarness (required).")
Expand All @@ -56,6 +58,7 @@ const (
cloudProfilingJobName = "CLOUD_PROF_JOB_NAME"
cloudProfilingJobID = "CLOUD_PROF_JOB_ID"
enableGoogleCloudProfilerOption = "enable_google_cloud_profiler"
workerPoolIdEnv = "BEAM_GO_WORKER_POOL_ID"
)

func configureGoogleCloudProfilerEnvVars(ctx context.Context, logger *tools.Logger, metadata map[string]string) error {
Expand All @@ -78,6 +81,30 @@ func configureGoogleCloudProfilerEnvVars(ctx context.Context, logger *tools.Logg

func main() {
flag.Parse()

if *workerPool {
workerPoolId := fmt.Sprintf("%d", os.Getpid())
bin, err := os.Executable()
if err != nil {
log.Fatalf("Error starting worker pool, couldn't find boot loader path: %v", err)
}

os.Setenv(workerPoolIdEnv, workerPoolId)
log.Printf("Starting worker pool %v: Go %v binary: %vv", workerPoolId, ":50000", bin)

ctx := context.Background()
server, err := pool.New(ctx, 50000, bin)
if err != nil {
log.Fatalf("Error starting worker pool: %v", err)
}
defer server.Stop(ctx)
if err := server.ServeAndWait(); err != nil {
log.Fatalf("Error with worker pool: %v", err)
}
log.Print("Go SDK worker pool exited.")
os.Exit(0)
}

if *id == "" {
log.Fatal("No id provided.")
}
Expand Down Expand Up @@ -126,7 +153,13 @@ func main() {

// (3) The persist dir may be on a noexec volume, so we must
// copy the binary to a different location to execute.
const prog = "/bin/worker"
tmpPrefix, err := os.MkdirTemp("/tmp/", "bin*")
if err != nil {
logger.Fatalf(ctx, "Failed to copy worker binary: %v", err)
}

prog := tmpPrefix + "/worker"
logger.Printf(ctx, "From: %q To:%q", filepath.Join(dir, name), prog)
if err := copyExe(filepath.Join(dir, name), prog); err != nil {
logger.Fatalf(ctx, "Failed to copy worker binary: %v", err)
}
Expand Down Expand Up @@ -233,6 +266,11 @@ func copyExe(from, to string) error {
}
defer src.Close()

// Ensure that the folder path exists locally.
if err := os.MkdirAll(filepath.Dir(to), 0755); err != nil {
return err
}

dst, err := os.OpenFile(to, os.O_WRONLY|os.O_CREATE, 0755)
if err != nil {
return err
Expand Down
160 changes: 160 additions & 0 deletions sdks/go/container/pool/workerpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package pool facilitates a external worker service, as an alternate mode for
// the standard Beam container.
//
// This is predeominantly to serve as a process spawner within a given container
// VM for an arbitrary number of jobs, instead of for a single worker instance.
//
// Workers will be spawned as executed OS processes.
package pool

import (
"context"
"fmt"
"log/slog"
"net"
"os"
"os/exec"
"sync"

fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/util/grpcx"
"google.golang.org/grpc"
)

// New initializes a process based ExternalWorkerService, at the given
// port.
func New(ctx context.Context, port int, containerExecutable string) (*Process, error) {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
slog.Info("starting Process server", "addr", lis.Addr())
grpcServer := grpc.NewServer()
root, cancel := context.WithCancel(ctx)
s := &Process{lis: lis, root: root, rootCancel: cancel, workers: map[string]context.CancelFunc{},
grpcServer: grpcServer, containerExecutable: containerExecutable}
fnpb.RegisterBeamFnExternalWorkerPoolServer(grpcServer, s)
return s, nil
}

// ServeAndWait starts the ExternalWorkerService and blocks until exit.
func (s *Process) ServeAndWait() error {
return s.grpcServer.Serve(s.lis)
}

// Process implements fnpb.BeamFnExternalWorkerPoolServer, by starting external
// processes.
type Process struct {
fnpb.UnimplementedBeamFnExternalWorkerPoolServer

containerExecutable string // The host for the container executable.

lis net.Listener
root context.Context
rootCancel context.CancelFunc

mu sync.Mutex
workers map[string]context.CancelFunc

grpcServer *grpc.Server
}

// StartWorker initializes a new worker harness, implementing BeamFnExternalWorkerPoolServer.StartWorker.
func (s *Process) StartWorker(_ context.Context, req *fnpb.StartWorkerRequest) (*fnpb.StartWorkerResponse, error) {
slog.Info("starting worker", "id", req.GetWorkerId())
s.mu.Lock()
defer s.mu.Unlock()
if s.workers == nil {
return &fnpb.StartWorkerResponse{
Error: "worker pool shutting down",
}, nil
}

if _, ok := s.workers[req.GetWorkerId()]; ok {
return &fnpb.StartWorkerResponse{
Error: fmt.Sprintf("worker with ID %q already exists", req.GetWorkerId()),
}, nil
}
if req.GetLoggingEndpoint() == nil {
return &fnpb.StartWorkerResponse{Error: fmt.Sprintf("Missing logging endpoint for worker %v", req.GetWorkerId())}, nil
}
if req.GetControlEndpoint() == nil {
return &fnpb.StartWorkerResponse{Error: fmt.Sprintf("Missing control endpoint for worker %v", req.GetWorkerId())}, nil
}
if req.GetLoggingEndpoint().Authentication != nil || req.GetControlEndpoint().Authentication != nil {
return &fnpb.StartWorkerResponse{Error: "[BEAM-10610] Secure endpoints not supported."}, nil
}

ctx := grpcx.WriteWorkerID(s.root, req.GetWorkerId())
ctx, s.workers[req.GetWorkerId()] = context.WithCancel(ctx)

args := []string{
"--id=" + req.GetWorkerId(),
"--control_endpoint=" + req.GetControlEndpoint().GetUrl(),
"--artifact_endpoint=" + req.GetArtifactEndpoint().GetUrl(),
"--provision_endpoint=" + req.GetProvisionEndpoint().GetUrl(),
"--logging_endpoint=" + req.GetLoggingEndpoint().GetUrl(),
}

cmd := exec.CommandContext(ctx, s.containerExecutable, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = nil // Use the current environment.

if err := cmd.Start(); err != nil {
return &fnpb.StartWorkerResponse{Error: fmt.Sprintf("Unable to start boot for worker %v: %v", req.GetWorkerId(), err)}, nil
}
return &fnpb.StartWorkerResponse{}, nil
}

// StopWorker terminates a worker harness, implementing BeamFnExternalWorkerPoolServer.StopWorker.
func (s *Process) StopWorker(_ context.Context, req *fnpb.StopWorkerRequest) (*fnpb.StopWorkerResponse, error) {
slog.Info("stopping worker", "id", req.GetWorkerId())
s.mu.Lock()
defer s.mu.Unlock()
if s.workers == nil {
// Worker pool is already shutting down, so no action is needed.
return &fnpb.StopWorkerResponse{}, nil
}
if cancelfn, ok := s.workers[req.GetWorkerId()]; ok {
cancelfn()
delete(s.workers, req.GetWorkerId())
return &fnpb.StopWorkerResponse{}, nil
}
return &fnpb.StopWorkerResponse{
Error: fmt.Sprintf("no worker with id %q running", req.GetWorkerId()),
}, nil

}

// Stop terminates the service and stops all workers.
func (s *Process) Stop(ctx context.Context) error {
s.mu.Lock()

slog.Debug("stopping Process", "worker_count", len(s.workers))
s.workers = nil
s.rootCancel()

// There can be a deadlock between the StopWorker RPC and GracefulStop
// which waits for all RPCs to finish, so it must be outside the critical section.
s.mu.Unlock()

s.grpcServer.GracefulStop()
return nil
}
150 changes: 150 additions & 0 deletions sdks/go/container/pool/workerpool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pool

import (
"context"
"os/exec"
"testing"

fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
)

func TestProcess(t *testing.T) {
// Use the no-op true binary, if available, skip this test otherwise.
dummyExec, err := exec.LookPath("true")
if err != nil {
t.Skip("Binary `true` doesn't exist, skipping tests.")
}

endpoint := &pipepb.ApiServiceDescriptor{
Url: "localhost:0",
}
secureEndpoint := &pipepb.ApiServiceDescriptor{
Url: "localhost:0",
Authentication: &pipepb.AuthenticationSpec{
Urn: "beam:authentication:oauth2_client_credentials_grant:v1",
},
}

ctx, cancelFn := context.WithCancel(context.Background())
t.Cleanup(cancelFn)
server, err := New(ctx, 0, dummyExec)
if err != nil {
t.Fatalf("Unable to create server: %v", err)
}
go server.ServeAndWait()

startTests := []struct {
req *fnpb.StartWorkerRequest
errExpected bool
}{
{
req: &fnpb.StartWorkerRequest{
WorkerId: "Worker1",
ControlEndpoint: endpoint,
LoggingEndpoint: endpoint,
},
}, {
req: &fnpb.StartWorkerRequest{
WorkerId: "Worker2",
ControlEndpoint: endpoint,
LoggingEndpoint: endpoint,
},
}, {
req: &fnpb.StartWorkerRequest{
WorkerId: "Worker1",
ControlEndpoint: endpoint,
LoggingEndpoint: endpoint,
},
errExpected: true, // Repeated start
}, {
req: &fnpb.StartWorkerRequest{
WorkerId: "missingControl",
LoggingEndpoint: endpoint,
},
errExpected: true,
}, {
req: &fnpb.StartWorkerRequest{
WorkerId: "missingLogging",
ControlEndpoint: endpoint,
},
errExpected: true,
}, {
req: &fnpb.StartWorkerRequest{
WorkerId: "secureLogging",
LoggingEndpoint: secureEndpoint,
ControlEndpoint: endpoint,
},
errExpected: true,
}, {
req: &fnpb.StartWorkerRequest{
WorkerId: "secureControl",
LoggingEndpoint: endpoint,
ControlEndpoint: secureEndpoint,
},
errExpected: true,
},
}
for _, test := range startTests {
resp, err := server.StartWorker(ctx, test.req)
if test.errExpected {
if err != nil || resp.Error == "" {
t.Errorf("Expected error starting %v: err: %v, resp: %v", test.req.GetWorkerId(), err, resp)
}
} else {
if err != nil || resp.Error != "" {
t.Errorf("Unexpected error starting %v: err: %v, resp: %v", test.req.GetWorkerId(), err, resp)
}
}
}
stopTests := []struct {
req *fnpb.StopWorkerRequest
errExpected bool
}{
{
req: &fnpb.StopWorkerRequest{
WorkerId: "Worker1",
},
}, {
req: &fnpb.StopWorkerRequest{
WorkerId: "Worker1",
},
errExpected: true,
}, {
req: &fnpb.StopWorkerRequest{
WorkerId: "NonExistent",
},
errExpected: true,
},
}
for _, test := range stopTests {
resp, err := server.StopWorker(ctx, test.req)
if test.errExpected {
if err != nil || resp.Error == "" {
t.Errorf("Expected error starting %v: err: %v, resp: %v", test.req.GetWorkerId(), err, resp)
}
} else {
if err != nil || resp.Error != "" {
t.Errorf("Unexpected error starting %v: err: %v, resp: %v", test.req.GetWorkerId(), err, resp)
}
}
}
if err := server.Stop(ctx); err != nil {
t.Fatalf("error stopping server: err: %v", err)
}
}
Loading

0 comments on commit c0a8ff6

Please sign in to comment.