ffctl/ffscript/transport.go

171 lines
3.9 KiB
Go

package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
)
func receive[T any](conn io.Reader) (T, error) {
var zero T
// Packet format is length-prefixed JSON:
//
// packet = digit{1,10}+ ":" body
// digit = "0"-"9"
// body = JSON text
head := make([]byte, 11)
// Shortest possible message is `2:{}`, so a 4-byte initial read is always safe.
headLen := 4
if _, err := io.ReadFull(conn, head[:4]); err != nil {
return zero, err
}
colon := bytes.IndexByte(head[:4], ':')
// If that wasn't enough, read another 7 bytes.
if colon < 0 {
n, err := io.ReadFull(conn, head[4:])
if err != nil {
return zero, err
}
headLen += n
}
colon = bytes.IndexByte(head, ':')
if colon < 0 {
return zero, fmt.Errorf("invalid message length: %q: length exceeded", head)
}
bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32)
if err != nil {
return zero, fmt.Errorf("invalid message length: %q: %w", head, bodyLen)
}
body := make([]byte, bodyLen)
n := copy(body, head[colon+1:headLen])
if _, err := io.ReadFull(conn, body[n:]); err != nil {
return zero, err
}
var val T
if err := json.Unmarshal(body, &val); err != nil {
return zero, err
}
return val, nil
}
func send(conn io.Writer, obj any) error {
jsonBytes, err := json.Marshal(obj)
if err != nil {
return err
}
msgBytes := make([]byte, 0, 11+len(jsonBytes))
msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10)
msgBytes = append(msgBytes, ':')
msgBytes = append(msgBytes, jsonBytes...)
if _, err := conn.Write(msgBytes); err != nil {
return err
}
return nil
}
type MsgHello struct {
ApplicationType string `json:"applicationType"`
MarionetteProtocol int `json:"marionetteProtocol"`
}
func sendCommand(conn io.Writer, id uint32, method string, parameters any) error {
return send(conn, []any{
0, // type=command
id,
method,
parameters,
})
}
type MsgResponse[T any] struct {
ID uint32
Error *ErrorObject
Result T
}
type ErrorObject struct {
ErrorCode string `json:"error"`
Message string `json:"message"`
StackTrace string `json:"stacktrace"`
}
func (e *ErrorObject) Error() string {
return fmt.Sprintf("{\n ErrorCode: %q,\n Message: %q\n StackTrace: |\n%s\n}", e.ErrorCode, e.Message, "\t"+strings.ReplaceAll(strings.TrimRight(e.StackTrace, "\n"), "\n", "\n\t"))
}
func (msg *MsgResponse[T]) UnmarshalJSON(dat []byte) error {
var ary []json.RawMessage
if err := json.Unmarshal(dat, &ary); err != nil {
return err
}
if len(ary) != 4 {
return fmt.Errorf("responses should have 4 fields, but got %v", len(ary))
}
var typ int
if err := json.Unmarshal(ary[0], &typ); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=type: %q", ary[0])
}
if err := json.Unmarshal(ary[1], &msg.ID); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=ID: %q", ary[1])
}
if err := json.Unmarshal(ary[2], &msg.Error); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=error: %q", ary[2])
}
if err := json.Unmarshal(ary[3], &msg.Result); err != nil || typ != 1 {
return fmt.Errorf("invalid response field=result: %q", ary[3])
}
return nil
}
func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) {
var zero T
if err := sendCommand(conn, id, method, params); err != nil {
return zero, err
}
for {
resp, err := receive[MsgResponse[T]](conn)
if err != nil {
return zero, err
}
if resp.ID == id {
if resp.Error != nil {
return zero, resp.Error
}
return resp.Result, nil
}
log.Printf("discarding response: %#v", resp)
}
}
type TCPTransport struct {
conn *net.TCPConn
lastCmdID uint32
}
func DoCommand[T any](t *TCPTransport, method string, params any) (T, error) {
t.lastCmdID++
id := t.lastCmdID
return doCommand[T](t.conn, id, method, params)
}
func NewTransport(conn *net.TCPConn) *TCPTransport {
return &TCPTransport{
conn: conn,
}
}