diff --git a/sangrenel.go b/sangrenel.go index b273124..6529b98 100644 --- a/sangrenel.go +++ b/sangrenel.go @@ -22,8 +22,12 @@ package main import ( + "bufio" + "crypto/tls" + "crypto/x509" "flag" "fmt" + "io" "log" "math/rand" "os" @@ -49,6 +53,9 @@ var ( clients int producers int noop bool + tlsconfig *tls.Config + + source MessageSource // Character selection for random messages. chars = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$^&*(){}][:<>.") @@ -59,6 +66,87 @@ var ( sentCntr = make(chan int64, 1) ) +type MessageSource interface { + GetMessage() []byte + Clone() MessageSource +} + +type RandomMessageSource struct { + generator *rand.Rand + buffer []byte +} + +func NewRandomMessageSource() *RandomMessageSource { + source := rand.NewSource(time.Now().UnixNano()) + return &RandomMessageSource{ + generator: rand.New(source), + buffer: make([]byte, msgSize), + } +} + +func (source *RandomMessageSource) GetMessage() []byte { + for i := range source.buffer { + source.buffer[i] = chars[source.generator.Intn(len(chars))] + } + return source.buffer +} + +func (source *RandomMessageSource) Clone() MessageSource { + s := rand.NewSource(time.Now().UnixNano()) + return &RandomMessageSource{ + generator: rand.New(s), + buffer: make([]byte, msgSize), + } +} + +type ReplayMessageSource struct { + lines [][]byte + index int +} + +func NewReplayMessageSource(path string) (*ReplayMessageSource, error) { + handle, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("Could not open data file %s for replay: %v", path, err) + } + + lines := make([][]byte, 0, 100) + reader := bufio.NewReader(handle) + for { + line, err := reader.ReadBytes('\n') + if err == io.EOF { + break + } else if err != nil { + return nil, fmt.Errorf("Error reading from data file %s: %v", path, err) + } else { + lines = append(lines, line) + } + } + + return &ReplayMessageSource{ + lines: lines, + index: 0, + }, nil +} + +func (source *ReplayMessageSource) Clone() MessageSource { + return &ReplayMessageSource{ + lines: source.lines, + index: 0, + } +} + +func (source *ReplayMessageSource) GetMessage() []byte { + if source.index >= len(source.lines) { + source.index = 0 + } + line := source.lines[source.index] + buffer := make([]byte, len(line)) + copy(buffer, line) + source.index++ + return buffer +} + func init() { flag.StringVar(&topic, "topic", "sangrenel", "Topic to publish to") flag.IntVar(&msgSize, "size", 300, "Message size in bytes") @@ -68,7 +156,11 @@ func init() { flag.BoolVar(&noop, "noop", false, "Test message generation performance, do not transmit messages") flag.IntVar(&clients, "clients", 1, "Number of Kafka client workers") flag.IntVar(&producers, "producers", 5, "Number of producer instances per client") + dataPath := flag.String("data", "", "File of lines that each producer should send to the broker") brokerString := flag.String("brokers", "localhost:9092", "Comma delimited list of Kafka brokers") + clientCertPath := flag.String("cert", "", "Path to TLS client certificate in PEM format") + clientKeyPath := flag.String("key", "", "Path to TLS client private key in PEM format") + caPath := flag.String("ca", "", "Path to CA root certificate in PEM format") flag.Parse() brokers = strings.Split(*brokerString, ",") @@ -85,6 +177,69 @@ func init() { os.Exit(1) } + // Select the proper message source based on command line options. + if len(*dataPath) == 0 { + fmt.Printf("Writing random strings of %d bytes.\n", msgSize) + source = NewRandomMessageSource() + } else { + fmt.Printf("Writing data from %s.\n", *dataPath) + var err error + source, err = NewReplayMessageSource(*dataPath) + if err != nil { + log.Println(err) + os.Exit(1) + } + } + + // Build TLS configuration if command line options are specified. + hasCert := len(*clientCertPath) > 0 + hasKey := len(*clientKeyPath) > 0 + hasCA := len(*caPath) > 0 + if (hasCert || hasKey || hasCA) != (hasCert && hasKey && hasCA) { + fmt.Printf("Must specify all three of cert, key, and ca, or none.\n") + os.Exit(1) + } else if hasCert { // Build TLS config + cert, err := tls.LoadX509KeyPair(*clientCertPath, *clientKeyPath) + if err != nil { + fmt.Printf("Failed to load key pair from cert file %s and key file %s: %v\n", *clientCertPath, *clientKeyPath, err) + os.Exit(1) + } + + h, err := os.Open(*caPath) + if err != nil { + fmt.Printf("Could not open CA %s: %v\n", *caPath, err) + os.Exit(1) + } + defer h.Close() + fi, err := h.Stat() + if err != nil { + fmt.Printf("Could not stat %s: %v\n", *caPath, err) + os.Exit(1) + } + certBuffer := make([]byte, fi.Size()) + n, err := h.Read(certBuffer) + if err != nil { + fmt.Printf("Could not read from %s: %v\n", *caPath, err) + os.Exit(1) + } + if n != int(fi.Size()) { + fmt.Printf("Bytes read didn't match file size in %s: expected %d, read %d\n", *caPath, fi.Size(), n) + os.Exit(1) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(certBuffer) { + fmt.Printf("No certs found in %s.\n", *caPath) + os.Exit(1) + } + + tlsconfig = &tls.Config{ + InsecureSkipVerify: true, + RootCAs: pool, + Certificates: []tls.Certificate{cert}, + } + } + sentCntr <- 0 } @@ -98,9 +253,7 @@ func clientProducer(c kafka.Client, t *tachymeter.Tachymeter) { } defer producer.Close() - source := rand.NewSource(time.Now().UnixNano()) - generator := rand.New(source) - msgData := make([]byte, msgSize) + localSource := source.Clone() // Use a local accumulator then periodically update global counter. // Global counter can become a bottleneck with too many threads. @@ -117,7 +270,7 @@ func clientProducer(c kafka.Client, t *tachymeter.Tachymeter) { countStart := fetchSent() var start time.Time for fetchSent()-countStart < msgRate { - randMsg(msgData, *generator) + msgData := localSource.GetMessage() msg := &kafka.ProducerMessage{Topic: topic, Value: kafka.ByteEncoder(msgData)} start = time.Now() @@ -147,16 +300,14 @@ func clientProducer(c kafka.Client, t *tachymeter.Tachymeter) { // clientDummyProducer is a dummy function that kafkaClient calls if noop is True. // It is used in place of starting actual Kafka client connections to test message creation performance. func clientDummyProducer(t *tachymeter.Tachymeter) { - source := rand.NewSource(time.Now().UnixNano()) - generator := rand.New(source) - msg := make([]byte, msgSize) + localSource := source.Clone() var n int64 var times [10]time.Duration for { start := time.Now() - randMsg(msg, *generator) + localSource.GetMessage() // Increment global counter and // tachymeter every 10 messages. @@ -185,8 +336,12 @@ func kafkaClient(n int, t *tachymeter.Tachymeter) { conf.Producer.Compression = compression } conf.Producer.Flush.MaxMessages = batchSize + conf.Producer.MaxMessageBytes = 1024 * 1024 * 10 - conf.Producer.MaxMessageBytes = msgSize + 50 + if tlsconfig != nil { + conf.Net.TLS.Enable = true + conf.Net.TLS.Config = tlsconfig + } client, err := kafka.NewClient(brokers, conf) if err != nil { @@ -208,14 +363,6 @@ func kafkaClient(n int, t *tachymeter.Tachymeter) { <-killClients } -// Returns a random message generated from the chars byte slice. -// Message length of m bytes as defined by msgSize. -func randMsg(m []byte, generator rand.Rand) { - for i := range m { - m[i] = chars[generator.Intn(len(chars))] - } -} - // Global counter functions. func incrSent(n int64) { i := <-sentCntr