Skip to content

Commit

Permalink
support prover task query api (#704)
Browse files Browse the repository at this point in the history
* support prover task query api

* fix unit test
  • Loading branch information
huangzhiran authored Oct 17, 2024
1 parent aa1ae1f commit 2375164
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 11 deletions.
56 changes: 53 additions & 3 deletions cmd/apinode/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"crypto/ecdsa"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
Expand All @@ -16,6 +18,7 @@ import (
"github.com/pkg/errors"

"github.com/iotexproject/w3bstream/cmd/apinode/persistence"
proverapi "github.com/iotexproject/w3bstream/cmd/prover/api"
"github.com/iotexproject/w3bstream/p2p"
)

Expand Down Expand Up @@ -47,6 +50,7 @@ type StateLog struct {
State string `json:"state"`
Time time.Time `json:"time"`
Comment string `json:"comment,omitempty"`
Error string `json:"error,omitempty"`
Tx string `json:"transaction_hash,omitempty"`
ProverID string `json:"prover_id,omitempty"`
}
Expand All @@ -63,6 +67,7 @@ type httpServer struct {
aggregationAmount int
prv *ecdsa.PrivateKey
pubSub *p2p.PubSub
proverAddr string
}

func (s *httpServer) handleMessage(c *gin.Context) {
Expand Down Expand Up @@ -176,6 +181,50 @@ func (s *httpServer) queryTask(c *gin.Context) {
return
}
if ts == nil {
reqJ, err := json.Marshal(proverapi.QueryTaskReq{
ProjectID: req.ProjectID,
TaskID: req.TaskID,
})
if err != nil {
slog.Error("failed to marshal prover request", "error", err)
c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to marshal prover request")))
return
}
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/task", s.proverAddr), bytes.NewBuffer(reqJ))
if err != nil {
slog.Error("failed to build http request", "error", err)
c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to build http request")))
return
}
req.Header.Set("Content-Type", "application/json")

proverResp, err := http.DefaultClient.Do(req)
if err != nil {
slog.Error("failed to call prover http server", "error", err)
c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to call prover http server")))
return
}
defer proverResp.Body.Close()

body, err := io.ReadAll(proverResp.Body)
if err != nil {
slog.Error("failed to read prover http server response", "error", err)
c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to read prover http server response")))
return
}
taskResp := &proverapi.QueryTaskResp{}
if err := json.Unmarshal(body, &taskResp); err != nil {
slog.Error("failed to unmarshal prover http server response", "error", err)
c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to unmarshal prover http server response")))
return
}
if taskResp.Processed && taskResp.Error != "" {
resp.States = append(resp.States, &StateLog{
State: "failed",
Error: taskResp.Error,
Time: taskResp.Time,
})
}
c.JSON(http.StatusOK, resp)
return
}
Expand All @@ -189,20 +238,21 @@ func (s *httpServer) queryTask(c *gin.Context) {
}

// this func will block caller
func Run(p *persistence.Persistence, prv *ecdsa.PrivateKey, pubSub *p2p.PubSub, aggregationAmount int, address string) error {
func Run(p *persistence.Persistence, prv *ecdsa.PrivateKey, pubSub *p2p.PubSub, aggregationAmount int, addr, proverAddr string) error {
s := &httpServer{
engine: gin.Default(),
p: p,
aggregationAmount: aggregationAmount,
prv: prv,
pubSub: pubSub,
proverAddr: proverAddr,
}

s.engine.POST("/message", s.handleMessage)
s.engine.GET("/task", s.queryTask)

if err := s.engine.Run(address); err != nil {
slog.Error("failed to start http server", "address", address, "error", err)
if err := s.engine.Run(addr); err != nil {
slog.Error("failed to start http server", "address", addr, "error", err)
return errors.Wrap(err, "could not start http server; check if the address is in use or network is accessible")
}
return nil
Expand Down
2 changes: 2 additions & 0 deletions cmd/apinode/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type Config struct {
LogLevel slog.Level `env:"LOG_LEVEL,optional"`
ServiceEndpoint string `env:"HTTP_SERVICE_ENDPOINT"`
ProverServiceEndpoint string `env:"PROVER_SERVICE_ENDPOINT"`
AggregationAmount int `env:"AGGREGATION_AMOUNT,optional"`
DatabaseDSN string `env:"DATABASE_DSN"`
PrvKey string `env:"PRIVATE_KEY,optional"`
Expand All @@ -24,6 +25,7 @@ type Config struct {
var defaultTestnetConfig = &Config{
LogLevel: slog.LevelInfo,
ServiceEndpoint: ":9000",
ProverServiceEndpoint: "localhost:9002",
AggregationAmount: 1,
DatabaseDSN: "postgres://postgres:mysecretpassword@postgres:5432/w3bstream?sslmode=disable",
PrvKey: "dbfe03b0406549232b8dccc04be8224fcc0afa300a33d4f335dcfdfead861c85",
Expand Down
2 changes: 1 addition & 1 deletion cmd/apinode/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func main() {
}

go func() {
if err := api.Run(p, prv, pubSub, cfg.AggregationAmount, cfg.ServiceEndpoint); err != nil {
if err := api.Run(p, prv, pubSub, cfg.AggregationAmount, cfg.ServiceEndpoint, cfg.ProverServiceEndpoint); err != nil {
log.Fatal(err)
}
}()
Expand Down
76 changes: 76 additions & 0 deletions cmd/prover/api/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package api

import (
"log/slog"
"net/http"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"

"github.com/iotexproject/w3bstream/cmd/prover/db"
)

type ErrResp struct {
Error string `json:"error,omitempty"`
}

func NewErrResp(err error) *ErrResp {
return &ErrResp{Error: err.Error()}
}

type QueryTaskReq struct {
ProjectID uint64 `json:"projectID" binding:"required"`
TaskID string `json:"taskID" binding:"required"`
}

type QueryTaskResp struct {
Time time.Time `json:"time"`
Processed bool `json:"processed"`
Error string `json:"error,omitempty"`
}

type httpServer struct {
engine *gin.Engine
db *db.DB
}

func (s *httpServer) queryTask(c *gin.Context) {
req := &QueryTaskReq{}
if err := c.ShouldBindJSON(req); err != nil {
slog.Error("failed to bind request", "error", err)
c.JSON(http.StatusBadRequest, NewErrResp(errors.Wrap(err, "invalid request payload")))
return
}
taskID := common.HexToHash(req.TaskID)

processed, errMsg, createdAt, err := s.db.ProcessedTask(req.ProjectID, taskID)
if err != nil {
slog.Error("failed to query processed task", "error", err)
c.JSON(http.StatusInternalServerError, NewErrResp(errors.Wrap(err, "failed to query processed task")))
return
}

c.JSON(http.StatusOK, &QueryTaskResp{
Time: createdAt,
Processed: processed,
Error: errMsg,
})
}

// this func will block caller
func Run(db *db.DB, address string) error {
s := &httpServer{
engine: gin.Default(),
db: db,
}

s.engine.GET("/task", s.queryTask)

if err := s.engine.Run(address); err != nil {
slog.Error("failed to start http server", "address", address, "error", err)
return errors.Wrap(err, "could not start http server; check if the address is in use or network is accessible")
}
return nil
}
2 changes: 2 additions & 0 deletions cmd/prover/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

type Config struct {
LogLevel slog.Level `env:"LOG_LEVEL,optional"`
ServiceEndpoint string `env:"HTTP_SERVICE_ENDPOINT"`
VMEndpoints string `env:"VM_ENDPOINTS"`
DatasourceDSN string `env:"DATASOURCE_DSN"`
ChainEndpoint string `env:"CHAIN_ENDPOINT,optional"`
Expand All @@ -24,6 +25,7 @@ type Config struct {
var (
defaultTestnetConfig = &Config{
LogLevel: slog.LevelInfo,
ServiceEndpoint: ":9002",
VMEndpoints: `{"1":"localhost:4001","2":"localhost:4002","3":"zkwasm:4001","4":"wasm:4001"}`,
ChainEndpoint: "https://babel-api.testnet.iotex.io",
DatasourceDSN: "postgres://postgres:mysecretpassword@postgres:5432/w3bstream?sslmode=disable",
Expand Down
2 changes: 2 additions & 0 deletions cmd/prover/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func TestConfig_Init(t *testing.T) {
t.Run("UseEnvConfig", func(t *testing.T) {
os.Clearenv()
expected := Config{
ServiceEndpoint: "test",
VMEndpoints: `{"1":"halo2:4001","2":"risc0:4001","3":"zkwasm:4001","4":"wasm:4001"}`,
ChainEndpoint: "http://abc.def.com",
DatasourceDSN: "postgres://root@localhost/abc?ext=666",
Expand All @@ -25,6 +26,7 @@ func TestConfig_Init(t *testing.T) {
LocalDBDir: "./test",
}

_ = os.Setenv("HTTP_SERVICE_ENDPOINT", expected.ServiceEndpoint)
_ = os.Setenv("VM_ENDPOINTS", expected.VMEndpoints)
_ = os.Setenv("CHAIN_ENDPOINT", expected.ChainEndpoint)
_ = os.Setenv("DATASOURCE_DSN", expected.DatasourceDSN)
Expand Down
20 changes: 18 additions & 2 deletions cmd/prover/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"bytes"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/pkg/errors"
Expand Down Expand Up @@ -35,6 +36,7 @@ type task struct {
TaskID common.Hash `gorm:"uniqueIndex:task_uniq,not null"`
ProjectID uint64 `gorm:"uniqueIndex:task_uniq,not null"`
Processed bool `gorm:"index:unprocessed_task,not null,default:false"`
Error string `gorm:"not null,default:''"`
}

type DB struct {
Expand Down Expand Up @@ -128,11 +130,14 @@ func (p *DB) CreateTask(projectID uint64, taskID common.Hash, prover common.Addr
return errors.Wrap(err, "failed to upsert task")
}

func (p *DB) ProcessTask(projectID uint64, taskID common.Hash) error {
func (p *DB) ProcessTask(projectID uint64, taskID common.Hash, err error) error {
t := &task{
Processed: true,
}
err := p.db.Model(t).Where("task_id = ?", taskID).Where("project_id = ?", projectID).Updates(t).Error
if err != nil {
t.Error = err.Error()
}
err = p.db.Model(t).Where("task_id = ?", taskID).Where("project_id = ?", projectID).Updates(t).Error
return errors.Wrap(err, "failed to update task")
}

Expand All @@ -141,6 +146,17 @@ func (p *DB) DeleteTask(projectID uint64, taskID, tx common.Hash) error {
return errors.Wrap(err, "failed to delete task")
}

func (p *DB) ProcessedTask(projectID uint64, taskID common.Hash) (bool, string, time.Time, error) {
t := task{}
if err := p.db.Where("task_id = ?", taskID).Where("project_id = ?", projectID).First(&t).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return false, "", time.Now(), nil
}
return false, "", time.Time{}, errors.Wrap(err, "failed to query processed task")
}
return t.Processed, t.Error, t.CreatedAt, nil
}

func (p *DB) UnprocessedTask() (uint64, common.Hash, error) {
t := task{}
if err := p.db.Order("created_at ASC").Where("processed = false").First(&t).Error; err != nil {
Expand Down
7 changes: 7 additions & 0 deletions cmd/prover/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/pkg/errors"

"github.com/iotexproject/w3bstream/cmd/prover/api"
"github.com/iotexproject/w3bstream/cmd/prover/config"
"github.com/iotexproject/w3bstream/cmd/prover/db"
"github.com/iotexproject/w3bstream/datasource"
Expand Down Expand Up @@ -83,6 +84,12 @@ func main() {
log.Fatal(errors.Wrap(err, "failed to run task processor"))
}

go func() {
if err := api.Run(db, cfg.ServiceEndpoint); err != nil {
log.Fatal(err)
}
}()

done := make(chan os.Signal, 1)
signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
<-done
Expand Down
9 changes: 4 additions & 5 deletions task/processor/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type RetrieveTask func(projectID uint64, taskID common.Hash) (*task.Task, error)

type DB interface {
UnprocessedTask() (uint64, common.Hash, error)
ProcessTask(uint64, common.Hash) error
ProcessTask(uint64, common.Hash, error) error
}

type processor struct {
Expand Down Expand Up @@ -98,12 +98,11 @@ func (r *processor) run() {
time.Sleep(r.waitingTime)
continue
}
if err := r.process(projectID, taskID); err != nil {
err = r.process(projectID, taskID)
if err != nil {
slog.Error("failed to process task", "error", err)
time.Sleep(r.waitingTime)
continue
}
if err := r.db.ProcessTask(projectID, taskID); err != nil {
if err := r.db.ProcessTask(projectID, taskID, err); err != nil {
slog.Error("failed to process db task", "error", err)
}
}
Expand Down

0 comments on commit 2375164

Please sign in to comment.