-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #499 from oasisprotocol/pro-wh/feature/holders11
analyzer: add pubclient
- Loading branch information
Showing
2 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package pubclient | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"strconv" | ||
"syscall" | ||
"time" | ||
|
||
coreCommon "github.com/oasisprotocol/oasis-core/go/common" | ||
) | ||
|
||
// Use this package for connecting to untrusted URLs. | ||
// It only allows you to connect to globally routable addresses, i.e. public | ||
// IP addresses, not things on your LAN. So it's a client for public | ||
// resources. Anyway, be aware when making changes here, your code will be up | ||
// against untrusted URLs. | ||
|
||
var permittedNetworks = map[string]bool{ | ||
"tcp4": true, | ||
"tcp6": true, | ||
} | ||
|
||
type NotPermittedError struct { | ||
// Note: .error is the implementation of .Error, .Unwrap etc. It is not | ||
// in the Unwrap chain. Use something like | ||
// `NotPermittedError{fmt.Errorf("...: %w", err)}` to set up an | ||
// instance with `err` in the Unwrap chain. | ||
error | ||
} | ||
|
||
func (err NotPermittedError) Is(target error) bool { | ||
if _, ok := target.(NotPermittedError); ok { | ||
return true | ||
} | ||
return false | ||
} | ||
|
||
// client is an *http.Client that permits HTTP(S) connections to hosts that | ||
// oasis-core considers "likely to be globally reachable" on the default | ||
// HTTP(S) ports and unreserved ports. | ||
var client = &http.Client{ | ||
Transport: &http.Transport{ | ||
// Copied from http.DefaultTransport. | ||
Proxy: http.ProxyFromEnvironment, | ||
DialContext: (&net.Dialer{ | ||
// Copied from http.DefaultTransport. | ||
Timeout: 30 * time.Second, | ||
KeepAlive: 30 * time.Second, | ||
// https://www.agwa.name/blog/post/preventing_server_side_request_forgery_in_golang | ||
// Recommends using a net.Dialer Control to interpose on local connections. | ||
Control: func(network, address string, c syscall.RawConn) error { | ||
if !permittedNetworks[network] { | ||
return NotPermittedError{fmt.Errorf("network %s not permitted", network)} | ||
} | ||
host, portStr, err := net.SplitHostPort(address) | ||
if err != nil { | ||
return NotPermittedError{fmt.Errorf("net.SplitHostPort %s: %w", address, err)} | ||
} | ||
ip := net.ParseIP(host) | ||
if ip == nil { | ||
return NotPermittedError{fmt.Errorf("IP %s not valid", ip)} | ||
} | ||
if !coreCommon.IsProbablyGloballyReachable(ip) { | ||
return NotPermittedError{fmt.Errorf("IP %s not permitted", ip)} | ||
} | ||
port, err := strconv.ParseUint(portStr, 10, 16) | ||
if err != nil { | ||
return NotPermittedError{fmt.Errorf("strconv.ParseUint %s: %w", portStr, err)} | ||
} | ||
if port != 443 && port != 80 && port < 1024 { | ||
return NotPermittedError{fmt.Errorf("port %d not permitted", port)} | ||
} | ||
return nil | ||
}, | ||
}).DialContext, | ||
// Copied from http.DefaultTransport. | ||
ForceAttemptHTTP2: true, | ||
MaxIdleConns: 100, | ||
IdleConnTimeout: 90 * time.Second, | ||
TLSHandshakeTimeout: 10 * time.Second, | ||
ExpectContinueTimeout: 1 * time.Second, | ||
}, | ||
Timeout: 30 * time.Second, | ||
} | ||
|
||
func getWithContextWithClient(ctx context.Context, client *http.Client, url string) (*http.Response, error) { | ||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return client.Do(req) | ||
} | ||
|
||
func GetWithContext(ctx context.Context, url string) (*http.Response, error) { | ||
return getWithContextWithClient(ctx, client, url) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package pubclient | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net/http" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func wasteResp(resp *http.Response) error { | ||
_, err := io.Copy(io.Discard, resp.Body) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func requireErrorAndWaste(t *testing.T, resp *http.Response, err error) { | ||
if err == nil { | ||
require.NoError(t, wasteResp(resp)) | ||
} | ||
require.Error(t, err) | ||
} | ||
|
||
func requireNoErrorAndWaste(t *testing.T, resp *http.Response, err error) { | ||
if err == nil { | ||
require.NoError(t, wasteResp(resp)) | ||
} | ||
require.NoError(t, err) | ||
} | ||
|
||
func TestMisc(t *testing.T) { | ||
var requested bool | ||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
requested = true | ||
}) | ||
testServer := http.Server{ | ||
Addr: "127.0.0.1:8001", | ||
Handler: handler, | ||
ReadHeaderTimeout: 30 * time.Second, | ||
} | ||
serverErr := make(chan error) | ||
go func() { | ||
serverErr <- testServer.ListenAndServe() | ||
}() | ||
testServer6 := http.Server{ | ||
Addr: "[::1]:8001", | ||
Handler: handler, | ||
ReadHeaderTimeout: 30 * time.Second, | ||
} | ||
serverErr6 := make(chan error) | ||
go func() { | ||
serverErr6 <- testServer6.ListenAndServe() | ||
}() | ||
ctx := context.Background() | ||
|
||
// Default client should reach local server. This makes sure the test server is working. | ||
resp, err := getWithContextWithClient(ctx, http.DefaultClient, "http://localhost:8001/test.json") | ||
requireNoErrorAndWaste(t, resp, err) | ||
require.True(t, requested) | ||
requested = false | ||
resp, err = getWithContextWithClient(ctx, http.DefaultClient, "http://[::1]:8001/test.json") | ||
requireNoErrorAndWaste(t, resp, err) | ||
require.True(t, requested) | ||
requested = false | ||
|
||
// Hostname of test server | ||
resp, err = GetWithContext(ctx, "http://localhost:8001/test.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
|
||
// IP address of test server | ||
resp, err = GetWithContext(ctx, "http://127.0.0.1:8001/test.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
resp, err = GetWithContext(ctx, "http://[::1]:8001/test.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
|
||
// Server that redirects to test server | ||
// Warning: external network dependency | ||
resp, err = GetWithContext(ctx, "https://httpbin.org/redirect-to?url=http%3A%2F%2F127.0.0.1%3A8001%2Ftest.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
resp, err = GetWithContext(ctx, "https://httpbin.org/redirect-to?url=http%3A%2F%2F%5B%3A%3A1%5D%3A8001%2Ftest.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
|
||
// Domain that resolves to test server | ||
// Warning: external network dependency | ||
resp, err = GetWithContext(ctx, "http://127.0.0.1.nip.io:8001/test.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
resp, err = GetWithContext(ctx, "http://0--1.sslip.io:8001/test.json") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
require.False(t, requested) | ||
|
||
// Well known port other than HTTP(S) | ||
resp, err = GetWithContext(ctx, "http://smtp.google.com:25/") | ||
requireErrorAndWaste(t, resp, err) | ||
require.ErrorIs(t, err, NotPermittedError{}) | ||
|
||
// Other requests ought to work. | ||
// Warning: external network dependency | ||
resp, err = GetWithContext(ctx, "https://www.example.com/") | ||
requireNoErrorAndWaste(t, resp, err) | ||
|
||
err = testServer.Shutdown(ctx) | ||
require.NoError(t, err) | ||
require.ErrorIs(t, <-serverErr, http.ErrServerClosed) | ||
err = testServer6.Shutdown(ctx) | ||
require.NoError(t, err) | ||
require.ErrorIs(t, <-serverErr6, http.ErrServerClosed) | ||
} |