Skip to content

Commit

Permalink
fix: template refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
GGXXLL committed Sep 14, 2022
1 parent 6a7af45 commit 171640b
Show file tree
Hide file tree
Showing 23 changed files with 152 additions and 265 deletions.
6 changes: 1 addition & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
SHA := $(shell git rev-parse --short=10 HEAD)

MAKEFILE_PATH := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
VERSION_DATE := $(shell $(MAKEFILE_PATH)/commit_date.sh)

# Build native Truss by default.
default: truss

dependencies:
go get -u google.golang.org/genproto
go get -u github.com/gogo/protobuf/protoc-gen-gogo
go get -u github.com/gogo/protobuf/protoc-gen-gogofaster
go get -u github.com/gogo/protobuf/protoc-gen-gofast
go get -u github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway
go get -u github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2
go get -u google.golang.org/grpc/cmd/protoc-gen-go-grpc
Expand All @@ -21,7 +17,7 @@ dependencies:

# Install truss
truss:
go install -ldflags '-X "main.version=$(SHA)" -X "main.date=$(VERSION_DATE)"' github.com/DoNewsCode/truss/cmd/truss
go install -ldflags '-X "main.version=$(SHA)"' github.com/DoNewsCode/truss/cmd/truss

# Run the go tests and the truss integration tests
test: test-go
Expand Down
3 changes: 3 additions & 0 deletions cmd/truss/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ func readPreviousGeneration(serviceDir string) (map[string]io.Reader, error) {
files := make(map[string]io.Reader)

addFileToFiles := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
switch info.Name() {
// Only files within the handlers dir are used to
Expand Down
22 changes: 0 additions & 22 deletions commit_date.sh

This file was deleted.

39 changes: 9 additions & 30 deletions deftree/gogo/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ func IsStdBytes(field *google_protobuf.FieldDescriptorProto) bool {
}

func IsStdType(field *google_protobuf.FieldDescriptorProto) bool {
return (IsStdTime(field) || IsStdDuration(field) ||
return IsStdTime(field) || IsStdDuration(field) ||
IsStdDouble(field) || IsStdFloat(field) ||
IsStdInt64(field) || IsStdUInt64(field) ||
IsStdInt32(field) || IsStdUInt32(field) ||
IsStdBool(field) ||
IsStdString(field) || IsStdBytes(field))
IsStdString(field) || IsStdBytes(field)
}

func IsWktPtr(field *google_protobuf.FieldDescriptorProto) bool {
Expand All @@ -109,34 +109,22 @@ func NeedsNilCheck(proto3 bool, field *google_protobuf.FieldDescriptorProto) boo

func IsCustomType(field *google_protobuf.FieldDescriptorProto) bool {
typ := GetCustomType(field)
if len(typ) > 0 {
return true
}
return false
return len(typ) > 0
}

func IsCastType(field *google_protobuf.FieldDescriptorProto) bool {
typ := GetCastType(field)
if len(typ) > 0 {
return true
}
return false
return len(typ) > 0
}

func IsCastKey(field *google_protobuf.FieldDescriptorProto) bool {
typ := GetCastKey(field)
if len(typ) > 0 {
return true
}
return false
return len(typ) > 0
}

func IsCastValue(field *google_protobuf.FieldDescriptorProto) bool {
typ := GetCastValue(field)
if len(typ) > 0 {
return true
}
return false
return len(typ) > 0
}

func HasEnumDecl(file *google_protobuf.FileDescriptorProto, enum *google_protobuf.EnumDescriptorProto) bool {
Expand Down Expand Up @@ -201,26 +189,17 @@ func GetCastValue(field *google_protobuf.FieldDescriptorProto) string {

func IsCustomName(field *google_protobuf.FieldDescriptorProto) bool {
name := GetCustomName(field)
if len(name) > 0 {
return true
}
return false
return len(name) > 0
}

func IsEnumCustomName(field *google_protobuf.EnumDescriptorProto) bool {
name := GetEnumCustomName(field)
if len(name) > 0 {
return true
}
return false
return len(name) > 0
}

func IsEnumValueCustomName(field *google_protobuf.EnumValueDescriptorProto) bool {
name := GetEnumValueCustomName(field)
if len(name) > 0 {
return true
}
return false
return len(name) > 0
}

func GetCustomName(field *google_protobuf.FieldDescriptorProto) string {
Expand Down
58 changes: 22 additions & 36 deletions gengokit/handlers/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ func init() {
log.SetLevel(log.DebugLevel)
}

func fatalError(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}

func TestServerMethsTempl(t *testing.T) {
const def = `
syntax = "proto3";
Expand Down Expand Up @@ -63,19 +69,16 @@ func TestServerMethsTempl(t *testing.T) {
}
`
sd, err := svcdef.NewFromString(def, gopath)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

var he handlerData
he.Methods = sd.Service.Methods
he.ServiceName = sd.Service.Name

gen, err := applyServerMethsTempl(he)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)
genBytes, err := ioutil.ReadAll(gen)
fatalError(t, err)
const expected = `
func (s Service) ProtoMethod(ctx context.Context, in *pb.RequestMessage) (*pb.ResponseMessage, error){
var resp pb.ResponseMessage
Expand Down Expand Up @@ -124,13 +127,14 @@ func TestApplyServerTempl(t *testing.T) {
PBPackage: "github.com/DoNewsCode/truss/gengokit/general-service",
}
sd, err := svcdef.NewFromString(def, gopath)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)
te, err := gengokit.NewData(sd, conf)
fatalError(t, err)

gen, err := applyServerTempl(te)
fatalError(t, err)
genBytes, err := ioutil.ReadAll(gen)
fatalError(t, err)
expected := `
package proto
Expand Down Expand Up @@ -212,9 +216,7 @@ func TestIsValidFunc(t *testing.T) {
}
`
sd, err := svcdef.NewFromString(def, gopath)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

m := newMethodMap(sd.Service.Methods)
const validUnexported = `package p;
Expand Down Expand Up @@ -297,9 +299,7 @@ func TestPruneDecls(t *testing.T) {
}
`
sd, err := svcdef.NewFromString(def, gopath)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

m := newMethodMap(sd.Service.Methods)

Expand Down Expand Up @@ -429,9 +429,7 @@ func TestUpdateMethods(t *testing.T) {
`

sd, err := svcdef.NewFromString(def, gopath)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

svc := sd.Service
allMethods := svc.Methods
Expand All @@ -442,21 +440,15 @@ func TestUpdateMethods(t *testing.T) {
}

te, err := gengokit.NewData(sd, conf)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

svc.Methods = []*svcdef.ServiceMethod{allMethods[0]}

firstCode, err := renderService(svc, "", te)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

secondCode, err := renderService(svc, firstCode, te)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

if len(firstCode) != len(secondCode) {
t.Fatal("Generated service differs after regenerated with same definition\n" +
Expand All @@ -466,9 +458,7 @@ func TestUpdateMethods(t *testing.T) {
svc.Methods = append(svc.Methods, allMethods[1])

thirdCode, err := renderService(svc, secondCode, te)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

if len(thirdCode) <= len(secondCode) {
t.Fatal("Generated service not longer after regenerated with additional service method\n" +
Expand All @@ -479,9 +469,7 @@ func TestUpdateMethods(t *testing.T) {
svc.Methods = svc.Methods[1:]

forthCode, err := renderService(svc, thirdCode, te)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

if len(forthCode) >= len(thirdCode) {
t.Fatal("Generated service not shorter after regenerated with fewer service method\n" +
Expand All @@ -491,9 +479,7 @@ func TestUpdateMethods(t *testing.T) {
svc.Methods = allMethods

fifthCode, err := renderService(svc, forthCode, te)
if err != nil {
t.Fatal(err)
}
fatalError(t, err)

if len(fifthCode) <= len(forthCode) {
t.Fatal("Generated service not longer after regenerated with additional service method\n" +
Expand Down
7 changes: 1 addition & 6 deletions gengokit/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (m methodMap) pruneDecls(decls []ast.Decl, svcName string) []ast.Decl {
newDecls = append(newDecls, x)
continue
}
if ok := isValidFunc(x, m, svcName); ok == true {
if ok := isValidFunc(x, m, svcName); ok {
updateParams(x, m[name])
updateResults(x, m[name])
newDecls = append(newDecls, x)
Expand Down Expand Up @@ -265,11 +265,6 @@ func exprString(e ast.Expr) string {
}
// *foo.Foo or foo.Foo
if sel, _ := e.(*ast.SelectorExpr); sel != nil {
// *foo.Foo -> foo.Foo
if ptr, _ := e.(*ast.StarExpr); ptr != nil {
prefix = "*"
e = ptr.X
}
// foo.Foo
if x, _ := sel.X.(*ast.Ident); x != nil {
return prefix + x.Name + "." + sel.Sel.Name
Expand Down
22 changes: 0 additions & 22 deletions gengokit/handlers/templates/hook.go

This file was deleted.

29 changes: 9 additions & 20 deletions gengokit/httptransport/get_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ func FuncSourceCode(val interface{}) (string, error) {
ptr := reflect.ValueOf(val).Pointer()
fpath, _ := runtime.FuncForPC(ptr).FileLine(ptr)

funcName := runtime.FuncForPC(ptr).Name()
parts := strings.Split(funcName, ".")
funcName = parts[len(parts)-1]
parts := strings.Split(runtime.FuncForPC(ptr).Name(), ".")
funcName := parts[len(parts)-1]

// Parse the go file into the ast
fset := token.NewFileSet()
Expand All @@ -32,13 +31,9 @@ func FuncSourceCode(val interface{}) (string, error) {

// Search ast for function declaration with name of function passed
var fAst *ast.FuncDecl
for _, decs := range fileAst.Decls {
switch decs.(type) {
case *ast.FuncDecl:
f := decs.(*ast.FuncDecl)
if f.Name.String() == funcName {
fAst = f
}
for _, decl := range fileAst.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok && funcName == decl.Name.String() {
fAst = decl
}
}
code := bytes.NewBuffer(nil)
Expand All @@ -59,10 +54,6 @@ func AllFuncSourceCode(val interface{}) (string, error) {
ptr := reflect.ValueOf(val).Pointer()
fpath, _ := runtime.FuncForPC(ptr).FileLine(ptr)

funcName := runtime.FuncForPC(ptr).Name()
parts := strings.Split(funcName, ".")
funcName = parts[len(parts)-1]

// Parse the go file into the ast
fset := token.NewFileSet()
fileAst, err := parser.ParseFile(fset, fpath, nil, parser.ParseComments)
Expand All @@ -71,12 +62,10 @@ func AllFuncSourceCode(val interface{}) (string, error) {
}

// Search ast for all function declarations
fncSlc := []*ast.FuncDecl{}
for _, decs := range fileAst.Decls {
switch decs.(type) {
case *ast.FuncDecl:
f := decs.(*ast.FuncDecl)
fncSlc = append(fncSlc, f)
var fncSlc []*ast.FuncDecl
for _, decl := range fileAst.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok {
fncSlc = append(fncSlc, decl)
}
}

Expand Down
Loading

0 comments on commit 171640b

Please sign in to comment.