1
0
mirror of https://github.com/wbt5/real-url.git synced 2025-08-02 15:44:49 +08:00
2020-06-18 10:39:11 +08:00

547 lines
21 KiB
Python

# Tencent is pleased to support the open source community by making Tars available.
#
# Copyright (C) 2016THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
#
import struct
from .__util import util
from .exception import *
class BinBuffer:
def __init__(self, buff=bytes()):
self.buffer = buff
self.position = 0
def writeBuf(self, buff):
self.buffer += buff
def getBuffer(self):
return self.buffer
def length(self):
return len(self.buffer)
class DataHead:
EN_INT8 = 0
EN_INT16 = 1
EN_INT32 = 2
EN_INT64 = 3
EN_FLOAT = 4
EN_DOUBLE = 5
EN_STRING1 = 6
EN_STRING4 = 7
EN_MAP = 8
EN_LIST = 9
EN_STRUCTBEGIN = 10
EN_STRUCTEND = 11
EN_ZERO = 12
EN_BYTES = 13
@staticmethod
def writeTo(buff, tag, vtype):
if tag < 15:
helper = (tag << 4) | vtype
buff.writeBuf(struct.pack('!B', helper))
else:
helper = (0xF0 | vtype) << 8 | tag
buff.writeBuf(struct.pack('!H', helper))
class TarsOutputStream(object):
def __init__(self):
self.__buffer = BinBuffer()
def __writeBoolean(self, tag, value):
self.__writeInt8(tag, int(value))
def __writeInt8(self, tag, value):
if value == 0:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_ZERO)
else:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_INT8)
self.__buffer.writeBuf(struct.pack('!b', value))
def __writeInt16(self, tag, value):
if value >= -128 and value <= 127:
self.__writeInt8(tag, value)
else:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_INT16)
self.__buffer.writeBuf(struct.pack('!h', value))
def __writeInt32(self, tag, value):
if value >= -32768 and value <= 32767:
self.__writeInt16(tag, value)
else:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_INT32)
self.__buffer.writeBuf(struct.pack('!i', value))
def __writeInt64(self, tag, value):
if value >= (-2147483648) and value <= 2147483647:
self.__writeInt32(tag, value)
else:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_INT64)
self.__buffer.writeBuf(struct.pack('!q', value))
def __writeFloat(self, tag, value):
DataHead.writeTo(self.__buffer, tag, DataHead.EN_FLOAT)
self.__buffer.writeBuf(struct.pack('!f', value))
def __writeDouble(self, tag, value):
DataHead.writeTo(self.__buffer, tag, DataHead.EN_DOUBLE)
self.__buffer.writeBuf(struct.pack('!d', value))
def __writeString(self, tag, value):
length = len(value)
if length <= 255:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_STRING1)
self.__buffer.writeBuf(struct.pack('!B', length))
self.__buffer.writeBuf(str.encode(value))
else:
DataHead.writeTo(self.__buffer, tag, DataHead.EN_STRING4)
self.__buffer.writeBuf(struct.pack('!I', length))
self.__buffer.writeBuf(str.encode(value))
def __writeBytes(self, tag, value):
DataHead.writeTo(self.__buffer, tag, DataHead.EN_BYTES)
DataHead.writeTo(self.__buffer, 0, DataHead.EN_INT8)
length = len(value)
self.__writeInt32(0, length)
self.__buffer.buffer += value
self.__buffer.position += length
def __writeMap(self, coder, tag, value):
DataHead.writeTo(self.__buffer, tag, DataHead.EN_MAP)
self.__writeInt32(0, len(value))
for key in value:
self.write(coder.ktype, 0, key)
self.write(coder.vtype, 1, value.get(key))
def __writeVector(self, coder, tag, value):
DataHead.writeTo(self.__buffer, tag, DataHead.EN_LIST)
n = len(value)
self.__writeInt32(0, n)
for i in range(0, n):
self.write(value.vtype, 0, value[i])
def __writeStruct(self, coder, tag, value):
DataHead.writeTo(self.__buffer, tag, DataHead.EN_STRUCTBEGIN)
value.writeTo(self, value)
DataHead.writeTo(self.__buffer, 0, DataHead.EN_STRUCTEND)
def write(self, coder, tag, value):
if coder.__tars_index__ == 999:
self.__writeBoolean(tag, value)
elif coder.__tars_index__ == 0:
self.__writeInt8(tag, value)
elif coder.__tars_index__ == 1:
self.__writeInt16(tag, value)
elif coder.__tars_index__ == 2:
self.__writeInt32(tag, value)
elif coder.__tars_index__ == 3:
self.__writeInt64(tag, value)
elif coder.__tars_index__ == 4:
self.__writeFloat(tag, value)
elif coder.__tars_index__ == 5:
self.__writeDouble(tag, value)
elif coder.__tars_index__ == 13:
self.__writeBytes(tag, value)
elif coder.__tars_index__ == 67:
self.__writeString(tag, value)
elif coder.__tars_index__ == 8:
self.__writeMap(coder, tag, value)
elif coder.__tars_index__ == 9:
self.__writeVector(coder, tag, value)
elif coder.__tars_index__ == 1011:
self.__writeStruct(coder, tag, value)
else:
raise TarsTarsUnsupportType(
"tars unsupport data type:" % coder.__tars_index__)
def getBuffer(self):
return self.__buffer.getBuffer()
def printHex(self):
util.printHex(self.__buffer.getBuffer())
class TarsInputStream(object):
def __init__(self, buff):
self.__buffer = BinBuffer(buff)
def __peekFrom(self):
helper, = struct.unpack_from(
'!B', self.__buffer.buffer, self.__buffer.position)
t = (helper & 0xF0) >> 4
p = (helper & 0x0F)
l = 1
if t >= 15:
l = 2
t, = struct.unpack_from(
'!B', self.__buffer.buffer, self.__buffer.position + 1)
return (t, p, l)
def __readFrom(self):
t, p, l = self.__peekFrom()
self.__buffer.position += l
return (t, p, l)
def __skipToStructEnd(self):
t, p, l = self.__readFrom()
while p != DataHead.EN_STRUCTEND:
self.__skipField(p)
t, p, l = self.__readFrom()
def __skipField(self, p):
if p == DataHead.EN_INT8:
self.__buffer.position += 1
elif p == DataHead.EN_INT16:
self.__buffer.position += 2
elif p == DataHead.EN_INT32:
self.__buffer.position += 4
elif p == DataHead.EN_INT64:
self.__buffer.position += 8
elif p == DataHead.EN_FLOAT:
self.__buffer.position += 4
elif p == DataHead.EN_DOUBLE:
self.__buffer.position += 8
elif p == DataHead.EN_STRING1:
length, = struct.unpack_from(
'!B', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += length + 1
elif p == DataHead.EN_STRING4:
length, = struct.unpack_from(
'!i', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += length + 4
elif p == DataHead.EN_MAP:
size = self.__readInt32(0, True)
for i in range(0, size * 2):
ti, pi, li = self.__readFrom()
self.__skipField(pi)
elif p == DataHead.EN_LIST:
size = self.__readInt32(0, True)
for i in range(0, size):
ti, pi, li = self.__readFrom()
self.__skipField(pi)
elif p == DataHead.EN_BYTES:
ti, pi, li = self.__readFrom()
if pi != DataHead.EN_INT8:
raise TarsTarsDecodeInvalidValue(
"skipField with invalid type, type value: %d, %d." % (p, pi))
size = self.__readInt32(0, True)
self.__buffer.position += size
elif p == DataHead.EN_STRUCTBEGIN:
self.__skipToStructEnd()
elif p == DataHead.EN_STRUCTEND:
pass
#self.__buffer.position += length + 1;
elif p == DataHead.EN_ZERO:
pass
#self.__buffer.position += length + 1;
else:
raise TarsTarsDecodeMismatch(
"skipField with invalid type, type value:%d" % p)
def __skipToTag(self, tag):
length = self.__buffer.length()
while self.__buffer.position < length:
t, p, l = self.__peekFrom()
if tag <= t or p == DataHead.EN_STRUCTEND:
return False if (p == DataHead.EN_STRUCTEND) else (t == tag)
self.__buffer.position += l
self.__skipField(p)
return False
def __readBoolean(self, tag, require, default=None):
v = self.__readInt8(tag, require)
if v is None:
return default
else:
return (v != 0)
def __readInt8(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_ZERO:
return 0
elif p == DataHead.EN_INT8:
value, = struct.unpack_from(
'!b', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 1
return value
else:
raise TarsTarsDecodeMismatch(
"read 'Char' type mismatch, tag: %d , get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readInt16(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_ZERO:
return 0
elif p == DataHead.EN_INT8:
value, = struct.unpack_from(
'!b', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 1
return value
elif p == DataHead.EN_INT16:
value, = struct.unpack_from(
'!h', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 2
return value
else:
raise TarsTarsDecodeMismatch(
"read 'Short' type mismatch, tag: %d , get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readInt32(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_ZERO:
return 0
elif p == DataHead.EN_INT8:
value, = struct.unpack_from(
'!b', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 1
return value
elif p == DataHead.EN_INT16:
value, = struct.unpack_from(
'!h', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 2
return value
elif p == DataHead.EN_INT32:
value, = struct.unpack_from(
'!i', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 4
return value
else:
raise TarsTarsDecodeMismatch(
"read 'Int32' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readInt64(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_ZERO:
return 0
elif p == DataHead.EN_INT8:
value, = struct.unpack_from(
'!b', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 1
return value
elif p == DataHead.EN_INT16:
value, = struct.unpack_from(
'!h', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 2
return value
elif p == DataHead.EN_INT32:
value, = struct.unpack_from(
'!i', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 4
return value
elif p == DataHead.EN_INT64:
value, = struct.unpack_from(
'!q', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 8
return value
else:
raise TarsTarsDecodeMismatch(
"read 'Int64' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readString(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_STRING1:
length, = struct.unpack_from(
'!B', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 1
value, = struct.unpack_from(
str(length) + "s", self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += length
return value
elif p == DataHead.EN_STRING4:
length, = struct.unpack_from(
'!i', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 4
value, = struct.unpack_from(
str(length) + "s", self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += length
return value
else:
raise TarsTarsDecodeMismatch(
"read 'string' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readBytes(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_BYTES:
ti, pi, li = self.__readFrom()
if pi != DataHead.EN_INT8:
raise TarsTarsDecodeMismatch(
"type mismatch, tag: %d, type: %d, %d" % (tag, p, pi))
size = self.__readInt32(0, True)
value, = struct.unpack_from(
str(size) + 's', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += size
return value
else:
raise TarsTarsDecodeMismatch(
"type mismatch, tag: %d, type: %d" % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readFloat(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_ZERO:
return 0
elif p == DataHead.EN_FLOAT:
value, = struct.unpack_from(
'!f', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 4
return value
else:
raise TarsTarsDecodeMismatch(
"read 'Float' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readDouble(self, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_ZERO:
return 0
elif p == DataHead.EN_FLOAT:
value, = struct.unpack_from(
'!f', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 4
return value
elif p == DataHead.EN_DOUBLE:
value, = struct.unpack_from(
'!d', self.__buffer.buffer, self.__buffer.position)
self.__buffer.position += 8
return value
else:
raise TarsTarsDecodeMismatch(
"read 'Double' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readStruct(self, coder, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p != DataHead.EN_STRUCTBEGIN:
raise TarsTarsDecodeMismatch(
"read 'struct' type mismatch, tag: %d, get type: %d." % (tag, p))
value = coder.readFrom(self)
self.__skipToStructEnd()
return value
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readMap(self, coder, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_MAP:
size = self.__readInt32(0, True)
omap = coder()
for i in range(0, size):
k = self.read(coder.ktype, 0, True)
v = self.read(coder.vtype, 1, True)
omap[k] = v
return omap
else:
raise TarsTarsDecodeMismatch(
"read 'map' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def __readVector(self, coder, tag, require, default=None):
if self.__skipToTag(tag):
t, p, l = self.__readFrom()
if p == DataHead.EN_LIST:
size = self.__readInt32(0, True)
value = coder()
for i in range(0, size):
k = self.read(coder.vtype, 0, True)
value.append(k)
return value
else:
raise TarsTarsDecodeMismatch(
"read 'vector' type mismatch, tag: %d, get type: %d." % (tag, p))
elif require:
raise TarsTarsDecodeRequireNotExist(
"require field not exist, tag: %d" % tag)
return default
def read(self, coder, tag, require, default=None):
if coder.__tars_index__ == 999:
return self.__readBoolean(tag, require, default)
elif coder.__tars_index__ == 0:
return self.__readInt8(tag, require, default)
elif coder.__tars_index__ == 1:
return self.__readInt16(tag, require, default)
elif coder.__tars_index__ == 2:
return self.__readInt32(tag, require, default)
elif coder.__tars_index__ == 3:
return self.__readInt64(tag, require, default)
elif coder.__tars_index__ == 4:
return self.__readFloat(tag, require, default)
elif coder.__tars_index__ == 5:
return self.__readDouble(tag, require, default)
elif coder.__tars_index__ == 13:
return self.__readBytes(tag, require, default)
elif coder.__tars_index__ == 67:
return self.__readString(tag, require, default)
elif coder.__tars_index__ == 8:
return self.__readMap(coder, tag, require, default)
elif coder.__tars_index__ == 9:
return self.__readVector(coder, tag, require, default)
elif coder.__tars_index__ == 1011:
return self.__readStruct(coder, tag, require, default)
else:
raise TarsTarsUnsupportType(
"tars unsupport data type:" % coder.__tars_index__)
def printHex(self):
util.printHex(self.__buffer.buffer)