spatch

Lenient universal diff patcher
git clone https://git.sinitax.com/sinitax/spatch
Log | Files | Refs | LICENSE | sfeed.txt

spatch (4643B)


      1#!/usr/bin/env python3
      2
      3import sys, re
      4
      5file_header = """\
      6--- ([^\\n\\t]*)(\\t[^\n]*)?
      7\+\+\+ ([^\\n\\t]*)(\\t[^\\n]*)?
      8"""
      9file_header_pattern = re.compile(file_header)
     10
     11chunk_header = "@@ -([0-9]*),([0-9]*) \+([0-9]*),([0-9]*) @@(.*)\n";
     12chunk_header_pattern = re.compile(chunk_header)
     13
     14def patch_file(src_filename, dst_filename, content):
     15    prev_header_match = chunk_header_pattern.search(content)
     16    if prev_header_match == None:
     17        print("[CHUNK] No chunks found, skipping.")
     18        return
     19
     20    src_content = open(src_filename, "r").read()
     21
     22    chunks = list()
     23    next_header_match = True # for do_while loop
     24    while next_header_match:
     25        next_header_match = chunk_header_pattern\
     26            .search(content, prev_header_match.span()[1])
     27        chunks.append(prev_header_match)
     28        prev_header_match = next_header_match
     29
     30    for i,c in enumerate(chunks):
     31        src_line = c.group(1)
     32        src_count = c.group(2)
     33        dst_line = c.group(3)
     34        dst_count = c.group(4)
     35        comment = c.group(5).strip()
     36
     37        if comment != "":
     38            print("[CHUNK] Applying chunk with comment: {}".format(comment))
     39        else:
     40            print("[CHUNK] Applying chunk at line: {}".format(src_line))
     41
     42        start_pos = c.span()[1]
     43        if i != len(chunks) - 1:
     44            end_pos = chunks[i + 1].span()[0]
     45        else:
     46            end_pos = len(content)
     47
     48        chunk_content = content[start_pos:end_pos].split("\n")
     49        valid_lines = 0
     50        for l in chunk_content:
     51            if len(l) == 0 or l[0] not in (' ', '+', '-', '\\'):
     52                break
     53            valid_lines += 1
     54
     55        chunk_content = chunk_content[:valid_lines]
     56
     57        src_lines = "\n".join([l[1:] for l in chunk_content if l[0] in (' ', '-')])
     58        dst_lines = "\n".join([l[1:] for l in chunk_content if l[0] in (' ', '+')])
     59
     60        if src_lines == 0 and dst_lines == 0:
     61            print("[ERROR] Chunk has no valid lines")
     62            sys.exit(1)
     63
     64        src_nl = dst_nl = True
     65        for i,l in enumerate(chunk_content):
     66            if i != 0 and l == '\\ No newline at end of file':
     67                if chunk_content[i-1][0] == '+':
     68                    src_nl = False
     69                elif chunk_content[i-1][0] == '-':
     70                    dst_nl = False
     71        src_lines += "\n" if src_nl and len(src_lines) > 0 else ""
     72        dst_lines += "\n" if dst_nl and len(dst_lines) > 0 else ""
     73
     74        try:
     75            replace_start = src_content.index(src_lines)
     76            src_content = src_content[:replace_start] \
     77                    + dst_lines + src_content[replace_start+len(src_lines):]
     78        except Exception as e:
     79            print("[ERROR] Failed to find corresponding lines for chunk, exiting..")
     80            sys.exit(1)
     81
     82    open(dst_filename, "w+").write(src_content)
     83
     84def main():
     85    if len(sys.argv) < 2:
     86        print("Supply the path of a unified diff file as argument")
     87        return 1
     88    elif len(sys.argv) == 3:
     89        targetdir = sys.argv[2]
     90    else:
     91        targetdir = None
     92
     93    diff_file = sys.argv[1]
     94    content = open(diff_file).read()
     95
     96    prev_header_match = file_header_pattern.search(content)
     97    if prev_header_match == None:
     98        print("[ERROR] Not a unified diff file!")
     99        return 1
    100    header_matches = list()
    101    next_header_match = True # for do_while loop
    102    while next_header_match:
    103        next_header_match = file_header_pattern.search(content, prev_header_match.span()[1])
    104        header_matches.append(prev_header_match)
    105        prev_header_match = next_header_match
    106
    107    print("[GLOBAL] Processing diff file '{}'".format(diff_file))
    108    print("[GLOBAL] Found {} file patch headers..".format(len(header_matches)))
    109
    110    for i in range(len(header_matches)):
    111        if len(header_matches[i].groups()) == 4:
    112            src_file = header_matches[i].group(1)
    113            dst_file = header_matches[i].group(3)
    114        else:
    115            src_file = header_matches[i].group(1)
    116            dst_file = header_matches[i].group(2)
    117
    118        if targetdir:
    119            if src_file[0] != "/":
    120                src_file = targetdir + "/" + src_file.split("/",1)[1]
    121            if dst_file[0] != "/":
    122                dst_file = targetdir + "/" + dst_file.split("/",1)[1]
    123
    124        print("[PATCH] Applying patch from {} to {}".format(src_file, dst_file))
    125
    126        startpos = header_matches[i].span()[1]
    127        if i == len(header_matches) - 1:
    128            endpos = len(content)
    129        else:
    130            endpos = header_matches[i + 1].span()[0]
    131
    132        patch_file(src_file, dst_file, content[startpos:endpos])
    133
    134    return 0
    135
    136if __name__ == "__main__":
    137    sys.exit(main())