// Used Unitys "Standard geometry shader example" as a base to get unitys standard lighting and stuff w/ a custom geometry stage since surface shaders dont allow this 
//and when writing from the ground up lighting has to be done by hand
// https://github.com/keijiro/StandardGeometryShader

#include "UnityCG.cginc"
#include "UnityGBuffer.cginc"
#include "UnityStandardUtils.cginc"

// Cube map shadow caster; Used to render point light shadows on platforms
// that lacks depth cube map support.
#if defined(SHADOWS_CUBE) && !defined(SHADOWS_CUBE_IN_DEPTH_TEX)
#define PASS_CUBE_SHADOWCASTER
#endif

// Shader uniforms
half4 _Color;
sampler2D _MainTex;
float4 _MainTex_ST;

half _Glossiness;
half _Metallic;

sampler2D _BumpMap;
float _BumpScale;

sampler2D _OcclusionMap;
float _OcclusionStrength;

float _LocalTime;

// Vertex input attributes
struct Attributes {
    float4 position : POSITION;
    float3 normal : NORMAL;
    float4 tangent : TANGENT;
    float2 texcoord : TEXCOORD;
    };

// Fragment varyings
struct Varyings {
    float4 position : SV_POSITION;

    #if defined(PASS_CUBE_SHADOWCASTER)
        // Cube map shadow caster pass
        float3 shadow : TEXCOORD0;

    #elif defined(UNITY_PASS_SHADOWCASTER)
        // Default shadow caster pass

    #else
        // GBuffer construction pass
        float3 normal : NORMAL;
        float2 texcoord : TEXCOORD0;
        float4 tspace0 : TEXCOORD1;
        float4 tspace1 : TEXCOORD2;
        float4 tspace2 : TEXCOORD3;
        half3 ambient : TEXCOORD4;

    #endif
    };

//
// Vertex stage
//

Attributes Vertex(Attributes input) {
    // Only do object space to world space transform.
    input.position = mul(unity_ObjectToWorld, input.position);
    input.normal = UnityObjectToWorldNormal(input.normal);
    input.tangent.xyz = UnityObjectToWorldDir(input.tangent.xyz);
    input.texcoord = TRANSFORM_TEX(input.texcoord, _MainTex);
    return input;
    }

//
// Geometry stage
//

Varyings VertexOutput(float3 wpos, half3 wnrm, half4 wtan, float2 uv) {
    Varyings o;
    #if defined(PASS_CUBE_SHADOWCASTER)
        // Cube map shadow caster pass: Transfer the shadow vector.
        o.position = UnityWorldToClipPos(float4(wpos, 1));
        o.shadow = wpos - _LightPositionRange.xyz;
    #elif defined(UNITY_PASS_SHADOWCASTER)
        // Default shadow caster pass: Apply the shadow bias.
        float scos = dot(wnrm, normalize(UnityWorldSpaceLightDir(wpos)));
        wpos -= wnrm * unity_LightShadowBias.z * sqrt(1 - scos * scos);
        o.position = UnityApplyLinearShadowBias(UnityWorldToClipPos(float4(wpos, 1)));
    #else
        // GBuffer construction pass
        half3 bi = cross(wnrm, wtan) * wtan.w * unity_WorldTransformParams.w;
        o.position = UnityWorldToClipPos(float4(wpos, 1));
        o.normal = wnrm;
        o.texcoord = uv;
        o.tspace0 = float4(wtan.x, bi.x, wnrm.x, wpos.x);
        o.tspace1 = float4(wtan.y, bi.y, wnrm.y, wpos.y);
        o.tspace2 = float4(wtan.z, bi.z, wnrm.z, wpos.z);
        o.ambient = ShadeSHPerVertex(wnrm, 0);
    #endif
        return o;
    }

float3 ConstructNormal(float3 v1, float3 v2, float3 v3) {
    return normalize(cross(v2 - v1, v3 - v1));
    }

float rand(float3 co){
    return frac(sin(dot(co.xyz, float3(12.9898, 78.233, 53.539))) * 43758.5453);
    }
float3x3 AngleAxis3x3(float angle, float3 axis){
    float c, s;
    sincos(angle, s, c);
    float t = 1 - c;
    float x = axis.x;
    float y = axis.y;
    float z = axis.z;
    return float3x3(
        t * x * x + c, t * x * y - s * z, t * x * z + s * y,
        t * x * y + s * z, t * y * y + c, t * y * z - s * x,
        t * x * z - s * y, t * y * z + s * x, t * z * z + c
        );
    }

float _BladeHeight, _BladeHeightRandom, _BladeWidth, _BladeWidthRandom, _BendRotationRandom;
int _Test;

[instance(5)]
[maxvertexcount(42)]
void Geometry( triangle Attributes input[3], uint pid : SV_PrimitiveID, inout TriangleStream<Varyings> outStream, uint InstanceID : SV_GSInstanceID) {
    //Vertex inputs;
    
//
    //float2 uv0 = input[0].texcoord;
    //float2 uv1 = input[1].texcoord;
    //float2 uv2 = input[2].texcoord;

    float3 pos = input[0].position;
    float3 vNormal = input[0].normal;
    float4 vTangent = input[0].tangent;
    float3 vBinormal = cross (vNormal,vTangent)*vTangent.w;

    float height = (rand(pos.zyx) * 2 - 1) * _BladeHeightRandom + _BladeHeight;
    float width = (rand(pos.xzy) * 2 - 1) * _BladeWidthRandom + _BladeWidth;

    float3x3 tangentToLocal = float3x3(
        vTangent.x, vBinormal.x, vNormal.x,
        vTangent.y, vBinormal.y, vNormal.y,
        vTangent.z, vBinormal.z, vNormal.z
        );

    

    //TODO add geometry instancing
    //TODO create nodes based on world pos
    //TODO find out if need to remember geometry instancing index to create mor tris
    //TODO keep track of how many verts used

    //Add node
    //float3 wn = ConstructNormal(wp3, wp4, wp5);
    //outStream.Append(VertexOutput(pos + mul(transformationMatrix, float3( width, 0, 0     )),half3(0,1,0),float4(1,1,0,0), float2(1.25,0) ));
    //outStream.Append(VertexOutput(pos + mul(transformationMatrix, float3(-width, 0, 0     )),half3(0,1,0),float4(1,1,0,0), float2(-0.25,0) ));
    //outStream.Append(VertexOutput(pos + mul(transformationMatrix, float3( 0    , 0, height)),half3(0,1,0),float4(1,1,0,0), float2(0.5,1.5) ));
    //outStream.RestartStrip();
    //outStream.Append(VertexOutput(pos + mul(transformationMatrix, float3(0,  width, 0     )),half3(0,1,0),float4(1,1,0,0), float2(1.25,0) ));
    //outStream.Append(VertexOutput(pos + mul(transformationMatrix, float3(0, -width, 0     )),half3(0,1,0),float4(1,1,0,0), float2(-0.25,0) ));
    //outStream.Append(VertexOutput(pos + mul(transformationMatrix, float3(0,  0    , height)),half3(0,1,0),float4(1,1,0,0), float2(0.5,1.5) ));
    //outStream.RestartStrip();

    //Do triangle
    float3 wp0 = input[0].position.xyz;
    float3 wp1 = input[1].position.xyz;
    float3 wp2 = input[2].position.xyz;
    //1. Order points
    float3 p1 = max(wp0,max(wp1,wp2));
    float3 p2;
    float3 p3 = min(wp0,min(wp1,wp2));
    if     (all(p1 >= wp0) && all(p3 <= wp0)) {p2 = wp0;}
    else if(all(p1 >= wp1) && all(p3 <= wp1)) {p2 = wp1;}
    else if(all(p1 >= wp2) && all(p3 <= wp2)) {p2 = wp2;}
    //2. Get directions
    float3 vec1 = p3-p1;
    float3 vec2 = p2-p1;
    float3 vec3 = p3-p2;
    float dir1Len = length(vec1);
    float dir2Len = length(vec2);
    float dir3Len = length(vec3);
    float3 dir1 = vec1 / dir1Len;
    float3 dir2 = vec2 / dir2Len;
    float3 dir3 = vec3 / dir3Len;
    //3. Define values
    float unit = 0.5;
    float3 step1 = dir1 * unit;
    float3 step2 = dir2 * unit * abs(vec2.x/vec1.x);
    float3 step3 = dir3 * unit;
    float3 current1 = p1;
    float3 current2 = p1;
    //4. loop
    int unitsToX = abs(floor(vec1.x/unit));
    for (int i = 0; i < unitsToX; ++i) {
        current1 += step1;
        current2 += step2;
        int unitsToZ = abs(floor((current1.z-current2.z)/unit));
        //1. create start and ends
        float3 current = floor(current1);
        for (int j = 0; j < unitsToZ; ++j) {//loop and create nodes

            float3x3 facingRotationMatrix = AngleAxis3x3(rand(pos) * UNITY_TWO_PI, float3(0, 0, 1));
            float3x3 bendRotationMatrix = AngleAxis3x3(rand(pos.zzx) * _BendRotationRandom * UNITY_PI * 0.5, float3(-1, 0, 0));
            float3x3 transformationMatrix = mul(mul(tangentToLocal, facingRotationMatrix),bendRotationMatrix);

            float k = i+j+1;
            //-Creat node- START
            outStream.Append(VertexOutput(current + mul(transformationMatrix, float3( width, 0, 0     )),half3(0,1,0),float4(1,1,0,0), float2(1.25,0) ));
            outStream.Append(VertexOutput(current + mul(transformationMatrix, float3(-width, 0, 0     )),half3(0,1,0),float4(1,1,0,0), float2(-0.25,0) ));
            outStream.Append(VertexOutput(current + mul(transformationMatrix, float3( 0    , 0, height*k)),half3(0,1,0),float4(1,1,0,0), float2(0.5,1.5) ));
            outStream.RestartStrip();
            outStream.Append(VertexOutput(current + mul(transformationMatrix, float3(0,  width, 0     )),half3(0,1,0),float4(1,1,0,0), float2(1.25,0) ));
            outStream.Append(VertexOutput(current + mul(transformationMatrix, float3(0, -width, 0     )),half3(0,1,0),float4(1,1,0,0), float2(-0.25,0) ));
            outStream.Append(VertexOutput(current + mul(transformationMatrix, float3(0,  0    , height*k)),half3(0,1,0),float4(1,1,0,0), float2(0.5,1.5) ));
            outStream.RestartStrip();
            //-Create node- END
            
            current.z += unit;
            }
        }


    // Extrusion amount
    //float ext = saturate(0.4 - cos(_LocalTime * UNITY_PI * 2) * 0.41);
    //ext *= 1 + 0.3 * sin(pid * 832.37843 + _LocalTime * 88.76);
//
    //// Extrusion points
    //float3 offs = ConstructNormal(wp0, wp1, wp2) * ext;
    //float3 wp3 = wp0 + offs;
    //float3 wp4 = wp1 + offs;
    //float3 wp5 = wp2 + offs;
//
    //// Cap triangle
    //float3 wn = ConstructNormal(wp3, wp4, wp5);
    //float np = saturate(ext * 10);
    //float3 wn0 = lerp(input[0].normal, wn, np);
    //float3 wn1 = lerp(input[1].normal, wn, np);
    //float3 wn2 = lerp(input[2].normal, wn, np);
    //outStream.Append(VertexOutput(wp3, wn0, input[0].tangent, uv0));
    //outStream.Append(VertexOutput(wp4, wn1, input[1].tangent, uv1));
    //outStream.Append(VertexOutput(wp5, wn2, input[2].tangent, uv2));
    //outStream.RestartStrip();
//
    //// Side faces
    //float4 wt = float4(normalize(wp3 - wp0), 1); // world space tangent
    //wn = ConstructNormal(wp3, wp0, wp4);
    //outStream.Append(VertexOutput(wp3, wn, wt, uv0));
    //outStream.Append(VertexOutput(wp0, wn, wt, uv0));
    //outStream.Append(VertexOutput(wp4, wn, wt, uv1));
    //outStream.Append(VertexOutput(wp1, wn, wt, uv1));
    //outStream.RestartStrip();
//
    //wn = ConstructNormal(wp4, wp1, wp5);
    //outStream.Append(VertexOutput(wp4, wn, wt, uv1));
    //outStream.Append(VertexOutput(wp1, wn, wt, uv1));
    //outStream.Append(VertexOutput(wp5, wn, wt, uv2));
    //outStream.Append(VertexOutput(wp2, wn, wt, uv2));
    //outStream.RestartStrip();
//
    //wn = ConstructNormal(wp5, wp2, wp3);
    //outStream.Append(VertexOutput(wp5, wn, wt, uv2));
    //outStream.Append(VertexOutput(wp2, wn, wt, uv2));
    //outStream.Append(VertexOutput(wp3, wn, wt, uv0));
    //outStream.Append(VertexOutput(wp0, wn, wt, uv0));
    //outStream.RestartStrip();
    }

//
// Fragment phase
//

#if defined(PASS_CUBE_SHADOWCASTER)

// Cube map shadow caster pass
half4 Fragment(Varyings input) : SV_Target {
    float depth = length(input.shadow) + unity_LightShadowBias.x;
    return UnityEncodeCubeShadowDepth(depth * _LightPositionRange.w);
    }

#elif defined(UNITY_PASS_SHADOWCASTER)

// Default shadow caster pass
half4 Fragment() : SV_Target { return 0; }

#else

float _AlphaCutoff;

// GBuffer construction pass
void Fragment( Varyings input, out half4 outGBuffer0 : SV_Target0, out half4 outGBuffer1 : SV_Target1, out half4 outGBuffer2 : SV_Target2, out half4 outEmission : SV_Target3) {
    // Sample textures
    float4 sample = tex2D(_MainTex, input.texcoord);
    if(sample.a <= _AlphaCutoff) clip(-1);
    half3 albedo = sample.rgb * _Color.rgb;

    half4 normal = tex2D(_BumpMap, input.texcoord);
    normal.xyz = UnpackScaleNormal(normal, _BumpScale);

    half occ = tex2D(_OcclusionMap, input.texcoord).g;
    occ = LerpOneTo(occ, _OcclusionStrength);

    // PBS workflow conversion (metallic -> specular)
    half3 c_diff, c_spec;
    half refl10;
    c_diff = DiffuseAndSpecularFromMetallic(
        albedo, _Metallic, // input
        c_spec, refl10     // output
        );

    // Tangent space conversion (tangent space normal -> world space normal)
    float3 wn = normalize(float3( dot(input.tspace0.xyz, normal), dot(input.tspace1.xyz, normal), dot(input.tspace2.xyz, normal) ));

    // Update the GBuffer.
    UnityStandardData data;
    data.diffuseColor = c_diff;
    data.occlusion = occ;
    data.specularColor = c_spec;
    data.smoothness = _Glossiness;
    data.normalWorld = wn;
    UnityStandardDataToGbuffer(data, outGBuffer0, outGBuffer1, outGBuffer2);

    // Calculate ambient lighting and output to the emission buffer.
    float3 wp = float3(input.tspace0.w, input.tspace1.w, input.tspace2.w);
    half3 sh = ShadeSHPerPixel(data.normalWorld, input.ambient, wp);
    outEmission = half4(sh * c_diff, 1) * occ;
    }

#endif