ffctl/lib/marionette/transport.go

212 lines
5.0 KiB
Go

package marionette
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
)
// Low-level send/receive //////////////////////////////////////////////////////
func send(conn io.Writer, obj any) error {
jsonBytes, err := json.Marshal(obj)
if err != nil {
return err
}
msgBytes := make([]byte, 0, 11+len(jsonBytes))
msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10)
msgBytes = append(msgBytes, ':')
msgBytes = append(msgBytes, jsonBytes...)
if _, err := conn.Write(msgBytes); err != nil {
return err
}
return nil
}
func receive[T any](conn io.Reader) (T, error) {
var zero T
// Packet format is length-prefixed JSON:
//
// packet = digit{1,10}+ ":" body
// digit = "0"-"9"
// body = JSON text
head := make([]byte, 11)
// Shortest possible message is `2:{}`, so a 4-byte initial read is always safe.
headLen := 4
if _, err := io.ReadFull(conn, head[:4]); err != nil {
return zero, err
}
colon := bytes.IndexByte(head[:4], ':')
// If that wasn't enough, read another 7 bytes.
if colon < 0 {
n, err := io.ReadFull(conn, head[4:])
if err != nil {
return zero, err
}
headLen += n
}
colon = bytes.IndexByte(head, ':')
if colon < 0 {
return zero, fmt.Errorf("invalid message length: %q: length exceeded", head)
}
bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32)
if err != nil {
return zero, fmt.Errorf("invalid message length: %q: %w", head[:colon], err)
}
body := make([]byte, bodyLen)
n := copy(body, head[colon+1:headLen])
if _, err := io.ReadFull(conn, body[n:]); err != nil {
return zero, err
}
var val T
if err := json.Unmarshal(body, &val); err != nil {
return zero, err
}
return val, nil
}
// read hello //////////////////////////////////////////////////////////////////
type msgHello struct {
ApplicationType string `json:"applicationType"`
MarionetteProtocol int `json:"marionetteProtocol"`
}
func readHello(conn io.Reader) error {
hello, err := receive[msgHello](conn)
if err != nil {
return err
}
if hello.ApplicationType != "gecko" || hello.MarionetteProtocol != 3 {
return fmt.Errorf("I only know how to speak marionette protocol v3 to gecko, but got response: %#v", hello)
}
return nil
}
// request/response semantics //////////////////////////////////////////////////
type errorObject struct {
ErrorCode string `json:"error"`
Message string `json:"message"`
StackTrace string `json:"stacktrace"`
}
var _ error = (*errorObject)(nil)
// Error implements [error].
func (e *errorObject) Error() string {
return fmt.Sprintf("{\n ErrorCode: %q,\n Message: %q\n StackTrace: |\n%s\n}", e.ErrorCode, e.Message, "\t"+strings.ReplaceAll(strings.TrimRight(e.StackTrace, "\n"), "\n", "\n\t"))
}
type msgResponse[T any] struct {
ID uint32
Error *errorObject
Result T
}
var _ json.Unmarshaler = (*msgResponse[any])(nil)
// UnmarshalJSON implements [json.Unmarshaler].
func (msg *msgResponse[T]) UnmarshalJSON(dat []byte) error {
var ary []json.RawMessage
if err := json.Unmarshal(dat, &ary); err != nil {
return err
}
if len(ary) != 4 {
return fmt.Errorf("responses should have 4 fields, but got %v", len(ary))
}
var typ int
if err := json.Unmarshal(ary[0], &typ); err != nil {
return fmt.Errorf("invalid response field=type: %w: %q", err, ary[0])
}
if typ != 1 {
return fmt.Errorf("invalid response field=type: not 1: %q", ary[0])
}
if err := json.Unmarshal(ary[1], &msg.ID); err != nil {
return fmt.Errorf("invalid response field=ID: %w: %q", err, ary[1])
}
if err := json.Unmarshal(ary[2], &msg.Error); err != nil {
return fmt.Errorf("invalid response field=error: %w: %q", err, ary[2])
}
if err := json.Unmarshal(ary[3], &msg.Result); err != nil {
return fmt.Errorf("invalid response field=result: %w: %q", err, ary[3])
}
return nil
}
func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) {
var zero T
if err := send(conn, []any{
0, // type=command
id,
method,
params,
}); err != nil {
return zero, err
}
for {
resp, err := receive[msgResponse[T]](conn)
if err != nil {
return zero, fmt.Errorf("%s: %w", method, err)
}
if resp.ID == id {
if resp.Error != nil {
return zero, resp.Error
}
return resp.Result, nil
}
log.Printf("discarding response: %#v", resp)
}
}
// High-level wrapper //////////////////////////////////////////////////////////
type TCPTransport struct {
conn *net.TCPConn
lastCmdID uint32
}
type Null struct{}
func NewTransport(conn *net.TCPConn) (*TCPTransport, error) {
if err := readHello(conn); err != nil {
return nil, err
}
return &TCPTransport{
conn: conn,
}, nil
}
func DoCommand[T any](t *TCPTransport, method string, params any) (T, error) {
t.lastCmdID++
id := t.lastCmdID
return doCommand[T](t.conn, id, method, params)
}
func (*Null) UnmarshalJSON(dat []byte) error {
if !bytes.Equal(dat, []byte("null")) {
return errors.New("expected null object")
}
return nil
}
func (Null) MarshalJSON() ([]byte, error) {
return []byte("null"), nil
}