package main import ( "bytes" "context" "encoding/json" "fmt" "io" "log" "net" "os" "os/exec" "strconv" "time" "golang.org/x/sync/errgroup" ) func WithFirefox(ctx context.Context, fn func(*net.TCPConn) error) (err error) { maybeSetErr := func(_err error) { if err == nil && _err != nil { err = _err } } profile, err := os.MkdirTemp("", "") if err != nil { return err } defer func() { maybeSetErr(os.RemoveAll(profile)) }() cmd := exec.CommandContext(ctx, "firefox", "--new-instance", "--profile", profile, "--marionette") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { return err } grp, ctx := errgroup.WithContext(ctx) grp.Go(func() error { return cmd.Wait() }) grp.Go(func() (err error) { maybeSetErr := func(_err error) { if err == nil && _err != nil { err = _err } } defer func() { cmd.Process.Signal(os.Interrupt) }() var conn *net.TCPConn for conn == nil { select { case <-ctx.Done(): return ctx.Err() default: _conn, _ := net.Dial("tcp", "localhost:2828") conn, _ = _conn.(*net.TCPConn) } } log.Println("connected") defer func() { maybeSetErr(conn.Close()) }() return fn(conn) }) return grp.Wait() } func receive[T any](conn io.Reader) (T, error) { var zero T // Packet format is length-prefixed JSON: // // packet = digit{1,10}+ ":" body // digit = "0"-"9" // body = JSON text head := make([]byte, 11) // Shortest possible message is `2:{}`, so a 4-byte initial read is always safe. headLen := 4 if _, err := io.ReadFull(conn, head[:4]); err != nil { return zero, err } colon := bytes.IndexByte(head[:4], ':') // If that wasn't enough, read another 7 bytes. if colon < 0 { n, err := io.ReadFull(conn, head[4:]) if err != nil { return zero, err } headLen += n } colon = bytes.IndexByte(head, ':') if colon < 0 { return zero, fmt.Errorf("invalid message length: %q: length exceeded", head) } bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32) if err != nil { return zero, fmt.Errorf("invalid message length: %q: %w", head, bodyLen) } body := make([]byte, bodyLen) n := copy(body, head[colon+1:headLen]) if _, err := io.ReadFull(conn, body[n:]); err != nil { return zero, err } var val T if err := json.Unmarshal(body, &val); err != nil { return zero, err } return val, nil } func send(conn io.Writer, obj any) error { jsonBytes, err := json.Marshal(obj) if err != nil { return err } msgBytes := make([]byte, 0, 11+len(jsonBytes)) msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10) msgBytes = append(msgBytes, ':') msgBytes = append(msgBytes, jsonBytes...) if _, err := conn.Write(msgBytes); err != nil { return err } return nil } type MsgHello struct { ApplicationType string `json:"applicationType"` MarionetteProtocol int `json:"marionetteProtocol"` } func sendCommand(conn io.Writer, id uint32, method string, parameters any) error { return send(conn, []any{ 0, // type=command id, method, parameters, }) } type MsgResponse[T any] struct { ID uint32 Error *ErrorObject Result T } type ErrorObject struct { ErrorCode string `json:"error"` Message string `json:"message"` StackTrace string `json:"stacktrace"` } func (msg *MsgResponse[T]) UnmarshalJSON(dat []byte) error { var ary []json.RawMessage if err := json.Unmarshal(dat, &ary); err != nil { return err } if len(ary) != 4 { return fmt.Errorf("responses should have 4 fields, but got %v", len(ary)) } var typ int if err := json.Unmarshal(ary[0], &typ); err != nil || typ != 1 { return fmt.Errorf("invalid response field=type: %q", ary[0]) } if err := json.Unmarshal(ary[1], &msg.ID); err != nil || typ != 1 { return fmt.Errorf("invalid response field=ID: %q", ary[1]) } if err := json.Unmarshal(ary[2], &msg.Error); err != nil || typ != 1 { return fmt.Errorf("invalid response field=error: %q", ary[2]) } if err := json.Unmarshal(ary[3], &msg.Result); err != nil || typ != 1 { return fmt.Errorf("invalid response field=result: %q", ary[3]) } return nil } func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) { var zero T if err := sendCommand(conn, id, method, params); err != nil { return zero, err } for { resp, err := receive[MsgResponse[T]](conn) if err != nil { return zero, err } if resp.ID == id { if resp.Error != nil { return zero, fmt.Errorf("error code: %#v", *resp.Error) } return resp.Result, nil } log.Printf("discarding response: %#v", resp) } } func main() { if err := WithFirefox(context.Background(), func(c *net.TCPConn) error { hello, err := receive[MsgHello](c) if err != nil { return err } if hello.ApplicationType != "gecko" || hello.MarionetteProtocol != 3 { return fmt.Errorf("I only know how to speak marionette protocol v3 to gecko, but got response: %#v", hello) } resp1, err := doCommand[any](c, 1, "WebDriver:NewSession", map[string]bool{ "strictFileInteractability": true, }) if err != nil { return err } log.Printf("resp1 = %#v", resp1) resp2, err := doCommand[any](c, 2, "Marionette:Quit", struct{}{}) if err != nil { return err } log.Printf("resp2 = %#v", resp2) io.Copy(io.Discard, c) time.Sleep(5 * time.Second) return nil }); err != nil { fmt.Fprintf(os.Stderr, "%s: error: %v\n", os.Args[0], err) os.Exit(1) } }