# ------------------------------------------------------------------------------
# Color analyzer example
# coding: utf-8
# Copyright (c) 2016 The Foundry Visionmongers Ltd.  All Rights Reserved.
# ------------------------------------------------------------------------------

import mari


# ------------------------------------------------------------------------------
# Print All Colorspace Info:
# ------------------------------------------------------------------------------

class ColorspaceAnalyzer():
    """Mari colorspace analyzer.

    This class can be used to walk a project and analyze the colorspace
    configuration for it. Currently it has one feature whereby it figures out
    which colorspace transforms to register shader code for so that the project
    gets the benefit of the full GPU acceleration.
    """

    # --------------------------------------------------------------------------

    def __init__(self, project):
        self._project = project
        self._findShaderTransforms()

    # --------------------------------------------------------------------------

    def _findShaderTransforms(self):
        """Walk the project and find all required shader transforms to enable full GPU acceleration."""
        from collections import defaultdict

        self._shader_transforms = defaultdict(set)

        for geo in mari.geo.list():
            for channel in geo.channelList():
                self._insertShaderTransforms(channel.colorspaceConfig())
                self._insertShaderTransforms(channel.scalarColorspaceConfig())

    # --------------------------------------------------------------------------

    def _insertShaderTransforms(self, colorspace_config):
        """Adds the shader transforms, into our internal, for the given colorspace config."""

        # Ignore the colorspace config if it has been disabled via the raw option.
        if colorspace_config.resolveRaw():
            return

        file_name = colorspace_config.fileName()
        native_colorspace = colorspace_config.resolveColorspace(colorspace_config.COLORSPACE_STAGE_NATIVE)
        working_colorspace = colorspace_config.resolveColorspace(colorspace_config.COLORSPACE_STAGE_WORKING)

        # Make sure we resolve any roles to their corresponding colorspace as the GPU acceleration system will resolve
        # roles itself when doing it's lookup.
        if mari.ocio.hasRole(file_name, native_colorspace):
            native_colorspace = mari.ocio.toColorspace(file_name, native_colorspace)
        if mari.ocio.hasRole(file_name, working_colorspace):
            working_colorspace = mari.ocio.toColorspace(file_name, working_colorspace)

        if native_colorspace != working_colorspace:
            self._shader_transforms[file_name].add((native_colorspace, working_colorspace))
            self._shader_transforms[file_name].add((working_colorspace, native_colorspace))

    # --------------------------------------------------------------------------

    def printAll(self):
        """Prints out all colorspace information in the project."""

        print('==========================================================================')
        print('PROJECT COLORSPACE SETUP:')
        print('==========================================================================')

        self.printProjectInfo()
        self.printChannelInfo()
        self.printImageInfo()
        self.printProjectorInfo()

        if self._shader_transforms:

            print('==========================================================================')
            print('REQUIRED SHADER TRANSFORMS:')
            print("(Put the following code in a *.py file, fill in the missing '// TODO', and")
            print('execute the script on startup, before the project is opened.)')
            print('==========================================================================')

            self.printShaderTransforms()

            print('==========================================================================')

        else:
            print('==========================================================================')
            print('NO SHADER TRANSFORMS REQUIRED:')
            print("After analyzing your Channels' color configs, there is no need to register")
            print('any colorspace shader tranformations.')
            print('==========================================================================')

    # --------------------------------------------------------------------------

    def printProjectInfo(self):
        """Prints out information about the global colorspaces of the project."""

        if self._project is None:
            return

        print('Project Defaults : %s' % self._project.name())

        project_colorspaces = self._project.colorspaceDefaults()
        print('         OCIO Config :', project_colorspaces.fileName())
        print('          8 bit Data :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_INT8))
        print('         16 bit Data :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_INT16))
        print('        8 bit Scalar :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_INT_SCALAR))
        print('16/32 bit Float Data :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_FLOAT))
        print('             Working :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_WORKING))
        print('             Monitor :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_MONITOR))
        print('       Color Picking :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_COLOR_PICKER))
        print('            Blending :', project_colorspaces.colorspace(mari.ColorspaceDefaults.COLORSPACE_TARGET_BLENDING))
        print('')

    # --------------------------------------------------------------------------

    def printChannelInfo(self):
        """Prints out the colorspace information of a channel."""

        for geo_index, geo in enumerate(mari.geo.list()):
            print('Geo %2d : %s' % (geo_index, geo.name()))
            for channel_index, channel in enumerate(geo.channelList()):
                print('    Channel %2d : %s' % (channel_index, channel.name()))
                print('        Color Data :')
                self._printColorspaceConfigInfo(channel.colorspaceConfig(), 12)
                print('         Mask Data :')
                self._printColorspaceConfigInfo(channel.scalarColorspaceConfig(), 12)
                print('')

    # --------------------------------------------------------------------------

    def printImageInfo(self):
        """Prints out the colorspace information of all images in the image manager."""

        for image_index, image in enumerate(mari.images.list()):
            print('Image %2d : %s' % (image_index, image.filePath()))
            self._printColorspaceConfigInfo(image.colorspaceConfig(), 4)
            print('')

    # --------------------------------------------------------------------------

    def printProjectorInfo(self):
        """Prints out the colorspace information of all projectors in the project."""

        for projector_index, projector in enumerate(mari.projectors.list()):
            print('Projector %2d : %s' % (projector_index, projector.name()))
            print('    Import :')
            self._printColorspaceConfigInfo(projector.importColorspaceConfig(), 8)
            print('    Export :')
            self._printColorspaceConfigInfo(projector.exportColorspaceConfig(), 8)
            print('')

    # --------------------------------------------------------------------------

    def _getColorspacePrettyName(self, colorspace_config, stage):
        colorspace = colorspace_config.colorspace(stage)
        actual_colorspace = colorspace_config.resolveColorspace(stage)
        if colorspace == actual_colorspace:
            return colorspace
        return colorspace + ' (' + actual_colorspace + ')'

    # --------------------------------------------------------------------------

    def _printColorspaceConfigInfo(self, colorspace_config, indent_num_spaces):
        """Prints out information about a colorspace config."""

        indent = ' ' * (indent_num_spaces - 1)
        print(indent, 'OCIO config :', colorspace_config.fileName())
        print(indent, '     Native :', self._getColorspacePrettyName(colorspace_config, colorspace_config.COLORSPACE_STAGE_NATIVE))
        print(indent, '     Output :', self._getColorspacePrettyName(colorspace_config, colorspace_config.COLORSPACE_STAGE_OUTPUT))
        print(indent, '    Working :', self._getColorspacePrettyName(colorspace_config, colorspace_config.COLORSPACE_STAGE_WORKING))
        print(indent, '        Raw :', colorspace_config.raw())
        print(indent, '     Scalar :', colorspace_config.scalar())

    # --------------------------------------------------------------------------

    def printShaderTransforms(self):
        """Print all required shader transforms to enable full GPU acceleration."""

        print('\n' \
              '# ------------------------------------------------------------------------------\n' \
              '# Register Optimized GLSL Shader Transforms:\n' \
              '# ------------------------------------------------------------------------------\n' \
              '\n' \
              'def registerOptimizedTransforms():\n' \
              '    "Registers opimized GLSL shader transform code for the OCIO config. This method can not only significantly"\n' \
              '    "improve the perfromance within Mari but can also improve the accuracy of the transform itself due to the ability"\n' \
              '    "to remove the need for a LUT."\n')

        for file_name, transforms in self._shader_transforms.items():
            for transform in transforms:
                print("    mari.ocio.setShaderTransformCode('%s',\n" \
                      "                                     '%s',\n" \
                      "                                     '%s',\n" \
                      "                                     '    vec3 v1 = #Input.rgb;\\n'\n" \
                      "                                     '    // TODO\\n'\n" \
                      "                                     '    #Output = vec4(v1, #Input.a);\\n')\n" \
                      % (file_name, transform[0], transform[1]))

        print('# ------------------------------------------------------------------------------\n' \
              '\n' \
              'if mari.app.isRunning():\n' \
              '    # We want to use the optimized transforms all the time and not just when the user executes the example. The action\n' \
              '    # registered with the menu is really only there to make the user aware of the technique.\n' \
              '    registerOptimizedTransforms()\n' \
              '\n')

        print('==========================================================================')

# ------------------------------------------------------------------------------

project = mari.projects.current()
if project is None:
    mari.utils.message('Please open a project first')
else:
    analyzer = ColorspaceAnalyzer(project)
    analyzer.printAll()

    