diff --git a/classname_bag.go b/classbag.go similarity index 62% rename from classname_bag.go rename to classbag.go index 66585a3..1c0b903 100644 --- a/classname_bag.go +++ b/classbag.go @@ -7,3 +7,7 @@ type ClassBag struct { func (bag *ClassBag) Add(classDesc *TCClassDesc) { bag.Classes = append(bag.Classes, classDesc) } + +func (bag *ClassBag) Merge(newBag *ClassBag) { + bag.Classes = append(bag.Classes, newBag.Classes...) +} diff --git a/example/main.go b/example/main.go index a3d6406..05fa363 100644 --- a/example/main.go +++ b/example/main.go @@ -8,7 +8,7 @@ import ( ) func main() { - data, err := ioutil.ReadFile("example/object.poc") + data, err := ioutil.ReadFile("example/object2.poc") if err != nil { fmt.Println(err.Error()) return diff --git a/tc_array.go b/tc_array.go index 967629f..6bd7cdb 100644 --- a/tc_array.go +++ b/tc_array.go @@ -28,11 +28,12 @@ func readTCArray(stream *ObjectStream) (*TCArray, error) { var err error _, _ = stream.ReadN(1) - array.ClassPointer, err = readTCClassPointer(stream, nil) + array.ClassPointer, err = readTCClassPointer(stream) if err != nil { return nil, err } + stream.AddReference(array) bs, err := stream.ReadN(4) if err != nil { sugar.Error(err) diff --git a/tc_class.go b/tc_class.go index 0855d7f..287fd76 100644 --- a/tc_class.go +++ b/tc_class.go @@ -15,7 +15,7 @@ func readTCClass(stream *ObjectStream) (*TCClass, error) { var err error _, _ = stream.ReadN(1) - class.ClassPointer, err = readTCClassPointer(stream, nil) + class.ClassPointer, err = readTCClassPointer(stream) if err != nil { return nil, err } diff --git a/tc_classdesc.go b/tc_classdesc.go index 36e272a..de2d4f0 100644 --- a/tc_classdesc.go +++ b/tc_classdesc.go @@ -37,14 +37,10 @@ func (desc *TCClassDesc) HasFlag(flag byte) bool { return (desc.ClassDescFlags & flag) == flag } -func readTCClassDesc(stream *ObjectStream, bag *ClassBag) (*TCClassDesc, error) { +func readTCClassDesc(stream *ObjectStream) (*TCClassDesc, error) { var err error var classDesc = new(TCClassDesc) - if bag != nil { - bag.Add(classDesc) - } - // read JAVA_TC_CLASSDESC flag _, _ = stream.ReadN(1) @@ -85,7 +81,7 @@ func readTCClassDesc(stream *ObjectStream, bag *ClassBag) (*TCClassDesc, error) } // superClassDesc - classDesc.SuperClassPointer, err = readTCClassPointer(stream, bag) + classDesc.SuperClassPointer, err = readTCClassPointer(stream) if err != nil { return nil, err } diff --git a/tc_classpointer.go b/tc_classpointer.go index 2e42d88..f711206 100644 --- a/tc_classpointer.go +++ b/tc_classpointer.go @@ -44,7 +44,35 @@ func (cp *TCClassPointer) GetClassDesc(stream *ObjectStream) (*TCClassDesc, erro } } -func readTCClassPointer(stream *ObjectStream, bag *ClassBag) (*TCClassPointer, error) { +func (cp *TCClassPointer) FindClassBag(stream *ObjectStream) (*ClassBag, error) { + var desc *TCClassDesc + var err error + if cp.Flag == JAVA_TC_NULL { + return nil, nil + } + + desc, err = cp.GetClassDesc(stream) + if err != nil { + return nil, err + } + + var bag = &ClassBag{ + Classes: []*TCClassDesc{desc}, + } + + newBag, err := desc.SuperClassPointer.FindClassBag(stream) + if err != nil { + return nil, err + } + + if newBag != nil { + bag.Merge(newBag) + } + + return bag, nil +} + +func readTCClassPointer(stream *ObjectStream) (*TCClassPointer, error) { // read JAVA_TC_CLASSDESC Flag flag, _ := stream.PeekN(1) if flag[0] == JAVA_TC_NULL { @@ -63,7 +91,7 @@ func readTCClassPointer(stream *ObjectStream, bag *ClassBag) (*TCClassPointer, e Reference: reference, }, nil } else if flag[0] == JAVA_TC_CLASSDESC { - desc, err := readTCClassDesc(stream, bag) + desc, err := readTCClassDesc(stream) if err != nil { return nil, err } diff --git a/tc_content.go b/tc_content.go index 973047b..5f70738 100644 --- a/tc_content.go +++ b/tc_content.go @@ -55,7 +55,7 @@ func readTCContent(stream *ObjectStream) (*TCContent, error) { case JAVA_TC_CLASS: content.Class, err = readTCClass(stream) case JAVA_TC_CLASSDESC: - content.ClassDesc, err = readTCClassDesc(stream, nil) + content.ClassDesc, err = readTCClassDesc(stream) case JAVA_TC_NULL: content.Null = readTCNull(stream) case JAVA_TC_REFERENCE: diff --git a/tc_enum.go b/tc_enum.go index d7cc649..07fd855 100644 --- a/tc_enum.go +++ b/tc_enum.go @@ -2,7 +2,7 @@ package javaserialize type TCEnum struct { ClassPointer *TCClassPointer - ConstantName *TCString + ConstantName *TCStringPointer } func (e *TCEnum) ToBytes() []byte { @@ -17,13 +17,13 @@ func readTCEnum(stream *ObjectStream) (*TCEnum, error) { var err error _, _ = stream.ReadN(1) - enum.ClassPointer, err = readTCClassPointer(stream, nil) + enum.ClassPointer, err = readTCClassPointer(stream) if err != nil { return nil, err } stream.AddReference(enum) - enum.ConstantName, err = readTCString(stream) + enum.ConstantName, err = readTCStringPointer(stream) if err != nil { return nil, err } diff --git a/tc_fielddesc.go b/tc_fielddesc.go index 22ca855..2c6b80b 100644 --- a/tc_fielddesc.go +++ b/tc_fielddesc.go @@ -12,7 +12,7 @@ var AllTypecode = append(PrimitiveTypecode, ObjectTypecode...) type TCFieldDesc struct { TypeCode string FieldName *TCString - ClassName *TCString + ClassName *TCStringPointer } func (f *TCFieldDesc) ToBytes() []byte { @@ -47,7 +47,7 @@ func readTCField(stream *ObjectStream) (*TCFieldDesc, error) { } if funk.ContainsString(ObjectTypecode, fieldDesc.TypeCode) { - fieldDesc.ClassName, err = readTCString(stream) + fieldDesc.ClassName, err = readTCStringPointer(stream) if err != nil { return nil, err } diff --git a/tc_object.go b/tc_object.go index f1be242..311a258 100644 --- a/tc_object.go +++ b/tc_object.go @@ -18,10 +18,9 @@ func (oo *TCObject) ToBytes() []byte { func readTCObject(stream *ObjectStream) (*TCObject, error) { var obj = new(TCObject) var err error - var bag = new(ClassBag) // save current TCClassDesc _, _ = stream.ReadN(1) - obj.ClassPointer, err = readTCClassPointer(stream, bag) + obj.ClassPointer, err = readTCClassPointer(stream) if err != nil { return nil, err } @@ -29,14 +28,11 @@ func readTCObject(stream *ObjectStream) (*TCObject, error) { stream.AddReference(obj) if obj.ClassPointer.Flag == JAVA_TC_NULL { return obj, nil - } else if obj.ClassPointer.Flag == JAVA_TC_REFERENCE { - classData, err := readTCClassData(stream, obj.ClassPointer.Reference.ClassDesc) - if err != nil { - return nil, err - } + } - obj.ClassDatas = append(obj.ClassDatas, classData) - return obj, nil + bag, err := obj.ClassPointer.FindClassBag(stream) + if err != nil { + return nil, err } for i := len(bag.Classes) - 1; i >= 0; i-- { diff --git a/tc_reference.go b/tc_reference.go index 6e07fc7..29747cc 100644 --- a/tc_reference.go +++ b/tc_reference.go @@ -57,5 +57,5 @@ func readTCReference(stream *ObjectStream) (*TCReference, error) { } Failed: - return nil, fmt.Errorf("object reference %v is not found", handler) + return nil, fmt.Errorf("object reference %v is not found on index %v", handler, stream.CurrentIndex()) } diff --git a/tc_stringpointer.go b/tc_stringpointer.go new file mode 100644 index 0000000..8f67450 --- /dev/null +++ b/tc_stringpointer.go @@ -0,0 +1,42 @@ +package javaserialize + +import "fmt" + +type TCStringPointer struct { + IsRef bool + String *TCString + Reference *TCReference +} + +func (sp *TCStringPointer) ToBytes() []byte { + var bs = []byte{JAVA_TC_STRING} + if sp.IsRef { + bs = append(bs, sp.Reference.ToBytes()...) + } else { + bs = append(bs, sp.String.ToBytes()...) + } + + return bs +} + +func readTCStringPointer(stream *ObjectStream) (*TCStringPointer, error) { + flag, err := stream.PeekN(1) + if err != nil { + return nil, fmt.Errorf("read JAVA_TC_STRING pointer failed on index %v", stream.CurrentIndex()) + } + + var sp = TCStringPointer { + IsRef: flag[0] != JAVA_TC_STRING, + } + if flag[0] == JAVA_TC_STRING { + sp.String, err = readTCString(stream) + } else { + sp.Reference, err = readTCReference(stream) + } + + if err != nil { + return nil, err + } else { + return &sp, nil + } +}