package main import ( "bytes" "encoding/json" "fmt" "io" "log" "net" "strconv" "strings" ) func receive[T any](conn io.Reader) (T, error) { var zero T // Packet format is length-prefixed JSON: // // packet = digit{1,10}+ ":" body // digit = "0"-"9" // body = JSON text head := make([]byte, 11) // Shortest possible message is `2:{}`, so a 4-byte initial read is always safe. headLen := 4 if _, err := io.ReadFull(conn, head[:4]); err != nil { return zero, err } colon := bytes.IndexByte(head[:4], ':') // If that wasn't enough, read another 7 bytes. if colon < 0 { n, err := io.ReadFull(conn, head[4:]) if err != nil { return zero, err } headLen += n } colon = bytes.IndexByte(head, ':') if colon < 0 { return zero, fmt.Errorf("invalid message length: %q: length exceeded", head) } bodyLen, err := strconv.ParseUint(string(head[:colon]), 10, 32) if err != nil { return zero, fmt.Errorf("invalid message length: %q: %w", head, bodyLen) } body := make([]byte, bodyLen) n := copy(body, head[colon+1:headLen]) if _, err := io.ReadFull(conn, body[n:]); err != nil { return zero, err } var val T if err := json.Unmarshal(body, &val); err != nil { return zero, err } return val, nil } func send(conn io.Writer, obj any) error { jsonBytes, err := json.Marshal(obj) if err != nil { return err } msgBytes := make([]byte, 0, 11+len(jsonBytes)) msgBytes = strconv.AppendInt(msgBytes, int64(len(jsonBytes)), 10) msgBytes = append(msgBytes, ':') msgBytes = append(msgBytes, jsonBytes...) if _, err := conn.Write(msgBytes); err != nil { return err } return nil } type MsgHello struct { ApplicationType string `json:"applicationType"` MarionetteProtocol int `json:"marionetteProtocol"` } func sendCommand(conn io.Writer, id uint32, method string, parameters any) error { return send(conn, []any{ 0, // type=command id, method, parameters, }) } type MsgResponse[T any] struct { ID uint32 Error *ErrorObject Result T } type ErrorObject struct { ErrorCode string `json:"error"` Message string `json:"message"` StackTrace string `json:"stacktrace"` } func (e *ErrorObject) Error() string { return fmt.Sprintf("{\n ErrorCode: %q,\n Message: %q\n StackTrace: |\n%s\n}", e.ErrorCode, e.Message, "\t"+strings.ReplaceAll(strings.TrimRight(e.StackTrace, "\n"), "\n", "\n\t")) } func (msg *MsgResponse[T]) UnmarshalJSON(dat []byte) error { var ary []json.RawMessage if err := json.Unmarshal(dat, &ary); err != nil { return err } if len(ary) != 4 { return fmt.Errorf("responses should have 4 fields, but got %v", len(ary)) } var typ int if err := json.Unmarshal(ary[0], &typ); err != nil || typ != 1 { return fmt.Errorf("invalid response field=type: %q", ary[0]) } if err := json.Unmarshal(ary[1], &msg.ID); err != nil || typ != 1 { return fmt.Errorf("invalid response field=ID: %q", ary[1]) } if err := json.Unmarshal(ary[2], &msg.Error); err != nil || typ != 1 { return fmt.Errorf("invalid response field=error: %q", ary[2]) } if err := json.Unmarshal(ary[3], &msg.Result); err != nil || typ != 1 { return fmt.Errorf("invalid response field=result: %q", ary[3]) } return nil } func doCommand[T any](conn io.ReadWriter, id uint32, method string, params any) (T, error) { var zero T if err := sendCommand(conn, id, method, params); err != nil { return zero, err } for { resp, err := receive[MsgResponse[T]](conn) if err != nil { return zero, err } if resp.ID == id { if resp.Error != nil { return zero, resp.Error } return resp.Result, nil } log.Printf("discarding response: %#v", resp) } } type TCPTransport struct { conn *net.TCPConn lastCmdID uint32 } func DoCommand[T any](t *TCPTransport, method string, params any) (T, error) { t.lastCmdID++ id := t.lastCmdID return doCommand[T](t.conn, id, method, params) } func NewTransport(conn *net.TCPConn) *TCPTransport { return &TCPTransport{ conn: conn, } }