package marionette import ( "bytes" "encoding/json" "errors" "fmt" "io" "log" "net" "strconv" "strings" ) // Low-level send/receive ////////////////////////////////////////////////////// func send(conn io.Writer, obj any) error { jsonBytes, err := json.Marshal(obj) if err != nil { return err } msgBytes := make([]byte, 0, 11+len(jsonBytes)) msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10) msgBytes = append(msgBytes, ':') msgBytes = append(msgBytes, jsonBytes...) if _, err := conn.Write(msgBytes); err != nil { return err } return nil } func receive[T any](conn io.Reader) (T, error) { var zero T // Packet format is length-prefixed JSON: // // packet = digit{1,10}+ ":" body // digit = "0"-"9" // body = JSON text head := make([]byte, 11) // Shortest possible message is `2:{}`, so a 4-byte initial read is always safe. headLen := 4 if _, err := io.ReadFull(conn, head[:4]); err != nil { return zero, err } colon := bytes.IndexByte(head[:4], ':') // If that wasn't enough, read another 7 bytes. if colon < 0 { n, err := io.ReadFull(conn, head[4:]) if err != nil { return zero, err } headLen += n } colon = bytes.IndexByte(head, ':') if colon < 0 { return zero, fmt.Errorf("invalid message length: %q: length exceeded", head) } bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32) if err != nil { return zero, fmt.Errorf("invalid message length: %q: %w", head[:colon], err) } body := make([]byte, bodyLen) n := copy(body, head[colon+1:headLen]) if _, err := io.ReadFull(conn, body[n:]); err != nil { return zero, err } var val T if err := json.Unmarshal(body, &val); err != nil { return zero, err } return val, nil } // read hello ////////////////////////////////////////////////////////////////// type msgHello struct { ApplicationType string `json:"applicationType"` MarionetteProtocol int `json:"marionetteProtocol"` } func readHello(conn io.Reader) error { hello, err := receive[msgHello](conn) if err != nil { return err } if hello.ApplicationType != "gecko" || hello.MarionetteProtocol != 3 { return fmt.Errorf("I only know how to speak marionette protocol v3 to gecko, but got response: %#v", hello) } return nil } // request/response semantics ////////////////////////////////////////////////// type errorObject struct { ErrorCode string `json:"error"` Message string `json:"message"` StackTrace string `json:"stacktrace"` } var _ error = (*errorObject)(nil) // Error implements [error]. func (e *errorObject) Error() string { return fmt.Sprintf("{\n ErrorCode: %q,\n Message: %q\n StackTrace: |\n%s\n}", e.ErrorCode, e.Message, "\t"+strings.ReplaceAll(strings.TrimRight(e.StackTrace, "\n"), "\n", "\n\t")) } type msgResponse[T any] struct { ID uint32 Error *errorObject Result T } var _ json.Unmarshaler = (*msgResponse[any])(nil) // UnmarshalJSON implements [json.Unmarshaler]. func (msg *msgResponse[T]) UnmarshalJSON(dat []byte) error { var ary []json.RawMessage if err := json.Unmarshal(dat, &ary); err != nil { return err } if len(ary) != 4 { return fmt.Errorf("responses should have 4 fields, but got %v", len(ary)) } var typ int if err := json.Unmarshal(ary[0], &typ); err != nil { return fmt.Errorf("invalid response field=type: %w: %q", err, ary[0]) } if typ != 1 { return fmt.Errorf("invalid response field=type: not 1: %q", ary[0]) } if err := json.Unmarshal(ary[1], &msg.ID); err != nil { return fmt.Errorf("invalid response field=ID: %w: %q", err, ary[1]) } if err := json.Unmarshal(ary[2], &msg.Error); err != nil { return fmt.Errorf("invalid response field=error: %w: %q", err, ary[2]) } if err := json.Unmarshal(ary[3], &msg.Result); err != nil { return fmt.Errorf("invalid response field=result: %w: %q", err, ary[3]) } return nil } func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) { var zero T if err := send(conn, []any{ 0, // type=command id, method, params, }); err != nil { return zero, err } for { resp, err := receive[msgResponse[T]](conn) if err != nil { return zero, fmt.Errorf("%s: %w", method, err) } if resp.ID == id { if resp.Error != nil { return zero, resp.Error } return resp.Result, nil } log.Printf("discarding response: %#v", resp) } } // High-level wrapper ////////////////////////////////////////////////////////// type TCPTransport struct { conn *net.TCPConn lastCmdID uint32 } type Null struct{} func NewTransport(conn *net.TCPConn) (*TCPTransport, error) { if err := readHello(conn); err != nil { return nil, err } return &TCPTransport{ conn: conn, }, nil } func DoCommand[T any](t *TCPTransport, method string, params any) (T, error) { t.lastCmdID++ id := t.lastCmdID return doCommand[T](t.conn, id, method, params) } func (*Null) UnmarshalJSON(dat []byte) error { if !bytes.Equal(dat, []byte("null")) { return errors.New("expected null object") } return nil } func (Null) MarshalJSON() ([]byte, error) { return []byte("null"), nil }