diff --git a/chain/mocks_test.go b/chain/mocks_test.go index b740253e61..01c19c31a4 100644 --- a/chain/mocks_test.go +++ b/chain/mocks_test.go @@ -27,8 +27,10 @@ func newMockNeutrinoClient(t *testing.T) *NeutrinoClient { var ( chainParams = &chaincfg.Params{} chainSvc = &mockChainService{} - newRescanner = func(cs neutrino.ChainSource, ro ...neutrino.RescanOption) Rescanner { - return &mockRescanner{} + newRescanner = func(ro ...neutrino.RescanOption) Rescanner { + return &mockRescanner{ + make(chan []neutrino.UpdateOption), + } } ) return &NeutrinoClient{ @@ -38,18 +40,21 @@ func newMockNeutrinoClient(t *testing.T) *NeutrinoClient { } } -type mockRescanner struct{} +type mockRescanner struct { + updateCh chan []neutrino.UpdateOption +} func (m *mockRescanner) Start() <-chan error { - return nil + return make(<-chan error) } func (m *mockRescanner) WaitForShutdown() { - panic(ErrNotImplemented) + // no-op } -func (m *mockRescanner) Update(...neutrino.UpdateOption) error { - return ErrNotImplemented +func (m *mockRescanner) Update(opts ...neutrino.UpdateOption) error { + m.updateCh <- opts + return nil } type mockChainService struct{} diff --git a/chain/neutrino.go b/chain/neutrino.go index 4ee831ca33..c7d89766e1 100644 --- a/chain/neutrino.go +++ b/chain/neutrino.go @@ -46,7 +46,7 @@ type NeutrinoClient struct { // We currently support one rescan/notifiction goroutine per client rescan Rescanner - newRescanner func(neutrino.ChainSource, ...neutrino.RescanOption) Rescanner + newRescanner func(...neutrino.RescanOption) Rescanner enqueueNotification chan interface{} dequeueNotification chan interface{} @@ -88,23 +88,30 @@ func (s *NeutrinoClient) Start() error { s.clientMtx.Lock() defer s.clientMtx.Unlock() + // attempt to start the chain service if err := s.CS.Start(); err != nil { return fmt.Errorf("error starting chain service: %v", err) } if !s.started { + // restart the client state s.enqueueNotification = make(chan interface{}) s.dequeueNotification = make(chan interface{}) s.currentBlock = make(chan *waddrmgr.BlockStamp) s.quit = make(chan struct{}) s.started = true + + // launch the notification handler s.wg.Add(1) go s.notificationHandler() + + // place the client connected notification into the queue select { case s.enqueueNotification <- ClientConnected{}: case <-s.quit: } } + return nil } @@ -112,9 +119,11 @@ func (s *NeutrinoClient) Start() error { func (s *NeutrinoClient) Stop() { s.clientMtx.Lock() defer s.clientMtx.Unlock() + if !s.started { return } + close(s.quit) s.started = false } @@ -415,15 +424,11 @@ func (s *NeutrinoClient) Rescan(startHash *chainhash.Hash, addrs []btcutil.Addre newRescanner := s.getNewRescanner() - s.rescan = newRescanner( - &neutrino.RescanChainSource{ - ChainService: s.CS.(*neutrino.ChainService), - }, - neutrino.NotificationHandlers(rpcclient.NotificationHandlers{ - OnBlockConnected: s.onBlockConnected, - OnFilteredBlockConnected: s.onFilteredBlockConnected, - OnBlockDisconnected: s.onBlockDisconnected, - }), + s.rescan = newRescanner(neutrino.NotificationHandlers(rpcclient.NotificationHandlers{ + OnBlockConnected: s.onBlockConnected, + OnFilteredBlockConnected: s.onFilteredBlockConnected, + OnBlockDisconnected: s.onBlockDisconnected, + }), neutrino.StartBlock(&headerfs.BlockStamp{Hash: *startHash}), neutrino.StartTime(s.startTime), neutrino.QuitChan(s.rescanQuit), @@ -450,6 +455,8 @@ func (s *NeutrinoClient) NotifyBlocks() error { } // NotifyReceived replicates the RPC client's NotifyReceived command. +// +// TODO(mstreet3) error if the client is not started? func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error { s.clientMtx.Lock() defer s.clientMtx.Unlock() @@ -471,15 +478,11 @@ func (s *NeutrinoClient) NotifyReceived(addrs []btcutil.Address) error { // Rescan with just the specified addresses. newRescanner := s.getNewRescanner() - s.rescan = newRescanner( - &neutrino.RescanChainSource{ - ChainService: s.CS.(*neutrino.ChainService), - }, - neutrino.NotificationHandlers(rpcclient.NotificationHandlers{ - OnBlockConnected: s.onBlockConnected, - OnFilteredBlockConnected: s.onFilteredBlockConnected, - OnBlockDisconnected: s.onBlockDisconnected, - }), + s.rescan = newRescanner(neutrino.NotificationHandlers(rpcclient.NotificationHandlers{ + OnBlockConnected: s.onBlockConnected, + OnFilteredBlockConnected: s.onFilteredBlockConnected, + OnBlockDisconnected: s.onBlockDisconnected, + }), neutrino.StartTime(s.startTime), neutrino.QuitChan(s.rescanQuit), neutrino.WatchAddrs(addrs...), @@ -759,9 +762,12 @@ out: // getNewRescanner injects the Rescanner constructor when called and defaults to using neutrino.NewRescan // when unspecified. -func (s *NeutrinoClient) getNewRescanner() func(neutrino.ChainSource, ...neutrino.RescanOption) Rescanner { +func (s *NeutrinoClient) getNewRescanner() func(...neutrino.RescanOption) Rescanner { if s.newRescanner == nil { - s.newRescanner = func(cs neutrino.ChainSource, ropts ...neutrino.RescanOption) Rescanner { + s.newRescanner = func(ropts ...neutrino.RescanOption) Rescanner { + cs := &neutrino.RescanChainSource{ + ChainService: s.CS.(*neutrino.ChainService), + } return neutrino.NewRescan(cs, ropts...) } } diff --git a/chain/neutrino_test.go b/chain/neutrino_test.go index 116b930f75..94dbd35b2f 100644 --- a/chain/neutrino_test.go +++ b/chain/neutrino_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcutil" "github.com/stretchr/testify/require" ) -// TestNeutrinoClientStartStop ensures that the client +// TestNeutrinoClientSequentialStartStop ensures that the client // can sequentially Start and Stop without errors or races. -func TestNeutrinoClientStartStop(t *testing.T) { +func TestNeutrinoClientSequentialStartStop(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) t.Cleanup(cancel) nc := newMockNeutrinoClient(t) @@ -37,3 +38,61 @@ func TestNeutrinoClientStartStop(t *testing.T) { } } } + +// TestNeutrinoClientNotifyReceived verifies that a call to NotifyReceived sets the client into +// the scanning state and that subsequent calls while scanning will call Update on the +// client's Rescanner +func TestNeutrinoClientNotifyReceived(t *testing.T) { + var ( + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + addrs []btcutil.Address + sent = make(chan struct{}) + read = make(chan struct{}) + called = make(chan struct{}) + nc = newMockNeutrinoClient(t) + wantNotifyReceivedCalls = 4 + wantUpdateCalls = wantNotifyReceivedCalls - 1 + gotUC = 0 + ) + t.Cleanup(cancel) + + go func() { + defer close(sent) + for i := 0; i < wantNotifyReceivedCalls; i++ { + err := nc.NotifyReceived(addrs) + require.NoError(t, err) + require.True(t, nc.scanning) + + // signal that NotifyReceived was called on first iteration + if i == 0 { + close(called) + } + } + }() + + // wait until called, then type cast and read from private channel + <-called + mockRescan := nc.rescan.(*mockRescanner) + go func() { + defer close(read) + for { + select { + case <-ctx.Done(): + return + case <-mockRescan.updateCh: + gotUC++ + if gotUC == wantUpdateCalls { + return + } + } + } + }() + + // wait for call to Update or test failure + select { + case <-ctx.Done(): + t.Fatal("timed out") + case <-read: + require.Equal(t, wantUpdateCalls, gotUC) + } +}