ffctl/scrape.go

243 lines
5.3 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"strconv"
"time"
"golang.org/x/sync/errgroup"
)
func WithFirefox(ctx context.Context, fn func(*net.TCPConn) error) (err error) {
maybeSetErr := func(_err error) {
if err == nil && _err != nil {
err = _err
}
}
profile, err := os.MkdirTemp("", "")
if err != nil {
return err
}
defer func() { maybeSetErr(os.RemoveAll(profile)) }()
cmd := exec.CommandContext(ctx, "firefox",
"--new-instance",
"--profile", profile,
"--marionette")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return err
}
grp, ctx := errgroup.WithContext(ctx)
grp.Go(func() error {
return cmd.Wait()
})
grp.Go(func() (err error) {
maybeSetErr := func(_err error) {
if err == nil && _err != nil {
err = _err
}
}
defer func() { cmd.Process.Signal(os.Interrupt) }()
var conn *net.TCPConn
for conn == nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
_conn, _ := net.Dial("tcp", "localhost:2828")
conn, _ = _conn.(*net.TCPConn)
}
}
log.Println("connected")
defer func() { maybeSetErr(conn.Close()) }()
return fn(conn)
})
return grp.Wait()
}
func receive[T any](conn io.Reader) (T, error) {
var zero T
// Packet format is length-prefixed JSON:
//
// packet = digit{1,10}+ ":" body
// digit = "0"-"9"
// body = JSON text
head := make([]byte, 11)
// Shortest possible message is `2:{}`, so a 4-byte initial read is always safe.
headLen := 4
if _, err := io.ReadFull(conn, head[:4]); err != nil {
return zero, err
}
colon := bytes.IndexByte(head[:4], ':')
// If that wasn't enough, read another 7 bytes.
if colon < 0 {
n, err := io.ReadFull(conn, head[4:])
if err != nil {
return zero, err
}
headLen += n
}
colon = bytes.IndexByte(head, ':')
if colon < 0 {
return zero, fmt.Errorf("invalid message length: %q: length exceeded", head)
}
bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32)
if err != nil {
return zero, fmt.Errorf("invalid message length: %q: %w", head, bodyLen)
}
body := make([]byte, bodyLen)
n := copy(body, head[colon+1:headLen])
if _, err := io.ReadFull(conn, body[n:]); err != nil {
return zero, err
}
var val T
if err := json.Unmarshal(body, &val); err != nil {
return zero, err
}
return val, nil
}
func send(conn io.Writer, obj any) error {
jsonBytes, err := json.Marshal(obj)
if err != nil {
return err
}
msgBytes := make([]byte, 0, 11+len(jsonBytes))
msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10)
msgBytes = append(msgBytes, ':')
msgBytes = append(msgBytes, jsonBytes...)
if _, err := conn.Write(msgBytes); err != nil {
return err
}
return nil
}
type MsgHello struct {
ApplicationType string `json:"applicationType"`
MarionetteProtocol int `json:"marionetteProtocol"`
}
func sendCommand(conn io.Writer, id uint32, method string, parameters any) error {
return send(conn, []any{
0, // type=command
id,
method,
parameters,
})
}
type MsgResponse[T any] struct {
ID uint32
Error *ErrorObject
Result T
}
type ErrorObject struct {
ErrorCode string `json:"error"`
Message string `json:"message"`
StackTrace string `json:"stacktrace"`
}
func (msg *MsgResponse[T]) UnmarshalJSON(dat []byte) error {
var ary []json.RawMessage
if err := json.Unmarshal(dat, &ary); err != nil {
return err
}
if len(ary) != 4 {
return fmt.Errorf("responses should have 4 fields, but got %v", len(ary))
}
var typ int
if err := json.Unmarshal(ary[0], &typ); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=type: %q", ary[0])
}
if err := json.Unmarshal(ary[1], &msg.ID); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=ID: %q", ary[1])
}
if err := json.Unmarshal(ary[2], &msg.Error); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=error: %q", ary[2])
}
if err := json.Unmarshal(ary[3], &msg.Result); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=result: %q", ary[3])
}
return nil
}
func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) {
var zero T
if err := sendCommand(conn, id, method, params); err != nil {
return zero, err
}
for {
resp, err := receive[MsgResponse[T]](conn)
if err != nil {
return zero, err
}
if resp.ID == id {
if resp.Error != nil {
return zero, fmt.Errorf("error code: %#v", *resp.Error)
}
return resp.Result, nil
}
log.Printf("discarding response: %#v", resp)
}
}
func main() {
if err := WithFirefox(context.Background(), func(c *net.TCPConn) error {
hello, err := receive[MsgHello](c)
if err != nil {
return err
}
if hello.ApplicationType != "gecko" || hello.MarionetteProtocol != 3 {
return fmt.Errorf("I only know how to speak marionette protocol v3 to gecko, but got response: %#v", hello)
}
resp1, err := doCommand[any](c, 1, "WebDriver:NewSession", map[string]bool{
"strictFileInteractability": true,
})
if err != nil {
return err
}
log.Printf("resp1 = %#v", resp1)
resp2, err := doCommand[any](c, 2, "Marionette:Quit", struct{}{})
if err != nil {
return err
}
log.Printf("resp2 = %#v", resp2)
io.Copy(io.Discard, c)
time.Sleep(5 * time.Second)
return nil
}); err != nil {
fmt.Fprintf(os.Stderr, "%s: error: %v\n", os.Args[0], err)
os.Exit(1)
}
}