"""Split single OBJ model into mutliple OBJ files by materials

-------------------------------------
How to use
-------------------------------------

python split_obj.py -i infile.obj -o outfile

Will generate:

outfile_000.obj
outfile_001.obj

...

outfile_XXX.obj

-------------------------------------
Parser based on format description
-------------------------------------

    http://en.wikipedia.org/wiki/Obj

------
Author
------
AlteredQualia http://alteredqualia.com

"""

import fileinput
import operator
import random
import os.path
import getopt
import sys
import struct
import math
import glob

# #####################################################
# Configuration
# #####################################################
TRUNCATE = False
SCALE = 1.0


# #####################################################
# Templates
# #####################################################
TEMPLATE_OBJ = u"""\
################################
# OBJ generated by split_obj.py
################################
# Faces:    %(nfaces)d
# Vertices: %(nvertices)d
# Normals:  %(nnormals)d
# UVs:      %(nuvs)d
################################

# vertices

%(vertices)s

# normals

%(normals)s

# uvs

%(uvs)s

# faces

%(faces)s
"""

TEMPLATE_VERTEX = "v %f %f %f"
TEMPLATE_VERTEX_TRUNCATE = "v %d %d %d"

TEMPLATE_NORMAL = "vn %.5g %.5g %.5g"
TEMPLATE_UV = "vt %.5g %.5g"

TEMPLATE_FACE3_V = "f %d %d %d"
TEMPLATE_FACE4_V = "f %d %d %d %d"

TEMPLATE_FACE3_VT = "f %d/%d %d/%d %d/%d"
TEMPLATE_FACE4_VT = "f %d/%d %d/%d %d/%d %d/%d"

TEMPLATE_FACE3_VN = "f %d//%d %d//%d %d//%d"
TEMPLATE_FACE4_VN = "f %d//%d %d//%d %d//%d %d//%d"

TEMPLATE_FACE3_VTN = "f %d/%d/%d %d/%d/%d %d/%d/%d"
TEMPLATE_FACE4_VTN = "f %d/%d/%d %d/%d/%d %d/%d/%d %d/%d/%d"


# #####################################################
# Utils
# #####################################################
def file_exists(filename):
    """Return true if file exists and is accessible for reading.

    Should be safer than just testing for existence due to links and
    permissions magic on Unix filesystems.

    @rtype: boolean
    """

    try:
        f = open(filename, 'r')
        f.close()
        return True
    except IOError:
        return False

# #####################################################
# OBJ parser
# #####################################################
def parse_vertex(text):
    """Parse text chunk specifying single vertex.

    Possible formats:
        vertex index
        vertex index / texture index
        vertex index / texture index / normal index
        vertex index / / normal index
    """

    v = 0
    t = 0
    n = 0

    chunks = text.split("/")

    v = int(chunks[0])
    if len(chunks) > 1:
        if chunks[1]:
            t = int(chunks[1])
    if len(chunks) > 2:
        if chunks[2]:
            n = int(chunks[2])

    return { 'v': v, 't': t, 'n': n }

def parse_obj(fname):
    """Parse OBJ file.
    """

    vertices = []
    normals = []
    uvs = []

    faces = []

    materials = {}
    mcounter = 0
    mcurrent = 0

    mtllib = ""

    # current face state
    group = 0
    object = 0
    smooth = 0

    for line in fileinput.input(fname):
        chunks = line.split()
        if len(chunks) > 0:

            # Vertices as (x,y,z) coordinates
            # v 0.123 0.234 0.345
            if chunks[0] == "v" and len(chunks) == 4:
                x = float(chunks[1])
                y = float(chunks[2])
                z = float(chunks[3])
                vertices.append([x,y,z])

            # Normals in (x,y,z) form; normals might not be unit
            # vn 0.707 0.000 0.707
            if chunks[0] == "vn" and len(chunks) == 4:
                x = float(chunks[1])
                y = float(chunks[2])
                z = float(chunks[3])
                normals.append([x,y,z])

            # Texture coordinates in (u,v[,w]) coordinates, w is optional
            # vt 0.500 -1.352 [0.234]
            if chunks[0] == "vt" and len(chunks) >= 3:
                u = float(chunks[1])
                v = float(chunks[2])
                w = 0
                if len(chunks)>3:
                    w = float(chunks[3])
                uvs.append([u,v,w])

            # Face
            if chunks[0] == "f" and len(chunks) >= 4:
                vertex_index = []
                uv_index = []
                normal_index = []

                for v in chunks[1:]:
                    vertex = parse_vertex(v)
                    if vertex['v']:
                        vertex_index.append(vertex['v'])
                    if vertex['t']:
                        uv_index.append(vertex['t'])
                    if vertex['n']:
                        normal_index.append(vertex['n'])

                faces.append({
                    'vertex':vertex_index,
                    'uv':uv_index,
                    'normal':normal_index,

                    'material':mcurrent,
                    'group':group,
                    'object':object,
                    'smooth':smooth,
                    })

            # Group
            if chunks[0] == "g" and len(chunks) == 2:
                group = chunks[1]

            # Object
            if chunks[0] == "o" and len(chunks) == 2:
                object = chunks[1]

            # Materials definition
            if chunks[0] == "mtllib" and len(chunks) == 2:
                mtllib = chunks[1]

            # Material
            if chunks[0] == "usemtl" and len(chunks) == 2:
                material = chunks[1]
                if not material in materials:
                    mcurrent = mcounter
                    materials[material] = mcounter
                    mcounter += 1
                else:
                    mcurrent = materials[material]

            # Smooth shading
            if chunks[0] == "s" and len(chunks) == 2:
                smooth = chunks[1]

    return faces, vertices, uvs, normals, materials, mtllib

# #############################################################################
# API - Breaker
# #############################################################################
def break_obj(infile, outfile):
    """Break infile.obj to outfile.obj
    """

    if not file_exists(infile):
        print "Couldn't find [%s]" % infile
        return

    faces, vertices, uvs, normals, materials, mtllib = parse_obj(infile)

    # sort faces by materials

    chunks = {}

    for face in faces:
        material = face["material"]
        if not material in chunks:
            chunks[material] = {"faces": [], "vertices": set(), "normals": set(), "uvs": set()}

        chunks[material]["faces"].append(face)

    # extract unique vertex / normal / uv indices used per chunk

    for material in chunks:
        chunk = chunks[material]
        for face in chunk["faces"]:
            for i in face["vertex"]:
                chunk["vertices"].add(i)

            for i in face["normal"]:
                chunk["normals"].add(i)

            for i in face["uv"]:
                chunk["uvs"].add(i)

    # generate new OBJs

    for mi, material in enumerate(chunks):
        chunk = chunks[material]

        # generate separate vertex / normal / uv index lists for each chunk
        # (including mapping from original to new indices)

        # get well defined order

        new_vertices = list(chunk["vertices"])
        new_normals = list(chunk["normals"])
        new_uvs = list(chunk["uvs"])

        # map original => new indices

        vmap = {}
        for i, v in enumerate(new_vertices):
            vmap[v] = i + 1

        nmap = {}
        for i, n in enumerate(new_normals):
            nmap[n] = i + 1

        tmap = {}
        for i, t in enumerate(new_uvs):
            tmap[t] = i + 1


        # vertices

        pieces = []
        for i in new_vertices:
            vertex = vertices[i-1]
            txt = TEMPLATE_VERTEX % (vertex[0], vertex[1], vertex[2])
            pieces.append(txt)

        str_vertices = "\n".join(pieces)

        # normals

        pieces = []
        for i in new_normals:
            normal = normals[i-1]
            txt = TEMPLATE_NORMAL % (normal[0], normal[1], normal[2])
            pieces.append(txt)

        str_normals = "\n".join(pieces)

        # uvs

        pieces = []
        for i in new_uvs:
            uv = uvs[i-1]
            txt = TEMPLATE_UV % (uv[0], uv[1])
            pieces.append(txt)

        str_uvs = "\n".join(pieces)

        # faces

        pieces = []

        for face in chunk["faces"]:

            txt = ""

            fv = face["vertex"]
            fn = face["normal"]
            ft = face["uv"]

            if len(fv) == 3:

                va = vmap[fv[0]]
                vb = vmap[fv[1]]
                vc = vmap[fv[2]]

                if len(fn) == 3 and len(ft) == 3:
                    na = nmap[fn[0]]
                    nb = nmap[fn[1]]
                    nc = nmap[fn[2]]

                    ta = tmap[ft[0]]
                    tb = tmap[ft[1]]
                    tc = tmap[ft[2]]

                    txt = TEMPLATE_FACE3_VTN % (va, ta, na, vb, tb, nb, vc, tc, nc)

                elif len(fn) == 3:
                    na = nmap[fn[0]]
                    nb = nmap[fn[1]]
                    nc = nmap[fn[2]]

                    txt = TEMPLATE_FACE3_VN % (va, na, vb, nb, vc, nc)

                elif len(ft) == 3:
                    ta = tmap[ft[0]]
                    tb = tmap[ft[1]]
                    tc = tmap[ft[2]]

                    txt = TEMPLATE_FACE3_VT % (va, ta, vb, tb, vc, tc)

                else:
                    txt = TEMPLATE_FACE3_V % (va, vb, vc)

            elif len(fv) == 4:

                va = vmap[fv[0]]
                vb = vmap[fv[1]]
                vc = vmap[fv[2]]
                vd = vmap[fv[3]]

                if len(fn) == 4 and len(ft) == 4:
                    na = nmap[fn[0]]
                    nb = nmap[fn[1]]
                    nc = nmap[fn[2]]
                    nd = nmap[fn[3]]

                    ta = tmap[ft[0]]
                    tb = tmap[ft[1]]
                    tc = tmap[ft[2]]
                    td = tmap[ft[3]]

                    txt = TEMPLATE_FACE4_VTN % (va, ta, na, vb, tb, nb, vc, tc, nc, vd, td, nd)

                elif len(fn) == 4:
                    na = nmap[fn[0]]
                    nb = nmap[fn[1]]
                    nc = nmap[fn[2]]
                    nd = nmap[fn[3]]

                    txt = TEMPLATE_FACE4_VN % (va, na, vb, nb, vc, nc, vd, nd)

                elif len(ft) == 4:
                    ta = tmap[ft[0]]
                    tb = tmap[ft[1]]
                    tc = tmap[ft[2]]
                    td = tmap[ft[3]]

                    txt = TEMPLATE_FACE4_VT % (va, ta, vb, tb, vc, tc, vd, td)

                else:
                    txt = TEMPLATE_FACE4_V % (va, vb, vc, vd)

            pieces.append(txt)


        str_faces = "\n".join(pieces)

        # generate OBJ string

        content = TEMPLATE_OBJ % {
        "nfaces"        : len(chunk["faces"]),
        "nvertices"     : len(new_vertices),
        "nnormals"      : len(new_normals),
        "nuvs"          : len(new_uvs),

        "vertices"      : str_vertices,
        "normals"       : str_normals,
        "uvs"           : str_uvs,
        "faces"         : str_faces
        }

        # write OBJ file

        outname = "%s_%03d.obj" % (outfile, mi)

        f = open(outname, "w")
        f.write(content)
        f.close()


# #############################################################################
# Helpers
# #############################################################################
def usage():
    print "Usage: %s -i filename.obj -o prefix" % os.path.basename(sys.argv[0])

# #####################################################
# Main
# #####################################################
if __name__ == "__main__":

    # get parameters from the command line

    try:
        opts, args = getopt.getopt(sys.argv[1:], "hi:o:x:", ["help", "input=", "output=", "truncatescale="])

    except getopt.GetoptError:
        usage()
        sys.exit(2)

    infile = outfile = ""

    for o, a in opts:
        if o in ("-h", "--help"):
            usage()
            sys.exit()

        elif o in ("-i", "--input"):
            infile = a

        elif o in ("-o", "--output"):
            outfile = a

        elif o in ("-x", "--truncatescale"):
            TRUNCATE = True
            SCALE = float(a)

    if infile == "" or outfile == "":
        usage()
        sys.exit(2)

    print "Splitting [%s] into [%s_XXX.obj] ..." % (infile, outfile)

    break_obj(infile, outfile)

