//===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a SMT generic Solver API, which will be the base class
// for every SMT solver specific class.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_SUPPORT_SMTAPI_H
#define LLVM_SUPPORT_SMTAPI_H
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
namespace llvm {
/// Generic base class for SMT sorts
class SMTSort {
public:
SMTSort() = default;
virtual ~SMTSort() = default;
/// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
/// Returns true if the sort is a floating-point, calls isFloatSortImpl().
virtual bool isFloatSort() const { return isFloatSortImpl(); }
/// Returns true if the sort is a boolean, calls isBooleanSortImpl().
virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
/// Returns the bitvector size, fails if the sort is not a bitvector
/// Calls getBitvectorSortSizeImpl().
virtual unsigned getBitvectorSortSize() const {
assert(isBitvectorSort() && "Not a bitvector sort!");
unsigned Size = getBitvectorSortSizeImpl();
assert(Size && "Size is zero!");
return Size;
};
/// Returns the floating-point size, fails if the sort is not a floating-point
/// Calls getFloatSortSizeImpl().
virtual unsigned getFloatSortSize() const {
assert(isFloatSort() && "Not a floating-point sort!");
unsigned Size = getFloatSortSizeImpl();
assert(Size && "Size is zero!");
return Size;
};
virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
bool operator<(const SMTSort &Other) const {
llvm::FoldingSetNodeID ID1, ID2;
Profile(ID1);
Other.Profile(ID2);
return ID1 < ID2;
}
friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
return LHS.equal_to(RHS);
}
virtual void print(raw_ostream &OS) const = 0;
LLVM_DUMP_METHOD void dump() const;
protected:
/// Query the SMT solver and returns true if two sorts are equal (same kind
/// and bit width). This does not check if the two sorts are the same objects.
virtual bool equal_to(SMTSort const &other) const = 0;
/// Query the SMT solver and checks if a sort is bitvector.
virtual bool isBitvectorSortImpl() const = 0;
/// Query the SMT solver and checks if a sort is floating-point.
virtual bool isFloatSortImpl() const = 0;
/// Query the SMT solver and checks if a sort is boolean.
virtual bool isBooleanSortImpl() const = 0;
/// Query the SMT solver and returns the sort bit width.
virtual unsigned getBitvectorSortSizeImpl() const = 0;
/// Query the SMT solver and returns the sort bit width.
virtual unsigned getFloatSortSizeImpl() const = 0;
};
/// Shared pointer for SMTSorts, used by SMTSolver API.
using SMTSortRef = const SMTSort *;
/// Generic base class for SMT exprs
class SMTExpr {
public:
SMTExpr() = default;
virtual ~SMTExpr() = default;
bool operator<(const SMTExpr &Other) const {
llvm::FoldingSetNodeID ID1, ID2;
Profile(ID1);
Other.Profile(ID2);
return ID1 < ID2;
}
virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
return LHS.equal_to(RHS);
}
virtual void print(raw_ostream &OS) const = 0;
LLVM_DUMP_METHOD void dump() const;
protected:
/// Query the SMT solver and returns true if two sorts are equal (same kind
/// and bit width). This does not check if the two sorts are the same objects.
virtual bool equal_to(SMTExpr const &other) const = 0;
};
/// Shared pointer for SMTExprs, used by SMTSolver API.
using SMTExprRef = const SMTExpr *;
/// Generic base class for SMT Solvers
///
/// This class is responsible for wrapping all sorts and expression generation,
/// through the mk* methods. It also provides methods to create SMT expressions
/// straight from clang's AST, through the from* methods.
class SMTSolver {
public:
SMTSolver() = default;
virtual ~SMTSolver() = default;
LLVM_DUMP_METHOD void dump() const;
// Returns an appropriate floating-point sort for the given bitwidth.
SMTSortRef getFloatSort(unsigned BitWidth) {
switch (BitWidth) {
case 16:
return getFloat16Sort();
case 32:
return getFloat32Sort();
case 64:
return getFloat64Sort();
case 128:
return getFloat128Sort();
default:;
}
llvm_unreachable("Unsupported floating-point bitwidth!");
}
// Returns a boolean sort.
virtual SMTSortRef getBoolSort() = 0;
// Returns an appropriate bitvector sort for the given bitwidth.
virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
// Returns a floating-point sort of width 16
virtual SMTSortRef getFloat16Sort() = 0;
// Returns a floating-point sort of width 32
virtual SMTSortRef getFloat32Sort() = 0;
// Returns a floating-point sort of width 64
virtual SMTSortRef getFloat64Sort() = 0;
// Returns a floating-point sort of width 128
virtual SMTSortRef getFloat128Sort() = 0;
// Returns an appropriate sort for the given AST.
virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
/// Given a constraint, adds it to the solver
virtual void addConstraint(const SMTExprRef &Exp) const = 0;
/// Creates a bitvector addition operation
virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector subtraction operation
virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector multiplication operation
virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector signed modulus operation
virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector unsigned modulus operation
virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector signed division operation
virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector unsigned division operation
virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector logical shift left operation
virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector arithmetic shift right operation
virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector logical shift right operation
virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector negation operation
virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
/// Creates a bitvector not operation
virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
/// Creates a bitvector xor operation
virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector or operation
virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector and operation
virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector unsigned less-than operation
virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector signed less-than operation
virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector unsigned greater-than operation
virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector signed greater-than operation
virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector unsigned less-equal-than operation
virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector signed less-equal-than operation
virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector unsigned greater-equal-than operation
virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a bitvector signed greater-equal-than operation
virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a boolean not operation
virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
/// Creates a boolean equality operation
virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a boolean and operation
virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a boolean or operation
virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a boolean ite operation
virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
const SMTExprRef &F) = 0;
/// Creates a bitvector sign extension operation
virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
/// Creates a bitvector zero extension operation
virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
/// Creates a bitvector extract operation
virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
const SMTExprRef &Exp) = 0;
/// Creates a bitvector concat operation
virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
const SMTExprRef &RHS) = 0;
/// Creates a predicate that checks for overflow in a bitvector addition
/// operation
virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
const SMTExprRef &RHS,
bool isSigned) = 0;
/// Creates a predicate that checks for underflow in a signed bitvector
/// addition operation
virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
const SMTExprRef &RHS) = 0;
/// Creates a predicate that checks for overflow in a signed bitvector
/// subtraction operation
virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
const SMTExprRef &RHS) = 0;
/// Creates a predicate that checks for underflow in a bitvector subtraction
/// operation
virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
const SMTExprRef &RHS,
bool isSigned) = 0;
/// Creates a predicate that checks for overflow in a signed bitvector
/// division/modulus operation
virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
const SMTExprRef &RHS) = 0;
/// Creates a predicate that checks for overflow in a bitvector negation
/// operation
virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
/// Creates a predicate that checks for overflow in a bitvector multiplication
/// operation
virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
const SMTExprRef &RHS,
bool isSigned) = 0;
/// Creates a predicate that checks for underflow in a signed bitvector
/// multiplication operation
virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
const SMTExprRef &RHS) = 0;
/// Creates a floating-point negation operation
virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
/// Creates a floating-point isInfinite operation
virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
/// Creates a floating-point isNaN operation
virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
/// Creates a floating-point isNormal operation
virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
/// Creates a floating-point isZero operation
virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
/// Creates a floating-point multiplication operation
virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point division operation
virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point remainder operation
virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point addition operation
virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point subtraction operation
virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point less-than operation
virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point greater-than operation
virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point less-than-or-equal operation
virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point greater-than-or-equal operation
virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
/// Creates a floating-point equality operation
virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
const SMTExprRef &RHS) = 0;
/// Creates a floating-point conversion from floatint-point to floating-point
/// operation
virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
/// Creates a floating-point conversion from signed bitvector to
/// floatint-point operation
virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
const SMTSortRef &To) = 0;
/// Creates a floating-point conversion from unsigned bitvector to
/// floatint-point operation
virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
const SMTSortRef &To) = 0;
/// Creates a floating-point conversion from floatint-point to signed
/// bitvector operation
virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
/// Creates a floating-point conversion from floatint-point to unsigned
/// bitvector operation
virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
/// Creates a new symbol, given a name and a sort
virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
// Returns an appropriate floating-point rounding mode.
virtual SMTExprRef getFloatRoundingMode() = 0;
// If the a model is available, returns the value of a given bitvector symbol
virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
bool isUnsigned) = 0;
// If the a model is available, returns the value of a given boolean symbol
virtual bool getBoolean(const SMTExprRef &Exp) = 0;
/// Constructs an SMTExprRef from a boolean.
virtual SMTExprRef mkBoolean(const bool b) = 0;
/// Constructs an SMTExprRef from a finite APFloat.
virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
/// Constructs an SMTExprRef from an APSInt and its bit width
virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
/// Given an expression, extract the value of this operand in the model.
virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
/// Given an expression extract the value of this operand in the model.
virtual bool getInterpretation(const SMTExprRef &Exp,
llvm::APFloat &Float) = 0;
/// Check if the constraints are satisfiable
virtual std::optional<bool> check() const = 0;
/// Push the current solver state
virtual void push() = 0;
/// Pop the previous solver state
virtual void pop(unsigned NumStates = 1) = 0;
/// Reset the solver and remove all constraints.
virtual void reset() = 0;
/// Checks if the solver supports floating-points.
virtual bool isFPSupported() = 0;
virtual void print(raw_ostream &OS) const = 0;
};
/// Shared pointer for SMTSolvers.
using SMTSolverRef = std::shared_ptr<SMTSolver>;
/// Convenience method to create and Z3Solver object
SMTSolverRef CreateZ3Solver();
} // namespace llvm
#endif