212 lines
5.0 KiB
Go
212 lines
5.0 KiB
Go
package marionette
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// Low-level send/receive //////////////////////////////////////////////////////
|
|
|
|
func send(conn io.Writer, obj any) error {
|
|
jsonBytes, err := json.Marshal(obj)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
msgBytes := make([]byte, 0, 11+len(jsonBytes))
|
|
|
|
msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10)
|
|
msgBytes = append(msgBytes, ':')
|
|
msgBytes = append(msgBytes, jsonBytes...)
|
|
|
|
if _, err := conn.Write(msgBytes); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func receive[T any](conn io.Reader) (T, error) {
|
|
var zero T
|
|
|
|
// Packet format is length-prefixed JSON:
|
|
//
|
|
// packet = digit{1,10}+ ":" body
|
|
// digit = "0"-"9"
|
|
// body = JSON text
|
|
|
|
head := make([]byte, 11)
|
|
// Shortest possible message is `2:{}`, so a 4-byte initial read is always safe.
|
|
headLen := 4
|
|
if _, err := io.ReadFull(conn, head[:4]); err != nil {
|
|
return zero, err
|
|
}
|
|
colon := bytes.IndexByte(head[:4], ':')
|
|
// If that wasn't enough, read another 7 bytes.
|
|
if colon < 0 {
|
|
n, err := io.ReadFull(conn, head[4:])
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
headLen += n
|
|
}
|
|
colon = bytes.IndexByte(head, ':')
|
|
if colon < 0 {
|
|
return zero, fmt.Errorf("invalid message length: %q: length exceeded", head)
|
|
}
|
|
|
|
bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32)
|
|
if err != nil {
|
|
return zero, fmt.Errorf("invalid message length: %q: %w", head[:colon], err)
|
|
}
|
|
|
|
body := make([]byte, bodyLen)
|
|
n := copy(body, head[colon+1:headLen])
|
|
if _, err := io.ReadFull(conn, body[n:]); err != nil {
|
|
return zero, err
|
|
}
|
|
|
|
var val T
|
|
if err := json.Unmarshal(body, &val); err != nil {
|
|
return zero, err
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
// read hello //////////////////////////////////////////////////////////////////
|
|
|
|
type msgHello struct {
|
|
ApplicationType string `json:"applicationType"`
|
|
MarionetteProtocol int `json:"marionetteProtocol"`
|
|
}
|
|
|
|
func readHello(conn io.Reader) error {
|
|
hello, err := receive[msgHello](conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if hello.ApplicationType != "gecko" || hello.MarionetteProtocol != 3 {
|
|
return fmt.Errorf("I only know how to speak marionette protocol v3 to gecko, but got response: %#v", hello)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// request/response semantics //////////////////////////////////////////////////
|
|
|
|
type errorObject struct {
|
|
ErrorCode string `json:"error"`
|
|
Message string `json:"message"`
|
|
StackTrace string `json:"stacktrace"`
|
|
}
|
|
|
|
var _ error = (*errorObject)(nil)
|
|
|
|
// Error implements [error].
|
|
func (e *errorObject) Error() string {
|
|
return fmt.Sprintf("{\n ErrorCode: %q,\n Message: %q\n StackTrace: |\n%s\n}", e.ErrorCode, e.Message, "\t"+strings.ReplaceAll(strings.TrimRight(e.StackTrace, "\n"), "\n", "\n\t"))
|
|
}
|
|
|
|
type msgResponse[T any] struct {
|
|
ID uint32
|
|
Error *errorObject
|
|
Result T
|
|
}
|
|
|
|
var _ json.Unmarshaler = (*msgResponse[any])(nil)
|
|
|
|
// UnmarshalJSON implements [json.Unmarshaler].
|
|
func (msg *msgResponse[T]) UnmarshalJSON(dat []byte) error {
|
|
var ary []json.RawMessage
|
|
if err := json.Unmarshal(dat, &ary); err != nil {
|
|
return err
|
|
}
|
|
if len(ary) != 4 {
|
|
return fmt.Errorf("responses should have 4 fields, but got %v", len(ary))
|
|
}
|
|
var typ int
|
|
if err := json.Unmarshal(ary[0], &typ); err != nil {
|
|
return fmt.Errorf("invalid response field=type: %w: %q", err, ary[0])
|
|
}
|
|
if typ != 1 {
|
|
return fmt.Errorf("invalid response field=type: not 1: %q", ary[0])
|
|
}
|
|
if err := json.Unmarshal(ary[1], &msg.ID); err != nil {
|
|
return fmt.Errorf("invalid response field=ID: %w: %q", err, ary[1])
|
|
}
|
|
if err := json.Unmarshal(ary[2], &msg.Error); err != nil {
|
|
return fmt.Errorf("invalid response field=error: %w: %q", err, ary[2])
|
|
}
|
|
if err := json.Unmarshal(ary[3], &msg.Result); err != nil {
|
|
return fmt.Errorf("invalid response field=result: %w: %q", err, ary[3])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) {
|
|
var zero T
|
|
if err := send(conn, []any{
|
|
0, // type=command
|
|
id,
|
|
method,
|
|
params,
|
|
}); err != nil {
|
|
return zero, err
|
|
}
|
|
for {
|
|
resp, err := receive[msgResponse[T]](conn)
|
|
if err != nil {
|
|
return zero, fmt.Errorf("%s: %w", method, err)
|
|
}
|
|
if resp.ID == id {
|
|
if resp.Error != nil {
|
|
return zero, resp.Error
|
|
}
|
|
return resp.Result, nil
|
|
}
|
|
log.Printf("discarding response: %#v", resp)
|
|
}
|
|
}
|
|
|
|
// High-level wrapper //////////////////////////////////////////////////////////
|
|
|
|
type TCPTransport struct {
|
|
conn *net.TCPConn
|
|
lastCmdID uint32
|
|
}
|
|
|
|
type Null struct{}
|
|
|
|
func NewTransport(conn *net.TCPConn) (*TCPTransport, error) {
|
|
if err := readHello(conn); err != nil {
|
|
return nil, err
|
|
}
|
|
return &TCPTransport{
|
|
conn: conn,
|
|
}, nil
|
|
}
|
|
|
|
func DoCommand[T any](t *TCPTransport, method string, params any) (T, error) {
|
|
t.lastCmdID++
|
|
id := t.lastCmdID
|
|
return doCommand[T](t.conn, id, method, params)
|
|
}
|
|
|
|
func (*Null) UnmarshalJSON(dat []byte) error {
|
|
if !bytes.Equal(dat, []byte("null")) {
|
|
return errors.New("expected null object")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (Null) MarshalJSON() ([]byte, error) {
|
|
return []byte("null"), nil
|
|
}
|