//===--- RecursiveSymbolVisitor.h - Clang refactoring library -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// A wrapper class around \c RecursiveASTVisitor that visits each
/// occurrences of a named symbol.
///
//===----------------------------------------------------------------------===//
#ifndef LLVM_CLANG_TOOLING_REFACTORING_RECURSIVESYMBOLVISITOR_H
#define LLVM_CLANG_TOOLING_REFACTORING_RECURSIVESYMBOLVISITOR_H
#include "clang/AST/AST.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Lex/Lexer.h"
namespace clang {
namespace tooling {
/// Traverses the AST and visits the occurrence of each named symbol in the
/// given nodes.
template <typename T>
class RecursiveSymbolVisitor
: public RecursiveASTVisitor<RecursiveSymbolVisitor<T>> {
using BaseType = RecursiveASTVisitor<RecursiveSymbolVisitor<T>>;
public:
RecursiveSymbolVisitor(const SourceManager &SM, const LangOptions &LangOpts)
: SM(SM), LangOpts(LangOpts) {}
bool visitSymbolOccurrence(const NamedDecl *ND,
ArrayRef<SourceRange> NameRanges) {
return true;
}
// Declaration visitors:
bool VisitNamedDecl(const NamedDecl *D) {
return isa<CXXConversionDecl>(D) ? true : visit(D, D->getLocation());
}
bool VisitCXXConstructorDecl(const CXXConstructorDecl *CD) {
for (const auto *Initializer : CD->inits()) {
// Ignore implicit initializers.
if (!Initializer->isWritten())
continue;
if (const FieldDecl *FD = Initializer->getMember()) {
if (!visit(FD, Initializer->getSourceLocation(),
Lexer::getLocForEndOfToken(Initializer->getSourceLocation(),
0, SM, LangOpts)))
return false;
}
}
return true;
}
// Expression visitors:
bool VisitDeclRefExpr(const DeclRefExpr *Expr) {
return visit(Expr->getFoundDecl(), Expr->getLocation());
}
bool VisitMemberExpr(const MemberExpr *Expr) {
return visit(Expr->getFoundDecl().getDecl(), Expr->getMemberLoc());
}
bool VisitOffsetOfExpr(const OffsetOfExpr *S) {
for (unsigned I = 0, E = S->getNumComponents(); I != E; ++I) {
const OffsetOfNode &Component = S->getComponent(I);
if (Component.getKind() == OffsetOfNode::Field) {
if (!visit(Component.getField(), Component.getEndLoc()))
return false;
}
// FIXME: Try to resolve dependent field references.
}
return true;
}
// Other visitors:
bool VisitTypeLoc(const TypeLoc Loc) {
const SourceLocation TypeBeginLoc = Loc.getBeginLoc();
const SourceLocation TypeEndLoc =
Lexer::getLocForEndOfToken(TypeBeginLoc, 0, SM, LangOpts);
if (const auto *TemplateTypeParm =
dyn_cast<TemplateTypeParmType>(Loc.getType())) {
if (!visit(TemplateTypeParm->getDecl(), TypeBeginLoc, TypeEndLoc))
return false;
}
if (const auto *TemplateSpecType =
dyn_cast<TemplateSpecializationType>(Loc.getType())) {
if (!visit(TemplateSpecType->getTemplateName().getAsTemplateDecl(),
TypeBeginLoc, TypeEndLoc))
return false;
}
if (const Type *TP = Loc.getTypePtr()) {
if (TP->getTypeClass() == clang::Type::Record)
return visit(TP->getAsCXXRecordDecl(), TypeBeginLoc, TypeEndLoc);
}
return true;
}
bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
const SourceLocation TypeEndLoc =
Lexer::getLocForEndOfToken(TL.getBeginLoc(), 0, SM, LangOpts);
return visit(TL.getTypedefNameDecl(), TL.getBeginLoc(), TypeEndLoc);
}
bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
// The base visitor will visit NNSL prefixes, so we should only look at
// the current NNS.
if (NNS) {
const NamespaceDecl *ND = NNS.getNestedNameSpecifier()->getAsNamespace();
if (!visit(ND, NNS.getLocalBeginLoc(), NNS.getLocalEndLoc()))
return false;
}
return BaseType::TraverseNestedNameSpecifierLoc(NNS);
}
bool VisitDesignatedInitExpr(const DesignatedInitExpr *E) {
for (const DesignatedInitExpr::Designator &D : E->designators()) {
if (D.isFieldDesignator() && D.getField()) {
const FieldDecl *Decl = D.getField();
if (!visit(Decl, D.getFieldLoc(), D.getFieldLoc()))
return false;
}
}
return true;
}
private:
const SourceManager &SM;
const LangOptions &LangOpts;
bool visit(const NamedDecl *ND, SourceLocation BeginLoc,
SourceLocation EndLoc) {
return static_cast<T *>(this)->visitSymbolOccurrence(
ND, SourceRange(BeginLoc, EndLoc));
}
bool visit(const NamedDecl *ND, SourceLocation Loc) {
return visit(ND, Loc, Lexer::getLocForEndOfToken(Loc, 0, SM, LangOpts));
}
};
} // end namespace tooling
} // end namespace clang
#endif // LLVM_CLANG_TOOLING_REFACTORING_RECURSIVESYMBOLVISITOR_H