vulkan-playground/vulkan_pipeline_utl.h

355 lines
15 KiB
C++

/*
* Copyright (c) 2024 mittorn
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sub license, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice (including the
* next paragraph) shall be included in all copies or substantial portions
* of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
* IN NO EVENT SHALL PRECISION INSIGHT AND/OR ITS SUPPLIERS BE LIABLE FOR
* ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef VULKAN_PIPELINE_UTL_H
#define VULKAN_PIPELINE_UTL_H
#include <vulkan/vulkan.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#ifndef VK_CHECK_RESULT
#define VK_CHECK_RESULT(f) \
{ \
VkResult res = (f); \
if (res != VK_SUCCESS) \
{ \
printf("Fatal : VkResult is %d in %s at line %d\n", res, __FILE__, __LINE__); \
assert(res == VK_SUCCESS); \
} \
}
#endif
#include <math.h>
#include "positional_utl.h"
template<class A, class B>
struct CompareTypes_w {
constexpr static bool res = false;
};
template<class T>
struct CompareTypes_w<T, T>{
constexpr static bool res = true;
};
#define CompareTypes(A,B) (CompareTypes_w<A,B>::res)
template <typename A>
void BadStaticAssert(const A &arg)
{
static_assert(!CompareTypes(A,A), "BadStaticAssert");
}
struct BaseVulkanPipeline
{
VkDevice device = NULL;
VkPipeline pipeline = NULL;
VkPipelineLayout pipelineLayout = NULL;
VkDescriptorPool descriptorPool = NULL;
VkDescriptorSetLayout descriptorSetLayout = NULL;
template <typename... Args>
void CreateDescriptorSetLayout(const Args&... arguments)
{
const VkDescriptorSetLayoutBinding descriptorSetLayoutBinding[sizeof... (Args)] = {arguments...};
VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO};
descriptorSetLayoutCreateInfo.bindingCount = sizeof... (Args);
descriptorSetLayoutCreateInfo.pBindings = descriptorSetLayoutBinding;
VK_CHECK_RESULT(vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, NULL, &descriptorSetLayout));
}
// todo: calculate based on layout?
// todo: no way to pass flags!
template <typename... Args>
void CreatePool(size_t maxSets, const Args&... arguments)
{
const VkDescriptorPoolSize descriptorPoolSize[sizeof... (Args)] = {arguments...};
VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO};
descriptorPoolCreateInfo.maxSets = maxSets; // we only need to allocate one descriptor set from the pool.
descriptorPoolCreateInfo.poolSizeCount = sizeof... (Args);
descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSize;
VK_CHECK_RESULT(vkCreateDescriptorPool(device, &descriptorPoolCreateInfo, NULL, &descriptorPool));
}
template <typename... Args>
void WriteDescriptors(const Args&... arguments)
{
const VkWriteDescriptorSet writeDescriptorSet[sizeof... (Args)] = {arguments...};
vkUpdateDescriptorSets(device, sizeof... (Args), writeDescriptorSet, 0, NULL);
}
VkDescriptorSet AllocateSingleDescriptorSet(VkDescriptorSetLayout layout = NULL)
{
VkDescriptorSet ret;
VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO};
if(!layout)
layout = descriptorSetLayout;
descriptorSetAllocateInfo.descriptorPool = descriptorPool; // pool to allocate from.
descriptorSetAllocateInfo.descriptorSetCount = 1; // allocate a single descriptor set.
descriptorSetAllocateInfo.pSetLayouts = &layout;
VK_CHECK_RESULT(vkAllocateDescriptorSets(device, &descriptorSetAllocateInfo, &ret));
return ret;
}
VkPipelineShaderStageCreateInfo ShaderFromFile(VkShaderModule &outShaderModule, const char *filename, VkShaderStageFlagBits stage, const VkSpecializationInfo *sinfo = NULL, const char *entrypoint = "main")
{
// todo: rewrite this in safer way
FILE* fp = fopen(filename, "rb");
if (fp == NULL) {
printf("Could not find or open file: %s\n", filename);
}
// get file size.
fseek(fp, 0, SEEK_END);
long filesize = ftell(fp);
fseek(fp, 0, SEEK_SET);
size_t filesizepadded = long(ceil(filesize / 4.0)) * 4;
uint32_t contents[filesizepadded/4];
char *str = (char*) contents;
fread(contents, filesize, sizeof(char), fp);
fclose(fp);
// data padding.
for (int i = filesize; i < filesizepadded; i++) {
str[i] = 0;
}
VkShaderModuleCreateInfo info = {VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO};
info.codeSize = filesizepadded;
info.pCode = contents;
VK_CHECK_RESULT(vkCreateShaderModule(device, &info, NULL, &outShaderModule));
return $M(VkPipelineShaderStageCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO},
$(stage) = stage,
$(module) = outShaderModule,
$(pName) = entrypoint,
$(pSpecializationInfo) = sinfo);
}
static VkPipelineInputAssemblyStateCreateInfo AssemblyTopology(VkPrimitiveTopology topology = VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST, bool restart = false, VkPipelineInputAssemblyStateCreateFlags flags = 0)
{
return {
.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO,
.topology = topology,
.primitiveRestartEnable = restart
};
}
template <typename... Args>
static VkPipelineRasterizationStateCreateInfo RasterMode(VkFrontFace frontFace = VK_FRONT_FACE_COUNTER_CLOCKWISE, VkPolygonMode polygonMode = VK_POLYGON_MODE_FILL, VkCullModeFlags cull = VK_CULL_MODE_BACK_BIT, float lineWidth = 1.0f, bool depthClampEnable = false, const Args&... arguments)
{
return $M(VkPipelineRasterizationStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO},
$(depthClampEnable) = depthClampEnable, $(polygonMode) = polygonMode,
$(cullMode) = cull, $(frontFace) = frontFace, $(lineWidth) = lineWidth,
arguments...);
}
struct BlendOp
{
VkBlendOp op;
VkBlendFactor srcFactor;
VkBlendFactor dstFactor;
};
static VkPipelineColorBlendAttachmentState BlendAttachment(bool enable = false, const BlendOp &color = BlendOp{}, const BlendOp& alpha = BlendOp{}, VkColorComponentFlags colorWriteMask = 0xf)
{
return {
.blendEnable = enable,
.srcColorBlendFactor = color.srcFactor,
.dstColorBlendFactor = color.dstFactor,
.colorBlendOp = color.op,
.srcAlphaBlendFactor = alpha.srcFactor,
.dstAlphaBlendFactor = alpha.dstFactor,
.alphaBlendOp = alpha.op,
.colorWriteMask = 0xf
};
}
// todo: array initializer?
template <typename... Args>
static VkPipelineColorBlendStateCreateInfo ColorBlend(const VkPipelineColorBlendAttachmentState& attachment = BlendAttachment(), const Args&... arguments)
{
return $M(VkPipelineColorBlendStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO},
$(attachmentCount), $(pAttachments) = &attachment,arguments...);
}
// todo: check this defaults
template <typename... Args>
static VkPipelineDepthStencilStateCreateInfo DepthStencil(bool test = false, bool write = false,VkCompareOp compare = VK_COMPARE_OP_LESS_OR_EQUAL, VkCompareOp stencil = VK_COMPARE_OP_ALWAYS, const Args&... arguments )
{
return $M(VkPipelineDepthStencilStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO},
$(depthTestEnable) = test, $(depthWriteEnable) = write, $(depthCompareOp) = compare,
$(front) = Vals($(compareOp) = stencil), $(back) = Vals($(compareOp) = stencil),
arguments...);
}
template <typename... Args>
static VkPipelineViewportStateCreateInfo ViewportState(const Args&... arguments)
{
return $M(VkPipelineViewportStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO},
$(viewportCount), $(scissorCount), arguments...);
}
template <typename... Args>
static VkPipelineMultisampleStateCreateInfo MultisampleState( VkSampleCountFlagBits samples = VK_SAMPLE_COUNT_1_BIT, const Args&... arguments)
{
return $M(VkPipelineMultisampleStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO},
$(rasterizationSamples) = samples, arguments...);
}
ARRAY_WRAPPER(VkDynamicState, DynamicStates);
ARRAY_WRAPPER(VkPipelineShaderStageCreateInfo,Stages);
ARRAY_WRAPPER(VkVertexInputBindingDescription, VertexBindings);
ARRAY_WRAPPER(VkVertexInputAttributeDescription,VertexAttributes);
VkVertexInputBindingDescription VertexBinding(uint32_t binding, uint32_t stride, VkVertexInputRate inputRate = VK_VERTEX_INPUT_RATE_VERTEX )
{
return {binding, stride, inputRate};
}
VkVertexInputAttributeDescription VertAttrib(uint32_t location, uint32_t binding, uint32_t offset, VkFormat format = VK_FORMAT_R32G32B32_SFLOAT)
{
return {location, binding, format, offset};
}
template <typename Stages, typename VertexBindings, typename VertexAttributes, typename DynamicStates, typename... Ts>
void CreateGraphicsPipeline(VkRenderPass renderPass, const Stages &stages,
const VertexBindings &vertexBindings,
const VertexAttributes &vertexAttributes,
const DynamicStates &dynamicStates,const Ts&... args)
{
const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = $M(
VkPipelineLayoutCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO},
$(setLayoutCount), $(pSetLayouts) &= descriptorSetLayout);
VK_CHECK_RESULT(vkCreatePipelineLayout(device, &pipelineLayoutCreateInfo, NULL, &pipelineLayout));
VkGraphicsPipelineCreateInfo info = {VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO};
const VkPipelineDynamicStateCreateInfo dynamic_state = $M(
VkPipelineDynamicStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO},
$(pDynamicStates).ptrWithLength($(dynamicStateCount)) = dynamicStates);
const VkPipelineVertexInputStateCreateInfo vertex_input_state = $M(
VkPipelineVertexInputStateCreateInfo{VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO},
$(pVertexBindingDescriptions).ptrWithLength($(vertexBindingDescriptionCount)) = vertexBindings,
$(pVertexAttributeDescriptions).ptrWithLength($(vertexAttributeDescriptionCount)) = vertexAttributes);
const VkPipelineInputAssemblyStateCreateInfo asmInfo= AssemblyTopology();
info.pInputAssemblyState = &asmInfo;
const VkPipelineViewportStateCreateInfo viewportInfo = ViewportState();
info.pViewportState = &viewportInfo;
const VkPipelineMultisampleStateCreateInfo multisample = MultisampleState();
info.pMultisampleState = &multisample;
const VkPipelineRasterizationStateCreateInfo rasterMode = RasterMode();
info.pRasterizationState = &rasterMode;
const VkPipelineDepthStencilStateCreateInfo depthStencil = DepthStencil();
info.pDepthStencilState = &depthStencil;
const VkPipelineColorBlendAttachmentState blendAttachment = BlendAttachment();
const VkPipelineColorBlendStateCreateInfo blendState = ColorBlend(blendAttachment);
info.pColorBlendState = &blendState;
auto filler = [](VkGraphicsPipelineCreateInfo &info, const auto &arg){
if constexpr(CompareTypes(decltype(arg),const VkPipelineInputAssemblyStateCreateInfo &))info.pInputAssemblyState = &arg;
else if constexpr(CompareTypes(decltype(arg),const VkPipelineViewportStateCreateInfo&))info.pViewportState = &arg;
else if constexpr(CompareTypes(decltype(arg),const VkPipelineMultisampleStateCreateInfo&))info.pMultisampleState = &arg;
else if constexpr(CompareTypes(decltype(arg),const VkPipelineRasterizationStateCreateInfo&))info.pRasterizationState = &arg;
else if constexpr(CompareTypes(decltype(arg),const VkPipelineDepthStencilStateCreateInfo&))info.pDepthStencilState = &arg;
else if constexpr(CompareTypes(decltype(arg),const VkPipelineColorBlendStateCreateInfo&))info.pColorBlendState = &arg;
else $F(info,arg);
};
$F(info,
$(pStages).ptrWithLength($(stageCount)) = stages,
$(pVertexInputState) = &vertex_input_state,
$(pDynamicState) = &dynamic_state,
$(renderPass) = renderPass,
$(layout) = pipelineLayout
);
(filler(info,args),...);
vkCreateGraphicsPipelines(device, VK_NULL_HANDLE, 1, &info, NULL, &pipeline);
}
void CreateComputePipeline(const VkPipelineShaderStageCreateInfo &shaderStageCreateInfo)
{
VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO};
pipelineLayoutCreateInfo.setLayoutCount = 1;
pipelineLayoutCreateInfo.pSetLayouts = &descriptorSetLayout;
VK_CHECK_RESULT(vkCreatePipelineLayout(device, &pipelineLayoutCreateInfo, NULL, &pipelineLayout));
VkComputePipelineCreateInfo pipelineCreateInfo = {VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
pipelineCreateInfo.stage = shaderStageCreateInfo;
pipelineCreateInfo.layout = pipelineLayout;
VK_CHECK_RESULT(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipelineCreateInfo, NULL, &pipeline));
}
static VkDescriptorSetLayoutBinding BasicBinding(uint32_t binding, VkDescriptorType descriptorType, uint32_t descriptorCount = 1, VkShaderStageFlags stageFlags = 0, const VkSampler* pImmutableSamplers = NULL )
{
return {binding, descriptorType, descriptorCount, stageFlags, pImmutableSamplers};
}
static VkDescriptorPoolSize BasicPoolSize(VkDescriptorType type, uint32_t descriptorCount)
{
return {type, descriptorCount};
}
VkDescriptorImageInfo ImageDescriptor(VkImageView imageView, VkImageLayout imageLayout, VkSampler sampler = 0)
{
return {sampler, imageView, imageLayout};
}
// todo: different types
VkWriteDescriptorSet ImageWrite(VkDescriptorSet dstSet, uint32_t binding, const VkDescriptorImageInfo &info, VkDescriptorType tp = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
{
VkWriteDescriptorSet wr = {VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET};
wr.dstSet = dstSet;
wr.dstBinding = binding;
wr.descriptorCount = 1;
wr.descriptorType = tp;
wr.pImageInfo = &info;
return wr;
}
VkWriteDescriptorSet BufferWrite(VkDescriptorSet dstSet, uint32_t binding, const VkDescriptorBufferInfo &info)
{
VkWriteDescriptorSet wr = {VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET};
wr.dstSet = dstSet;
wr.dstBinding = binding;
wr.descriptorCount = 1;
wr.descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
wr.pBufferInfo = &info;
return wr;
}
void Destroy()
{
if(!device)
return;
if(descriptorPool)
vkDestroyDescriptorPool(device, descriptorPool, NULL);
descriptorPool = NULL;
if(descriptorSetLayout)
vkDestroyDescriptorSetLayout(device, descriptorSetLayout, NULL);
descriptorSetLayout = NULL;
if(pipelineLayout)
vkDestroyPipelineLayout(device, pipelineLayout, NULL);
pipelineLayout = NULL;
if(pipeline)
vkDestroyPipeline(device, pipeline, NULL);
pipeline = NULL;
}
};
#endif // VULKAN_PIPELINE_UTL_H