243 lines
5.3 KiB
Go
243 lines
5.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strconv"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
func WithFirefox(ctx context.Context, fn func(*net.TCPConn) error) (err error) {
|
|
maybeSetErr := func(_err error) {
|
|
if err == nil && _err != nil {
|
|
err = _err
|
|
}
|
|
}
|
|
|
|
profile, err := os.MkdirTemp("", "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { maybeSetErr(os.RemoveAll(profile)) }()
|
|
|
|
cmd := exec.CommandContext(ctx, "firefox",
|
|
"--new-instance",
|
|
"--profile", profile,
|
|
"--marionette")
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
if err := cmd.Start(); err != nil {
|
|
return err
|
|
}
|
|
|
|
grp, ctx := errgroup.WithContext(ctx)
|
|
|
|
grp.Go(func() error {
|
|
return cmd.Wait()
|
|
})
|
|
|
|
grp.Go(func() (err error) {
|
|
maybeSetErr := func(_err error) {
|
|
if err == nil && _err != nil {
|
|
err = _err
|
|
}
|
|
}
|
|
|
|
defer func() { cmd.Process.Signal(os.Interrupt) }()
|
|
|
|
var conn *net.TCPConn
|
|
for conn == nil {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
_conn, _ := net.Dial("tcp", "localhost:2828")
|
|
conn, _ = _conn.(*net.TCPConn)
|
|
}
|
|
}
|
|
log.Println("connected")
|
|
defer func() { maybeSetErr(conn.Close()) }()
|
|
|
|
return fn(conn)
|
|
})
|
|
|
|
return grp.Wait()
|
|
}
|
|
|
|
func receive[T any](conn io.Reader) (T, error) {
|
|
var zero T
|
|
|
|
// Packet format is length-prefixed JSON:
|
|
//
|
|
// packet = digit{1,10}+ ":" body
|
|
// digit = "0"-"9"
|
|
// body = JSON text
|
|
|
|
head := make([]byte, 11)
|
|
// Shortest possible message is `2:{}`, so a 4-byte initial read is always safe.
|
|
headLen := 4
|
|
if _, err := io.ReadFull(conn, head[:4]); err != nil {
|
|
return zero, err
|
|
}
|
|
colon := bytes.IndexByte(head[:4], ':')
|
|
// If that wasn't enough, read another 7 bytes.
|
|
if colon < 0 {
|
|
n, err := io.ReadFull(conn, head[4:])
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
headLen += n
|
|
}
|
|
colon = bytes.IndexByte(head, ':')
|
|
if colon < 0 {
|
|
return zero, fmt.Errorf("invalid message length: %q: length exceeded", head)
|
|
}
|
|
|
|
bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32)
|
|
if err != nil {
|
|
return zero, fmt.Errorf("invalid message length: %q: %w", head, bodyLen)
|
|
}
|
|
|
|
body := make([]byte, bodyLen)
|
|
n := copy(body, head[colon+1:headLen])
|
|
if _, err := io.ReadFull(conn, body[n:]); err != nil {
|
|
return zero, err
|
|
}
|
|
|
|
var val T
|
|
if err := json.Unmarshal(body, &val); err != nil {
|
|
return zero, err
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
func send(conn io.Writer, obj any) error {
|
|
jsonBytes, err := json.Marshal(obj)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
msgBytes := make([]byte, 0, 11+len(jsonBytes))
|
|
|
|
msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10)
|
|
msgBytes = append(msgBytes, ':')
|
|
msgBytes = append(msgBytes, jsonBytes...)
|
|
|
|
if _, err := conn.Write(msgBytes); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type MsgHello struct {
|
|
ApplicationType string `json:"applicationType"`
|
|
MarionetteProtocol int `json:"marionetteProtocol"`
|
|
}
|
|
|
|
func sendCommand(conn io.Writer, id uint32, method string, parameters any) error {
|
|
return send(conn, []any{
|
|
0, // type=command
|
|
id,
|
|
method,
|
|
parameters,
|
|
})
|
|
}
|
|
|
|
type MsgResponse[T any] struct {
|
|
ID uint32
|
|
Error *ErrorObject
|
|
Result T
|
|
}
|
|
|
|
type ErrorObject struct {
|
|
ErrorCode string `json:"error"`
|
|
Message string `json:"message"`
|
|
StackTrace string `json:"stacktrace"`
|
|
}
|
|
|
|
func (msg *MsgResponse[T]) UnmarshalJSON(dat []byte) error {
|
|
var ary []json.RawMessage
|
|
if err := json.Unmarshal(dat, &ary); err != nil {
|
|
return err
|
|
}
|
|
if len(ary) != 4 {
|
|
return fmt.Errorf("responses should have 4 fields, but got %v", len(ary))
|
|
}
|
|
var typ int
|
|
if err := json.Unmarshal(ary[0], &typ); err != nil || typ != 1 {
|
|
return fmt.Errorf("invalid response field=type: %q", ary[0])
|
|
}
|
|
if err := json.Unmarshal(ary[1], &msg.ID); err != nil || typ != 1 {
|
|
return fmt.Errorf("invalid response field=ID: %q", ary[1])
|
|
}
|
|
if err := json.Unmarshal(ary[2], &msg.Error); err != nil || typ != 1 {
|
|
return fmt.Errorf("invalid response field=error: %q", ary[2])
|
|
}
|
|
if err := json.Unmarshal(ary[3], &msg.Result); err != nil || typ != 1 {
|
|
return fmt.Errorf("invalid response field=result: %q", ary[3])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) {
|
|
var zero T
|
|
if err := sendCommand(conn, id, method, params); err != nil {
|
|
return zero, err
|
|
}
|
|
for {
|
|
resp, err := receive[MsgResponse[T]](conn)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
if resp.ID == id {
|
|
if resp.Error != nil {
|
|
return zero, fmt.Errorf("error code: %#v", *resp.Error)
|
|
}
|
|
return resp.Result, nil
|
|
}
|
|
log.Printf("discarding response: %#v", resp)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
if err := WithFirefox(context.Background(), func(c *net.TCPConn) error {
|
|
hello, err := receive[MsgHello](c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if hello.ApplicationType != "gecko" || hello.MarionetteProtocol != 3 {
|
|
return fmt.Errorf("I only know how to speak marionette protocol v3 to gecko, but got response: %#v", hello)
|
|
}
|
|
resp1, err := doCommand[any](c, 1, "WebDriver:NewSession", map[string]bool{
|
|
"strictFileInteractability": true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("resp1 = %#v", resp1)
|
|
|
|
resp2, err := doCommand[any](c, 2, "Marionette:Quit", struct{}{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("resp2 = %#v", resp2)
|
|
io.Copy(io.Discard, c)
|
|
time.Sleep(5 * time.Second)
|
|
return nil
|
|
}); err != nil {
|
|
fmt.Fprintf(os.Stderr, "%s: error: %v\n", os.Args[0], err)
|
|
os.Exit(1)
|
|
}
|
|
}
|