//==- AMDGPUArgumentrUsageInfo.h - Function Arg Usage Info -------*- C++ -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUARGUMENTUSAGEINFO_H
#define LLVM_LIB_TARGET_AMDGPU_AMDGPUARGUMENTUSAGEINFO_H

#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/PassRegistry.h"
#include <variant>

namespace llvm {

void initializeAMDGPUArgumentUsageInfoWrapperLegacyPass(PassRegistry &);

class Function;
class LLT;
class raw_ostream;
class TargetRegisterClass;
class TargetRegisterInfo;

struct ArgDescriptor {
private:
  friend struct AMDGPUFunctionArgInfo;
  friend class AMDGPUArgumentUsageInfo;

  std::variant<std::monostate, MCRegister, unsigned> Val;

  // Bitmask to locate argument within the register.
  unsigned Mask;

public:
  ArgDescriptor(unsigned Mask = ~0u) : Mask(Mask) {}

  static ArgDescriptor createRegister(Register Reg, unsigned Mask = ~0u) {
    ArgDescriptor Ret(Mask);
    Ret.Val = Reg.asMCReg();
    return Ret;
  }

  static ArgDescriptor createStack(unsigned Offset, unsigned Mask = ~0u) {
    ArgDescriptor Ret(Mask);
    Ret.Val = Offset;
    return Ret;
  }

  static ArgDescriptor createArg(const ArgDescriptor &Arg, unsigned Mask) {
    // Copy the descriptor, then change the mask.
    ArgDescriptor Ret(Arg);
    Ret.Mask = Mask;
    return Ret;
  }

  bool isSet() const { return !std::holds_alternative<std::monostate>(Val); }

  explicit operator bool() const {
    return isSet();
  }

  bool isRegister() const { return std::holds_alternative<MCRegister>(Val); }

  MCRegister getRegister() const { return std::get<MCRegister>(Val); }

  unsigned getStackOffset() const { return std::get<unsigned>(Val); }

  unsigned getMask() const {
    // None of the target SGPRs or VGPRs are expected to have a 'zero' mask.
    assert(Mask && "Invalid mask.");
    return Mask;
  }

  bool isMasked() const {
    return Mask != ~0u;
  }

  void print(raw_ostream &OS, const TargetRegisterInfo *TRI = nullptr) const;
};

inline raw_ostream &operator<<(raw_ostream &OS, const ArgDescriptor &Arg) {
  Arg.print(OS);
  return OS;
}

struct KernArgPreloadDescriptor : public ArgDescriptor {
  KernArgPreloadDescriptor() = default;
  SmallVector<MCRegister> Regs;
};

struct AMDGPUFunctionArgInfo {
  // clang-format off
  enum PreloadedValue {
    // SGPRS:
    PRIVATE_SEGMENT_BUFFER = 0,
    DISPATCH_PTR        =  1,
    QUEUE_PTR           =  2,
    KERNARG_SEGMENT_PTR =  3,
    DISPATCH_ID         =  4,
    FLAT_SCRATCH_INIT   =  5,
    LDS_KERNEL_ID       =  6, // LLVM internal, not part of the ABI
    WORKGROUP_ID_X      = 10, // Also used for cluster ID X.
    WORKGROUP_ID_Y      = 11, // Also used for cluster ID Y.
    WORKGROUP_ID_Z      = 12, // Also used for cluster ID Z.
    PRIVATE_SEGMENT_WAVE_BYTE_OFFSET = 14,
    IMPLICIT_BUFFER_PTR = 15,
    IMPLICIT_ARG_PTR = 16,
    PRIVATE_SEGMENT_SIZE = 17,
    CLUSTER_WORKGROUP_ID_X = 21,
    CLUSTER_WORKGROUP_ID_Y = 22,
    CLUSTER_WORKGROUP_ID_Z = 23,
    CLUSTER_WORKGROUP_MAX_ID_X = 24,
    CLUSTER_WORKGROUP_MAX_ID_Y = 25,
    CLUSTER_WORKGROUP_MAX_ID_Z = 26,
    CLUSTER_WORKGROUP_MAX_FLAT_ID = 27,

    // VGPRS:
    WORKITEM_ID_X       = 28,
    WORKITEM_ID_Y       = 29,
    WORKITEM_ID_Z       = 30,
    FIRST_VGPR_VALUE    = WORKITEM_ID_X
  };
  // clang-format on

  // Kernel input registers setup for the HSA ABI in allocation order.

  // User SGPRs in kernels
  // XXX - Can these require argument spills?
  ArgDescriptor PrivateSegmentBuffer;
  ArgDescriptor DispatchPtr;
  ArgDescriptor QueuePtr;
  ArgDescriptor KernargSegmentPtr;
  ArgDescriptor DispatchID;
  ArgDescriptor FlatScratchInit;
  ArgDescriptor PrivateSegmentSize;
  ArgDescriptor LDSKernelId;

  // System SGPRs in kernels.
  ArgDescriptor WorkGroupIDX;
  ArgDescriptor WorkGroupIDY;
  ArgDescriptor WorkGroupIDZ;
  ArgDescriptor WorkGroupInfo;
  ArgDescriptor PrivateSegmentWaveByteOffset;

  // Pointer with offset from kernargsegmentptr to where special ABI arguments
  // are passed to callable functions.
  ArgDescriptor ImplicitArgPtr;

  // Input registers for non-HSA ABI
  ArgDescriptor ImplicitBufferPtr;

  // VGPRs inputs. For entry functions these are either v0, v1 and v2 or packed
  // into v0, 10 bits per dimension if packed-tid is set.
  ArgDescriptor WorkItemIDX;
  ArgDescriptor WorkItemIDY;
  ArgDescriptor WorkItemIDZ;

  // Map the index of preloaded kernel arguments to its descriptor.
  SmallDenseMap<int, KernArgPreloadDescriptor> PreloadKernArgs{};
  // The first user SGPR allocated for kernarg preloading.
  Register FirstKernArgPreloadReg;

  std::tuple<const ArgDescriptor *, const TargetRegisterClass *, LLT>
  getPreloadedValue(PreloadedValue Value) const;

  static AMDGPUFunctionArgInfo fixedABILayout();
};

class AMDGPUArgumentUsageInfo {
private:
  DenseMap<const Function *, AMDGPUFunctionArgInfo> ArgInfoMap;

public:
  static const AMDGPUFunctionArgInfo ExternFunctionInfo;
  static const AMDGPUFunctionArgInfo FixedABIFunctionInfo;

  void print(raw_ostream &OS, const Module *M = nullptr) const;

  void clear() { ArgInfoMap.clear(); }

  void setFuncArgInfo(const Function &F, const AMDGPUFunctionArgInfo &ArgInfo) {
    ArgInfoMap[&F] = ArgInfo;
  }

  const AMDGPUFunctionArgInfo &lookupFuncArgInfo(const Function &F) const;

  bool invalidate(Module &M, const PreservedAnalyses &PA,
                  ModuleAnalysisManager::Invalidator &Inv);
};

class AMDGPUArgumentUsageInfoWrapperLegacy : public ImmutablePass {
  std::unique_ptr<AMDGPUArgumentUsageInfo> AUIP;

public:
  static char ID;

  AMDGPUArgumentUsageInfoWrapperLegacy() : ImmutablePass(ID) {
    initializeAMDGPUArgumentUsageInfoWrapperLegacyPass(
        *PassRegistry::getPassRegistry());
  }

  AMDGPUArgumentUsageInfo &getArgUsageInfo() { return *AUIP; }
  const AMDGPUArgumentUsageInfo &getArgUsageInfo() const { return *AUIP; }

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesAll();
  }

  bool doInitialization(Module &M) override {
    AUIP = std::make_unique<AMDGPUArgumentUsageInfo>();
    return false;
  }

  bool doFinalization(Module &M) override {
    AUIP->clear();
    return false;
  }

  void print(raw_ostream &OS, const Module *M = nullptr) const override {
    AUIP->print(OS, M);
  }
};

class AMDGPUArgumentUsageAnalysis
    : public AnalysisInfoMixin<AMDGPUArgumentUsageAnalysis> {
  friend AnalysisInfoMixin<AMDGPUArgumentUsageAnalysis>;
  static AnalysisKey Key;

public:
  using Result = AMDGPUArgumentUsageInfo;

  AMDGPUArgumentUsageInfo run(Module &M, ModuleAnalysisManager &);
};

} // end namespace llvm

#endif
