From 8c15199ca9d756c1d375836e934534ca2d04d1a2 Mon Sep 17 00:00:00 2001 From: Owen Mansel-Chan Date: Fri, 8 Apr 2022 12:12:46 +0100 Subject: [PATCH] Use generic struct field not instantiated one in Uses We do not extract instantiated named types, and instead use the generic type. But fields of the underlying struct of an instantiated named types are obtained from the Uses map. We solve this keeping track of which objects should be overridden by which other objects. --- extractor/extractor.go | 34 ++++++++++++++++++++++++++++++++++ extractor/trap/trapwriter.go | 2 ++ 2 files changed, 36 insertions(+) diff --git a/extractor/extractor.go b/extractor/extractor.go index 7a14e876..f57bd2eb 100644 --- a/extractor/extractor.go +++ b/extractor/extractor.go @@ -1531,6 +1531,7 @@ func extractType(tw *trap.Writer, tp types.Type) trap.Label { dbscheme.TypeNameTable.Emit(tw, lbl, origintp.Obj().Name()) underlying := origintp.Underlying() extractUnderlyingType(tw, lbl, underlying) + trackInstantiatedStructFields(tw, tp, origintp) entitylbl, exists := tw.Labeler.LookupObjectID(origintp.Obj(), lbl) if entitylbl == trap.InvalidLabel { @@ -1902,6 +1903,9 @@ func getObjectBeingUsed(tw *trap.Writer, ident *ast.Ident) types.Object { if obj == nil { return nil } + if override, ok := tw.ObjectsOverride[obj]; ok { + return override + } if funcObj, ok := obj.(*types.Func); ok { sig := funcObj.Type().(*types.Signature) if recv := sig.Recv(); recv != nil { @@ -1948,3 +1952,33 @@ func tryGetGenericType(tp types.Type) (*types.Named, bool) { } return nil, false } + +// trackInstantiatedStructFields tries to give the fields of an instantiated +// struct type underlying `tp` the same labels as the corresponding fields of +// the generic struct type. This is so that when we come across the +// instantiated field in `tw.Package.TypesInfo.Uses` we will get the label for +// the generic field instead. +func trackInstantiatedStructFields(tw *trap.Writer, tp, origintp *types.Named) { + if tp == origintp { + return + } + + if instantiatedStruct, ok := tp.Underlying().(*types.Struct); ok { + genericStruct, ok2 := origintp.Underlying().(*types.Struct) + if !ok2 { + log.Fatalf( + "Error: underlying type of instantiated type is a struct but underlying type of generic type is %s", + origintp.Underlying()) + } + + if instantiatedStruct.NumFields() != genericStruct.NumFields() { + log.Fatalf( + "Error: instantiated struct %s has different number of fields than the generic version %s (%d != %d)", + instantiatedStruct, genericStruct, instantiatedStruct.NumFields(), genericStruct.NumFields()) + } + + for i := 0; i < instantiatedStruct.NumFields(); i++ { + tw.ObjectsOverride[instantiatedStruct.Field(i)] = genericStruct.Field(i) + } + } +} diff --git a/extractor/trap/trapwriter.go b/extractor/trap/trapwriter.go index 21a47078..713cb341 100644 --- a/extractor/trap/trapwriter.go +++ b/extractor/trap/trapwriter.go @@ -27,6 +27,7 @@ type Writer struct { Package *packages.Package TypesOverride map[ast.Expr]types.Type TypeParamParent map[*types.TypeParam]Label + ObjectsOverride map[types.Object]types.Object } func FileFor(path string) (string, error) { @@ -66,6 +67,7 @@ func NewWriter(path string, pkg *packages.Package) (*Writer, error) { pkg, make(map[ast.Expr]types.Type), make(map[*types.TypeParam]Label), + make(map[types.Object]types.Object), } tw.Labeler = newLabeler(tw) return tw, nil