diff --git a/route/router.go b/route/router.go index 07d264d..12ef1a8 100644 --- a/route/router.go +++ b/route/router.go @@ -45,32 +45,33 @@ func (r *Router) Initialize(ctx context.Context, logger *log.Logger, options Rou if r.started { return errors.New("already initialized") } - r.ctx = ctx - r.logger = logger - r.outboundMap = options.OutboundMap - r.listMap = options.ListMap - r.ruleRegistry = options.RuleRegistry - r.snifferRegistry = options.SnifferRegistry - r.rules = make([]Rule, 0, len(options.Config.Rules)) + rules := make([]Rule, 0, len(options.Config.Rules)) for i, ruleConfig := range options.Config.Rules { - rule, err := NewRule(logger, ruleConfig, r.listMap, r.ruleRegistry) + rule, err := NewRule(logger, ruleConfig, options.ListMap, options.RuleRegistry) if err != nil { return fmt.Errorf("initialize rule [index=%d]: %w", i, err) } - r.rules = append(r.rules, rule) + rules = append(rules, rule) } if options.Config.DefaultOutbound != "" { - var err error - r.defaultOutbound, err = r.FindOutboundByName(options.Config.DefaultOutbound) + defaultOutbound, err := r.findOutboundByName(options.OutboundMap, options.Config.DefaultOutbound) if err != nil { return common.Cause("default outbound is not found: ", err) } + r.defaultOutbound = defaultOutbound } else { r.defaultOutbound, _ = protocol.NewOutbound(r.logger, &config.Outbound{ Name: "default", }) - r.defaultOutbound.PostInitialize(r) + r.defaultOutbound.PostInitialize(r) // this is dangerous since the router is not fully initialized yet } + r.ctx = ctx + r.logger = logger + r.outboundMap = options.OutboundMap + r.listMap = options.ListMap + r.ruleRegistry = options.RuleRegistry + r.snifferRegistry = options.SnifferRegistry + r.rules = rules r.started = true return nil } @@ -169,16 +170,20 @@ func (r *Router) HandleConnection(conn net.Conn, metadata *adapter.Metadata) { } func (r *Router) FindOutboundByName(name string) (adapter.Outbound, error) { + return r.findOutboundByName(r.outboundMap, name) +} + +func (r *Router) findOutboundByName(outboundMap map[string]adapter.Outbound, name string) (adapter.Outbound, error) { switch name { case "REJECT": return rejectOutbound{}, nil case "RESET": return resetOutbound{}, nil } - if r.outboundMap == nil { + if outboundMap == nil { return nil, errors.New("outbounds are not initialized") } - outbound, ok := r.outboundMap[name] + outbound, ok := outboundMap[name] if !ok { return nil, fmt.Errorf("outbound not found [%s]", name) }