Skip to content

Commit

Permalink
feat(cosmos): make chain and node conform to the new interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Zygimantass committed Jan 10, 2025
1 parent c0cddd3 commit 0e0af2a
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 54 deletions.
4 changes: 4 additions & 0 deletions core/provider/docker/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,7 @@ func (t *Task) ensureVolume(ctx context.Context) error {

return nil
}

func (t *Task) GetDefinition() provider.TaskDefinition {
return t.GetState().Definition
}
2 changes: 2 additions & 0 deletions core/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type TaskI interface {
Stop(context.Context) error
Destroy(context.Context) error

GetDefinition() TaskDefinition

GetStatus(context.Context) (TaskStatus, error)

Modify(context.Context, TaskDefinition) error
Expand Down
18 changes: 8 additions & 10 deletions core/types/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package types
import (
"context"
"fmt"
"go.uber.org/zap"

rpchttp "github.com/cometbft/cometbft/rpc/client/http"
sdk "github.com/cosmos/cosmos-sdk/types"
"go.uber.org/zap"
"google.golang.org/grpc"

"github.com/skip-mev/petri/core/v2/provider"
Expand All @@ -19,8 +19,7 @@ type NodeConfig struct {

IsValidator bool // IsValidator denotes whether this node is a validator

Chain ChainI // Chain is the chain this node is running on
Provider provider.Provider // Provider is the provider this node is running on
Chain ChainI // Chain is the chain this node is running on
}

func (c NodeConfig) ValidateBasic() error {
Expand All @@ -32,10 +31,6 @@ func (c NodeConfig) ValidateBasic() error {
return fmt.Errorf("chain cannot be nil")
}

if c.Provider == nil {
return fmt.Errorf("provider cannot be nil")
}

return nil
}

Expand All @@ -45,10 +40,12 @@ func (c NodeConfig) ValidateBasic() error {
type NodeDefinitionModifier func(provider.TaskDefinition, NodeConfig) provider.TaskDefinition

// NodeCreator is a type of function that given a NodeConfig creates a new logical node
type NodeCreator func(context.Context, *zap.Logger, NodeConfig) (NodeI, error)
type NodeCreator func(context.Context, *zap.Logger, provider.ProviderI, NodeConfig) (NodeI, error)

// NodeI represents an interface for a logical node that is running on a chain
type NodeI interface {
provider.TaskI

// GetConfig returns the configuration of the node
GetConfig() NodeConfig

Expand Down Expand Up @@ -93,8 +90,9 @@ type NodeI interface {
// NodeId returns the p2p peer ID of the node
NodeId(context.Context) (string, error)

// GetTask returns the underlying node's Task
GetTask() *provider.Task
// GetDefinition returns the task definition of the node
GetDefinition() provider.TaskDefinition

// GetIP returns the IP address of the node
GetIP(context.Context) (string, error)
}
32 changes: 15 additions & 17 deletions cosmos/chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Chain struct {
var _ petritypes.ChainI = &Chain{}

// CreateChain creates the Chain object and initializes the node tasks, their backing compute and the validator wallets
func CreateChain(ctx context.Context, logger *zap.Logger, infraProvider provider.Provider, config petritypes.ChainConfig) (*Chain, error) {
func CreateChain(ctx context.Context, logger *zap.Logger, infraProvider provider.ProviderI, config petritypes.ChainConfig) (*Chain, error) {
if err := config.ValidateBasic(); err != nil {
return nil, fmt.Errorf("failed to validate chain config: %w", err)
}
Expand All @@ -68,11 +68,10 @@ func CreateChain(ctx context.Context, logger *zap.Logger, infraProvider provider

logger.Info("creating validator", zap.String("name", validatorName))

validator, err := config.NodeCreator(ctx, logger, petritypes.NodeConfig{
validator, err := config.NodeCreator(ctx, logger, infraProvider, petritypes.NodeConfig{
Index: i,
Name: validatorName,
IsValidator: true,
Provider: infraProvider,
Chain: &chain,
})
if err != nil {
Expand All @@ -97,11 +96,10 @@ func CreateChain(ctx context.Context, logger *zap.Logger, infraProvider provider

logger.Info("creating node", zap.String("name", nodeName))

node, err := config.NodeCreator(ctx, logger, petritypes.NodeConfig{
node, err := config.NodeCreator(ctx, logger, infraProvider, petritypes.NodeConfig{
Index: i,
Name: nodeName,
IsValidator: true,
Provider: infraProvider,
Chain: &chain,
})
if err != nil {
Expand Down Expand Up @@ -139,7 +137,7 @@ func (c *Chain) Height(ctx context.Context) (uint64, error) {

client, err := node.GetTMClient(ctx)

c.logger.Debug("fetching height from", zap.String("node", node.GetTask().Definition.Name), zap.String("ip", client.Remote()))
c.logger.Debug("fetching height from", zap.String("node", node.GetDefinition().Name), zap.String("ip", client.Remote()))

if err != nil {
return 0, err
Expand Down Expand Up @@ -178,7 +176,7 @@ func (c *Chain) Init(ctx context.Context) error {
v := v
idx := idx
eg.Go(func() error {
c.logger.Info("setting up validator home dir", zap.String("validator", v.GetTask().Definition.Name))
c.logger.Info("setting up validator home dir", zap.String("validator", v.GetDefinition().Name))
if err := v.InitHome(ctx); err != nil {
return fmt.Errorf("error initializing home dir: %v", err)
}
Expand Down Expand Up @@ -211,7 +209,7 @@ func (c *Chain) Init(ctx context.Context) error {
n := n

eg.Go(func() error {
c.logger.Info("setting up node home dir", zap.String("node", n.GetTask().Definition.Name))
c.logger.Info("setting up node home dir", zap.String("node", n.GetDefinition().Name))
if err := n.InitHome(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -247,7 +245,7 @@ func (c *Chain) Init(ctx context.Context) error {
return err
}

c.logger.Info("setting up validator keys", zap.String("validator", validatorN.GetTask().Definition.Name), zap.String("address", bech32))
c.logger.Info("setting up validator keys", zap.String("validator", validatorN.GetDefinition().Name), zap.String("address", bech32))
if err := firstValidator.AddGenesisAccount(ctx, bech32, genesisAmounts); err != nil {
return err
}
Expand Down Expand Up @@ -292,7 +290,7 @@ func (c *Chain) Init(ctx context.Context) error {
for i := range c.Validators {
v := c.Validators[i]
eg.Go(func() error {
c.logger.Info("overwriting genesis for validator", zap.String("validator", v.GetTask().Definition.Name))
c.logger.Info("overwriting genesis for validator", zap.String("validator", v.GetDefinition().Name))
if err := v.OverwriteGenesisFile(ctx, genbz); err != nil {
return err
}
Expand All @@ -309,7 +307,7 @@ func (c *Chain) Init(ctx context.Context) error {
for i := range c.Nodes {
n := c.Nodes[i]
eg.Go(func() error {
c.logger.Info("overwriting node genesis", zap.String("node", n.GetTask().Definition.Name))
c.logger.Info("overwriting node genesis", zap.String("node", n.GetDefinition().Name))
if err := n.OverwriteGenesisFile(ctx, genbz); err != nil {
return err
}
Expand All @@ -330,8 +328,8 @@ func (c *Chain) Init(ctx context.Context) error {
for i := range c.Validators {
v := c.Validators[i]
eg.Go(func() error {
c.logger.Info("starting validator task", zap.String("validator", v.GetTask().Definition.Name))
if err := v.GetTask().Start(ctx, true); err != nil {
c.logger.Info("starting validator task", zap.String("validator", v.GetDefinition().Name))
if err := v.Start(ctx); err != nil {
return err
}
return nil
Expand All @@ -341,8 +339,8 @@ func (c *Chain) Init(ctx context.Context) error {
for i := range c.Nodes {
n := c.Nodes[i]
eg.Go(func() error {
c.logger.Info("starting node task", zap.String("node", n.GetTask().Definition.Name))
if err := n.GetTask().Start(ctx, true); err != nil {
c.logger.Info("starting node task", zap.String("node", n.GetDefinition().Name))
if err := n.Start(ctx); err != nil {
return err
}
return nil
Expand All @@ -361,13 +359,13 @@ func (c *Chain) Teardown(ctx context.Context) error {
c.logger.Info("tearing down chain", zap.String("name", c.Config.ChainId))

for _, v := range c.Validators {
if err := v.GetTask().Destroy(ctx, true); err != nil {
if err := v.Destroy(ctx); err != nil {
return err
}
}

for _, n := range c.Nodes {
if err := n.GetTask().Destroy(ctx, true); err != nil {
if err := n.Destroy(ctx); err != nil {
return err
}
}
Expand Down
1 change: 1 addition & 0 deletions cosmos/node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package node
import (
"bytes"
"context"
"fmt"
toml "github.com/pelletier/go-toml/v2"

petritypes "github.com/skip-mev/petri/core/v2/types"
Expand Down
24 changes: 12 additions & 12 deletions cosmos/node/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (

// GenesisFileContent returns the genesis file on the node in byte format
func (n *Node) GenesisFileContent(ctx context.Context) ([]byte, error) {
n.logger.Info("reading genesis file", zap.String("node", n.Definition.Name))
n.logger.Info("reading genesis file", zap.String("node", n.GetDefinition().Name))

bz, err := n.Task.ReadFile(ctx, "config/genesis.json")
bz, err := n.ReadFile(ctx, "config/genesis.json")
if err != nil {
return nil, err
}
Expand All @@ -36,18 +36,18 @@ func (n *Node) CopyGenTx(ctx context.Context, dstNode petritypes.NodeI) error {
path := fmt.Sprintf("config/gentx/gentx-%s.json", nid)

n.logger.Debug("reading gen tx", zap.String("node", n.GetConfig().Name))
gentx, err := n.Task.ReadFile(context.Background(), path)
gentx, err := n.ReadFile(context.Background(), path)
if err != nil {
return err
}

n.logger.Debug("writing gen tx", zap.String("node", dstNode.GetConfig().Name))
return dstNode.GetTask().WriteFile(context.Background(), path, gentx)
return dstNode.WriteFile(context.Background(), path, gentx)
}

// AddGenesisAccount adds a genesis account to the node's local genesis file
func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmounts []types.Coin) error {
n.logger.Debug("adding genesis account", zap.String("node", n.Definition.Name), zap.String("address", address))
n.logger.Debug("adding genesis account", zap.String("node", n.GetDefinition().Name), zap.String("address", address))

amount := ""

Expand All @@ -71,7 +71,7 @@ func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmo
command = append(command, "add-genesis-account", address, amount)
command = n.BinCommand(command...)

stdout, stderr, exitCode, err := n.Task.RunCommand(ctx, command)
stdout, stderr, exitCode, err := n.RunCommand(ctx, command)
n.logger.Debug("add-genesis-account", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode))

if err != nil {
Expand All @@ -87,7 +87,7 @@ func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmo

// GenerateGenTx generates a genesis transaction for the node
func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Coin) error {
n.logger.Info("generating genesis transaction", zap.String("node", n.Definition.Name))
n.logger.Info("generating genesis transaction", zap.String("node", n.GetDefinition().Name))

chainConfig := n.chain.GetConfig()

Expand All @@ -103,7 +103,7 @@ func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Co

command = n.BinCommand(command...)

stdout, stderr, exitCode, err := n.Task.RunCommand(ctx, command)
stdout, stderr, exitCode, err := n.RunCommand(ctx, command)
n.logger.Debug("gentx", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode))

if err != nil {
Expand All @@ -119,7 +119,7 @@ func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Co

// CollectGenTxs collects the genesis transactions from the node and create a finalized genesis file
func (n *Node) CollectGenTxs(ctx context.Context) error {
n.logger.Info("collecting genesis transactions", zap.String("node", n.Definition.Name))
n.logger.Info("collecting genesis transactions", zap.String("node", n.GetDefinition().Name))

command := []string{}

Expand All @@ -129,7 +129,7 @@ func (n *Node) CollectGenTxs(ctx context.Context) error {

command = append(command, "collect-gentxs")

stdout, stderr, exitCode, err := n.Task.RunCommand(ctx, n.BinCommand(command...))
stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand(command...))
n.logger.Debug("collect-gentxs", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode))

if err != nil {
Expand All @@ -145,7 +145,7 @@ func (n *Node) CollectGenTxs(ctx context.Context) error {

// OverwriteGenesisFile overwrites the genesis file on the node with the provided genesis file
func (n *Node) OverwriteGenesisFile(ctx context.Context, bz []byte) error {
n.logger.Info("overwriting genesis file", zap.String("node", n.Definition.Name))
n.logger.Info("overwriting genesis file", zap.String("node", n.GetDefinition().Name))

return n.Task.WriteFile(ctx, "config/genesis.json", bz)
return n.WriteFile(ctx, "config/genesis.json", bz)
}
4 changes: 2 additions & 2 deletions cosmos/node/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (

// InitHome initializes the node's home directory
func (n *Node) InitHome(ctx context.Context) error {
n.logger.Info("initializing home", zap.String("name", n.Definition.Name))
n.logger.Info("initializing home", zap.String("name", n.GetDefinition().Name))
chainConfig := n.chain.GetConfig()

stdout, stderr, exitCode, err := n.Task.RunCommand(ctx, n.BinCommand([]string{"init", n.Definition.Name, "--chain-id", chainConfig.ChainId}...))
stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand([]string{"init", n.GetDefinition().Name, "--chain-id", chainConfig.ChainId}...))
n.logger.Debug("init home", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode))

if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion cosmos/node/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (n *Node) KeyBech32(ctx context.Context, name, bech string) (string, error)
command = append(command, "--bech", bech)
}

stdout, stderr, exitCode, err := n.Task.RunCommand(ctx, command)
stdout, stderr, exitCode, err := n.RunCommand(ctx, command)
n.logger.Debug("show key", zap.String("name", name), zap.String("stdout", stdout), zap.String("stderr", stderr))

if err != nil {
Expand Down
19 changes: 7 additions & 12 deletions cosmos/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

type Node struct {
*provider.Task
provider.TaskI

logger *zap.Logger
config petritypes.NodeConfig
Expand All @@ -29,7 +29,7 @@ type Node struct {
var _ petritypes.NodeCreator = CreateNode

// CreateNode creates a new logical node and creates the underlying workload for it
func CreateNode(ctx context.Context, logger *zap.Logger, nodeConfig petritypes.NodeConfig) (petritypes.NodeI, error) {
func CreateNode(ctx context.Context, logger *zap.Logger, infraProvider provider.ProviderI, nodeConfig petritypes.NodeConfig) (petritypes.NodeI, error) {
if err := nodeConfig.ValidateBasic(); err != nil {
return nil, fmt.Errorf("failed to validate node config: %w", err)
}
Expand Down Expand Up @@ -71,24 +71,19 @@ func CreateNode(ctx context.Context, logger *zap.Logger, nodeConfig petritypes.N
def = nodeConfig.Chain.GetConfig().NodeDefinitionModifier(def, nodeConfig)
}

task, err := provider.CreateTask(ctx, node.logger, nodeConfig.Provider, def)
task, err := infraProvider.CreateTask(ctx, def)
if err != nil {
return nil, err
}

node.Task = task
node.TaskI = task

return &node, nil
}

// GetTask returns the underlying task of the node
func (n *Node) GetTask() *provider.Task {
return n.Task
}

// GetTMClient returns a CometBFT HTTP client for the node
func (n *Node) GetTMClient(ctx context.Context) (*rpchttp.HTTP, error) {
addr, err := n.Task.GetExternalAddress(ctx, "26657")
addr, err := n.GetExternalAddress(ctx, "26657")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -127,7 +122,7 @@ func (n *Node) GetGRPCClient(ctx context.Context) (*grpc.ClientConn, error) {

// Height returns the current block height of the node
func (n *Node) Height(ctx context.Context) (uint64, error) {
n.logger.Debug("getting height", zap.String("node", n.Definition.Name))
n.logger.Debug("getting height", zap.String("node", n.GetDefinition().Name))
client, err := n.GetTMClient(ctx)
if err != nil {
return 0, err
Expand All @@ -143,7 +138,7 @@ func (n *Node) Height(ctx context.Context) (uint64, error) {

// NodeId returns the node's p2p ID
func (n *Node) NodeId(ctx context.Context) (string, error) {
j, err := n.Task.ReadFile(ctx, "config/node_key.json")
j, err := n.ReadFile(ctx, "config/node_key.json")
if err != nil {
return "", fmt.Errorf("getting node_key.json content: %w", err)
}
Expand Down

0 comments on commit 0e0af2a

Please sign in to comment.