mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2025-01-11 15:41:19 -05:00
292 lines
5.7 KiB
Go
292 lines
5.7 KiB
Go
|
package hbase
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"io"
|
||
|
"net"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
pb "github.com/golang/protobuf/proto"
|
||
|
"github.com/juju/errors"
|
||
|
"github.com/ngaut/log"
|
||
|
"github.com/pingcap/go-hbase/iohelper"
|
||
|
"github.com/pingcap/go-hbase/proto"
|
||
|
)
|
||
|
|
||
|
type ServiceType byte
|
||
|
|
||
|
const (
|
||
|
MasterMonitorService = iota + 1
|
||
|
MasterService
|
||
|
MasterAdminService
|
||
|
AdminService
|
||
|
ClientService
|
||
|
RegionServerStatusService
|
||
|
)
|
||
|
|
||
|
// convert above const to protobuf string
|
||
|
var ServiceString = map[ServiceType]string{
|
||
|
MasterMonitorService: "MasterMonitorService",
|
||
|
MasterService: "MasterService",
|
||
|
MasterAdminService: "MasterAdminService",
|
||
|
AdminService: "AdminService",
|
||
|
ClientService: "ClientService",
|
||
|
RegionServerStatusService: "RegionServerStatusService",
|
||
|
}
|
||
|
|
||
|
type idGenerator struct {
|
||
|
n int
|
||
|
mu *sync.RWMutex
|
||
|
}
|
||
|
|
||
|
func newIdGenerator() *idGenerator {
|
||
|
return &idGenerator{
|
||
|
n: 0,
|
||
|
mu: &sync.RWMutex{},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (a *idGenerator) get() int {
|
||
|
a.mu.RLock()
|
||
|
v := a.n
|
||
|
a.mu.RUnlock()
|
||
|
return v
|
||
|
}
|
||
|
|
||
|
func (a *idGenerator) incrAndGet() int {
|
||
|
a.mu.Lock()
|
||
|
a.n++
|
||
|
v := a.n
|
||
|
a.mu.Unlock()
|
||
|
return v
|
||
|
}
|
||
|
|
||
|
type connection struct {
|
||
|
mu sync.Mutex
|
||
|
addr string
|
||
|
conn net.Conn
|
||
|
bw *bufio.Writer
|
||
|
idGen *idGenerator
|
||
|
serviceType ServiceType
|
||
|
in chan *iohelper.PbBuffer
|
||
|
ongoingCalls map[int]*call
|
||
|
}
|
||
|
|
||
|
func processMessage(msg []byte) ([][]byte, error) {
|
||
|
buf := pb.NewBuffer(msg)
|
||
|
payloads := make([][]byte, 0)
|
||
|
|
||
|
// Question: why can we ignore this error?
|
||
|
for {
|
||
|
hbytes, err := buf.DecodeRawBytes(true)
|
||
|
if err != nil {
|
||
|
// Check whether error is `unexpected EOF`.
|
||
|
if strings.Contains(err.Error(), "unexpected EOF") {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
log.Errorf("Decode raw bytes error - %v", errors.ErrorStack(err))
|
||
|
return nil, errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
payloads = append(payloads, hbytes)
|
||
|
}
|
||
|
|
||
|
return payloads, nil
|
||
|
}
|
||
|
|
||
|
func readPayloads(r io.Reader) ([][]byte, error) {
|
||
|
nBytesExpecting, err := iohelper.ReadInt32(r)
|
||
|
if err != nil {
|
||
|
return nil, errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
if nBytesExpecting > 0 {
|
||
|
buf, err := iohelper.ReadN(r, nBytesExpecting)
|
||
|
// Question: why should we return error only when we get an io.EOF error?
|
||
|
if err != nil && ErrorEqual(err, io.EOF) {
|
||
|
return nil, errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
payloads, err := processMessage(buf)
|
||
|
if err != nil {
|
||
|
return nil, errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
if len(payloads) > 0 {
|
||
|
return payloads, nil
|
||
|
}
|
||
|
}
|
||
|
return nil, errors.New("unexpected payload")
|
||
|
}
|
||
|
|
||
|
func newConnection(addr string, srvType ServiceType) (*connection, error) {
|
||
|
conn, err := net.Dial("tcp", addr)
|
||
|
if err != nil {
|
||
|
return nil, errors.Trace(err)
|
||
|
}
|
||
|
if _, ok := ServiceString[srvType]; !ok {
|
||
|
return nil, errors.Errorf("unexpected service type [serviceType=%d]", srvType)
|
||
|
}
|
||
|
c := &connection{
|
||
|
addr: addr,
|
||
|
bw: bufio.NewWriter(conn),
|
||
|
conn: conn,
|
||
|
in: make(chan *iohelper.PbBuffer, 20),
|
||
|
serviceType: srvType,
|
||
|
idGen: newIdGenerator(),
|
||
|
ongoingCalls: map[int]*call{},
|
||
|
}
|
||
|
|
||
|
err = c.init()
|
||
|
if err != nil {
|
||
|
return nil, errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
func (c *connection) init() error {
|
||
|
err := c.writeHead()
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
err = c.writeConnectionHeader()
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
go func() {
|
||
|
err := c.processMessages()
|
||
|
if err != nil {
|
||
|
log.Warnf("process messages failed - %v", errors.ErrorStack(err))
|
||
|
return
|
||
|
}
|
||
|
}()
|
||
|
go c.dispatch()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *connection) processMessages() error {
|
||
|
for {
|
||
|
msgs, err := readPayloads(c.conn)
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
var rh proto.ResponseHeader
|
||
|
err = pb.Unmarshal(msgs[0], &rh)
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
callId := rh.GetCallId()
|
||
|
c.mu.Lock()
|
||
|
call, ok := c.ongoingCalls[int(callId)]
|
||
|
if !ok {
|
||
|
c.mu.Unlock()
|
||
|
return errors.Errorf("Invalid call id: %d", callId)
|
||
|
}
|
||
|
delete(c.ongoingCalls, int(callId))
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
exception := rh.GetException()
|
||
|
if exception != nil {
|
||
|
call.complete(errors.Errorf("Exception returned: %s\n%s", exception.GetExceptionClassName(), exception.GetStackTrace()), nil)
|
||
|
} else if len(msgs) == 2 {
|
||
|
call.complete(nil, msgs[1])
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *connection) writeHead() error {
|
||
|
buf := bytes.NewBuffer(nil)
|
||
|
buf.Write(hbaseHeaderBytes)
|
||
|
buf.WriteByte(0)
|
||
|
buf.WriteByte(80)
|
||
|
_, err := c.conn.Write(buf.Bytes())
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
func (c *connection) writeConnectionHeader() error {
|
||
|
buf := iohelper.NewPbBuffer()
|
||
|
service := pb.String(ServiceString[c.serviceType])
|
||
|
|
||
|
err := buf.WritePBMessage(&proto.ConnectionHeader{
|
||
|
UserInfo: &proto.UserInformation{
|
||
|
EffectiveUser: pb.String("pingcap"),
|
||
|
},
|
||
|
ServiceName: service,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
err = buf.PrependSize()
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
_, err = c.conn.Write(buf.Bytes())
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *connection) dispatch() {
|
||
|
for {
|
||
|
select {
|
||
|
case buf := <-c.in:
|
||
|
// TODO: add error check.
|
||
|
c.bw.Write(buf.Bytes())
|
||
|
if len(c.in) == 0 {
|
||
|
c.bw.Flush()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *connection) call(request *call) error {
|
||
|
id := c.idGen.incrAndGet()
|
||
|
rh := &proto.RequestHeader{
|
||
|
CallId: pb.Uint32(uint32(id)),
|
||
|
MethodName: pb.String(request.methodName),
|
||
|
RequestParam: pb.Bool(true),
|
||
|
}
|
||
|
|
||
|
request.id = uint32(id)
|
||
|
|
||
|
bfrh := iohelper.NewPbBuffer()
|
||
|
err := bfrh.WritePBMessage(rh)
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
bfr := iohelper.NewPbBuffer()
|
||
|
err = bfr.WritePBMessage(request.request)
|
||
|
if err != nil {
|
||
|
return errors.Trace(err)
|
||
|
}
|
||
|
|
||
|
// Buf =>
|
||
|
// | total size | pb1 size | pb1 | pb2 size | pb2 | ...
|
||
|
buf := iohelper.NewPbBuffer()
|
||
|
buf.WriteDelimitedBuffers(bfrh, bfr)
|
||
|
|
||
|
c.mu.Lock()
|
||
|
c.ongoingCalls[id] = request
|
||
|
c.in <- buf
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *connection) close() error {
|
||
|
return c.conn.Close()
|
||
|
}
|