From 23751648f31355516cd7acf3e61f6bedf6c69250 Mon Sep 17 00:00:00 2001 From: huangzhiran <30522704+huangzhiran@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:07:51 +0800 Subject: [PATCH] support prover task query api (#704) * support prover task query api * fix unit test --- cmd/apinode/api/http.go | 56 +++++++++++++++++++++-- cmd/apinode/config/config.go | 2 + cmd/apinode/main.go | 2 +- cmd/prover/api/http.go | 76 ++++++++++++++++++++++++++++++++ cmd/prover/config/config.go | 2 + cmd/prover/config/config_test.go | 2 + cmd/prover/db/db.go | 20 ++++++++- cmd/prover/main.go | 7 +++ task/processor/processor.go | 9 ++-- 9 files changed, 165 insertions(+), 11 deletions(-) create mode 100644 cmd/prover/api/http.go diff --git a/cmd/apinode/api/http.go b/cmd/apinode/api/http.go index 5d7af824..0429717d 100644 --- a/cmd/apinode/api/http.go +++ b/cmd/apinode/api/http.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/ecdsa" "encoding/json" + "fmt" + "io" "log/slog" "net/http" "strings" @@ -16,6 +18,7 @@ import ( "github.com/pkg/errors" "github.com/iotexproject/w3bstream/cmd/apinode/persistence" + proverapi "github.com/iotexproject/w3bstream/cmd/prover/api" "github.com/iotexproject/w3bstream/p2p" ) @@ -47,6 +50,7 @@ type StateLog struct { State string `json:"state"` Time time.Time `json:"time"` Comment string `json:"comment,omitempty"` + Error string `json:"error,omitempty"` Tx string `json:"transaction_hash,omitempty"` ProverID string `json:"prover_id,omitempty"` } @@ -63,6 +67,7 @@ type httpServer struct { aggregationAmount int prv *ecdsa.PrivateKey pubSub *p2p.PubSub + proverAddr string } func (s *httpServer) handleMessage(c *gin.Context) { @@ -176,6 +181,50 @@ func (s *httpServer) queryTask(c *gin.Context) { return } if ts == nil { + reqJ, err := json.Marshal(proverapi.QueryTaskReq{ + ProjectID: req.ProjectID, + TaskID: req.TaskID, + }) + if err != nil { + slog.Error("failed to marshal prover request", "error", err) + c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to marshal prover request"))) + return + } + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/task", s.proverAddr), bytes.NewBuffer(reqJ)) + if err != nil { + slog.Error("failed to build http request", "error", err) + c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to build http request"))) + return + } + req.Header.Set("Content-Type", "application/json") + + proverResp, err := http.DefaultClient.Do(req) + if err != nil { + slog.Error("failed to call prover http server", "error", err) + c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to call prover http server"))) + return + } + defer proverResp.Body.Close() + + body, err := io.ReadAll(proverResp.Body) + if err != nil { + slog.Error("failed to read prover http server response", "error", err) + c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to read prover http server response"))) + return + } + taskResp := &proverapi.QueryTaskResp{} + if err := json.Unmarshal(body, &taskResp); err != nil { + slog.Error("failed to unmarshal prover http server response", "error", err) + c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to unmarshal prover http server response"))) + return + } + if taskResp.Processed && taskResp.Error != "" { + resp.States = append(resp.States, &StateLog{ + State: "failed", + Error: taskResp.Error, + Time: taskResp.Time, + }) + } c.JSON(http.StatusOK, resp) return } @@ -189,20 +238,21 @@ func (s *httpServer) queryTask(c *gin.Context) { } // this func will block caller -func Run(p *persistence.Persistence, prv *ecdsa.PrivateKey, pubSub *p2p.PubSub, aggregationAmount int, address string) error { +func Run(p *persistence.Persistence, prv *ecdsa.PrivateKey, pubSub *p2p.PubSub, aggregationAmount int, addr, proverAddr string) error { s := &httpServer{ engine: gin.Default(), p: p, aggregationAmount: aggregationAmount, prv: prv, pubSub: pubSub, + proverAddr: proverAddr, } s.engine.POST("/message", s.handleMessage) s.engine.GET("/task", s.queryTask) - if err := s.engine.Run(address); err != nil { - slog.Error("failed to start http server", "address", address, "error", err) + if err := s.engine.Run(addr); err != nil { + slog.Error("failed to start http server", "address", addr, "error", err) return errors.Wrap(err, "could not start http server; check if the address is in use or network is accessible") } return nil diff --git a/cmd/apinode/config/config.go b/cmd/apinode/config/config.go index b30ddb4c..b2498697 100644 --- a/cmd/apinode/config/config.go +++ b/cmd/apinode/config/config.go @@ -10,6 +10,7 @@ import ( type Config struct { LogLevel slog.Level `env:"LOG_LEVEL,optional"` ServiceEndpoint string `env:"HTTP_SERVICE_ENDPOINT"` + ProverServiceEndpoint string `env:"PROVER_SERVICE_ENDPOINT"` AggregationAmount int `env:"AGGREGATION_AMOUNT,optional"` DatabaseDSN string `env:"DATABASE_DSN"` PrvKey string `env:"PRIVATE_KEY,optional"` @@ -24,6 +25,7 @@ type Config struct { var defaultTestnetConfig = &Config{ LogLevel: slog.LevelInfo, ServiceEndpoint: ":9000", + ProverServiceEndpoint: "localhost:9002", AggregationAmount: 1, DatabaseDSN: "postgres://postgres:mysecretpassword@postgres:5432/w3bstream?sslmode=disable", PrvKey: "dbfe03b0406549232b8dccc04be8224fcc0afa300a33d4f335dcfdfead861c85", diff --git a/cmd/apinode/main.go b/cmd/apinode/main.go index a240c442..ac550453 100644 --- a/cmd/apinode/main.go +++ b/cmd/apinode/main.go @@ -61,7 +61,7 @@ func main() { } go func() { - if err := api.Run(p, prv, pubSub, cfg.AggregationAmount, cfg.ServiceEndpoint); err != nil { + if err := api.Run(p, prv, pubSub, cfg.AggregationAmount, cfg.ServiceEndpoint, cfg.ProverServiceEndpoint); err != nil { log.Fatal(err) } }() diff --git a/cmd/prover/api/http.go b/cmd/prover/api/http.go new file mode 100644 index 00000000..c29077d3 --- /dev/null +++ b/cmd/prover/api/http.go @@ -0,0 +1,76 @@ +package api + +import ( + "log/slog" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "github.com/iotexproject/w3bstream/cmd/prover/db" +) + +type ErrResp struct { + Error string `json:"error,omitempty"` +} + +func NewErrResp(err error) *ErrResp { + return &ErrResp{Error: err.Error()} +} + +type QueryTaskReq struct { + ProjectID uint64 `json:"projectID" binding:"required"` + TaskID string `json:"taskID" binding:"required"` +} + +type QueryTaskResp struct { + Time time.Time `json:"time"` + Processed bool `json:"processed"` + Error string `json:"error,omitempty"` +} + +type httpServer struct { + engine *gin.Engine + db *db.DB +} + +func (s *httpServer) queryTask(c *gin.Context) { + req := &QueryTaskReq{} + if err := c.ShouldBindJSON(req); err != nil { + slog.Error("failed to bind request", "error", err) + c.JSON(http.StatusBadRequest, NewErrResp(errors.Wrap(err, "invalid request payload"))) + return + } + taskID := common.HexToHash(req.TaskID) + + processed, errMsg, createdAt, err := s.db.ProcessedTask(req.ProjectID, taskID) + if err != nil { + slog.Error("failed to query processed task", "error", err) + c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to query processed task"))) + return + } + + c.JSON(http.StatusOK, &QueryTaskResp{ + Time: createdAt, + Processed: processed, + Error: errMsg, + }) +} + +// this func will block caller +func Run(db *db.DB, address string) error { + s := &httpServer{ + engine: gin.Default(), + db: db, + } + + s.engine.GET("/task", s.queryTask) + + if err := s.engine.Run(address); err != nil { + slog.Error("failed to start http server", "address", address, "error", err) + return errors.Wrap(err, "could not start http server; check if the address is in use or network is accessible") + } + return nil +} diff --git a/cmd/prover/config/config.go b/cmd/prover/config/config.go index 88c58ed5..56cce1d6 100644 --- a/cmd/prover/config/config.go +++ b/cmd/prover/config/config.go @@ -9,6 +9,7 @@ import ( type Config struct { LogLevel slog.Level `env:"LOG_LEVEL,optional"` + ServiceEndpoint string `env:"HTTP_SERVICE_ENDPOINT"` VMEndpoints string `env:"VM_ENDPOINTS"` DatasourceDSN string `env:"DATASOURCE_DSN"` ChainEndpoint string `env:"CHAIN_ENDPOINT,optional"` @@ -24,6 +25,7 @@ type Config struct { var ( defaultTestnetConfig = &Config{ LogLevel: slog.LevelInfo, + ServiceEndpoint: ":9002", VMEndpoints: `{"1":"localhost:4001","2":"localhost:4002","3":"zkwasm:4001","4":"wasm:4001"}`, ChainEndpoint: "https://babel-api.testnet.iotex.io", DatasourceDSN: "postgres://postgres:mysecretpassword@postgres:5432/w3bstream?sslmode=disable", diff --git a/cmd/prover/config/config_test.go b/cmd/prover/config/config_test.go index 4906454e..e1f9b8bb 100644 --- a/cmd/prover/config/config_test.go +++ b/cmd/prover/config/config_test.go @@ -17,6 +17,7 @@ func TestConfig_Init(t *testing.T) { t.Run("UseEnvConfig", func(t *testing.T) { os.Clearenv() expected := Config{ + ServiceEndpoint: "test", VMEndpoints: `{"1":"halo2:4001","2":"risc0:4001","3":"zkwasm:4001","4":"wasm:4001"}`, ChainEndpoint: "http://abc.def.com", DatasourceDSN: "postgres://root@localhost/abc?ext=666", @@ -25,6 +26,7 @@ func TestConfig_Init(t *testing.T) { LocalDBDir: "./test", } + _ = os.Setenv("HTTP_SERVICE_ENDPOINT", expected.ServiceEndpoint) _ = os.Setenv("VM_ENDPOINTS", expected.VMEndpoints) _ = os.Setenv("CHAIN_ENDPOINT", expected.ChainEndpoint) _ = os.Setenv("DATASOURCE_DSN", expected.DatasourceDSN) diff --git a/cmd/prover/db/db.go b/cmd/prover/db/db.go index e408179e..53bd6d41 100644 --- a/cmd/prover/db/db.go +++ b/cmd/prover/db/db.go @@ -2,6 +2,7 @@ package db import ( "bytes" + "time" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" @@ -35,6 +36,7 @@ type task struct { TaskID common.Hash `gorm:"uniqueIndex:task_uniq,not null"` ProjectID uint64 `gorm:"uniqueIndex:task_uniq,not null"` Processed bool `gorm:"index:unprocessed_task,not null,default:false"` + Error string `gorm:"not null,default:''"` } type DB struct { @@ -128,11 +130,14 @@ func (p *DB) CreateTask(projectID uint64, taskID common.Hash, prover common.Addr return errors.Wrap(err, "failed to upsert task") } -func (p *DB) ProcessTask(projectID uint64, taskID common.Hash) error { +func (p *DB) ProcessTask(projectID uint64, taskID common.Hash, err error) error { t := &task{ Processed: true, } - err := p.db.Model(t).Where("task_id = ?", taskID).Where("project_id = ?", projectID).Updates(t).Error + if err != nil { + t.Error = err.Error() + } + err = p.db.Model(t).Where("task_id = ?", taskID).Where("project_id = ?", projectID).Updates(t).Error return errors.Wrap(err, "failed to update task") } @@ -141,6 +146,17 @@ func (p *DB) DeleteTask(projectID uint64, taskID, tx common.Hash) error { return errors.Wrap(err, "failed to delete task") } +func (p *DB) ProcessedTask(projectID uint64, taskID common.Hash) (bool, string, time.Time, error) { + t := task{} + if err := p.db.Where("task_id = ?", taskID).Where("project_id = ?", projectID).First(&t).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return false, "", time.Now(), nil + } + return false, "", time.Time{}, errors.Wrap(err, "failed to query processed task") + } + return t.Processed, t.Error, t.CreatedAt, nil +} + func (p *DB) UnprocessedTask() (uint64, common.Hash, error) { t := task{} if err := p.db.Order("created_at ASC").Where("processed = false").First(&t).Error; err != nil { diff --git a/cmd/prover/main.go b/cmd/prover/main.go index 07ca5781..d43c51af 100644 --- a/cmd/prover/main.go +++ b/cmd/prover/main.go @@ -12,6 +12,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/pkg/errors" + "github.com/iotexproject/w3bstream/cmd/prover/api" "github.com/iotexproject/w3bstream/cmd/prover/config" "github.com/iotexproject/w3bstream/cmd/prover/db" "github.com/iotexproject/w3bstream/datasource" @@ -83,6 +84,12 @@ func main() { log.Fatal(errors.Wrap(err, "failed to run task processor")) } + go func() { + if err := api.Run(db, cfg.ServiceEndpoint); err != nil { + log.Fatal(err) + } + }() + done := make(chan os.Signal, 1) signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) <-done diff --git a/task/processor/processor.go b/task/processor/processor.go index 6ea3a3eb..5e033b9e 100644 --- a/task/processor/processor.go +++ b/task/processor/processor.go @@ -25,7 +25,7 @@ type RetrieveTask func(projectID uint64, taskID common.Hash) (*task.Task, error) type DB interface { UnprocessedTask() (uint64, common.Hash, error) - ProcessTask(uint64, common.Hash) error + ProcessTask(uint64, common.Hash, error) error } type processor struct { @@ -98,12 +98,11 @@ func (r *processor) run() { time.Sleep(r.waitingTime) continue } - if err := r.process(projectID, taskID); err != nil { + err = r.process(projectID, taskID) + if err != nil { slog.Error("failed to process task", "error", err) - time.Sleep(r.waitingTime) - continue } - if err := r.db.ProcessTask(projectID, taskID); err != nil { + if err := r.db.ProcessTask(projectID, taskID, err); err != nil { slog.Error("failed to process db task", "error", err) } }