#!/usr/bin/env python # -*- coding: utf-8 -*- import os import sys import subprocess import datetime import argparse import commands git_repo_path = os.path.abspath(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()) class include: def __init__(self, isInclude, isQuote, body, comment): self.isInclude = isInclude self.isQuote = isQuote self.body = body self.comment = comment def format(self): result = '%s %s%s%s' % ( ('#include' if self.isInclude else '#import'), ('"' if self.isQuote else '<'), self.body.strip(), ('"' if self.isQuote else '>'), ) if self.comment.strip(): result += ' ' + self.comment.strip() return result def is_include_or_import(line): line = line.strip() if line.startswith('#include '): return True elif line.startswith('#import '): return True else: return False def parse_include(line): remainder = line.strip() if remainder.startswith('#include '): isInclude = True remainder = remainder[len('#include '):] elif remainder.startswith('#import '): isInclude = False remainder = remainder[len('#import '):] elif remainder == '//': return None elif not remainder: return None else: print ('Unexpected import or include: '+ line) sys.exit(1) comment = None if remainder.startswith('"'): isQuote = True endIndex = remainder.find('"', 1) if endIndex < 0: print ('Unexpected import or include: '+ line) sys.exit(1) body = remainder[1:endIndex] comment = remainder[endIndex+1:] elif remainder.startswith('<'): isQuote = False endIndex = remainder.find('>', 1) if endIndex < 0: print ('Unexpected import or include: '+ line) sys.exit(1) body = remainder[1:endIndex] comment = remainder[endIndex+1:] else: print ('Unexpected import or include: '+ remainder) sys.exit(1) return include(isInclude, isQuote, body, comment) def parse_includes(text): lines = text.split('\n') includes = [] for line in lines: include = parse_include(line) if include: includes.append(include) return includes def sort_include_block(text, filepath, filename, file_extension): lines = text.split('\n') includes = parse_includes(text) blocks = [] file_extension = file_extension.lower() for include in includes: include.isInclude = False if file_extension in ('c', 'cpp', 'hpp'): for include in includes: include.isInclude = True elif file_extension in ('m'): for include in includes: include.isInclude = False # Make sure matching header is first. matching_header_includes = [] other_includes = [] def is_matching_header(include): filename_wo_ext = os.path.splitext(filename)[0] include_filename_wo_ext = os.path.splitext(os.path.basename(include.body))[0] return filename_wo_ext == include_filename_wo_ext for include in includes: if is_matching_header(include): matching_header_includes.append(include) else: other_includes.append(include) includes = other_includes def formatBlock(includes): lines = [include.format() for include in includes] lines = list(set(lines)) def include_sorter(a, b): # return cmp(a.lower(), b.lower()) return cmp(a, b) # print 'before' # for line in lines: # print '\t', line # print lines.sort(include_sorter) # print 'after' # for line in lines: # print '\t', line # print # print # print 'filepath' # for line in lines: # print '\t', line # print return '\n'.join(lines) includeAngles = [include for include in includes if include.isInclude and not include.isQuote] includeQuotes = [include for include in includes if include.isInclude and include.isQuote] importAngles = [include for include in includes if (not include.isInclude) and not include.isQuote] importQuotes = [include for include in includes if (not include.isInclude) and include.isQuote] if matching_header_includes: blocks.append(formatBlock(matching_header_includes)) if includeQuotes: blocks.append(formatBlock(includeQuotes)) if includeAngles: blocks.append(formatBlock(includeAngles)) if importQuotes: blocks.append(formatBlock(importQuotes)) if importAngles: blocks.append(formatBlock(importAngles)) return '\n'.join(blocks) + '\n' def sort_class_statement_block(text, filepath, filename, file_extension): lines = text.split('\n') lines = [line.strip() for line in lines if line.strip()] lines = list(set(lines)) lines.sort() return '\n' + '\n'.join(lines) + '\n' def find_matching_section(text, match_test): lines = text.split('\n') first_matching_line_index = None for index, line in enumerate(lines): if match_test(line): first_matching_line_index = index break if first_matching_line_index is None: return None # Absorb any leading empty lines. while first_matching_line_index > 0: prev_line = lines[first_matching_line_index - 1] if prev_line.strip(): break first_matching_line_index = first_matching_line_index - 1 first_non_matching_line_index = None for index, line in enumerate(lines[first_matching_line_index:]): if not line.strip(): # Absorb any trailing empty lines. continue if not match_test(line): first_non_matching_line_index = index + first_matching_line_index break text0 = '\n'.join(lines[:first_matching_line_index]) if first_non_matching_line_index is None: text1 = '\n'.join(lines[first_matching_line_index:]) text2 = None else: text1 = '\n'.join(lines[first_matching_line_index:first_non_matching_line_index]) text2 = '\n'.join(lines[first_non_matching_line_index:]) return text0, text1, text2 def sort_matching_blocks(sort_name, filepath, filename, file_extension, text, match_func, sort_func): unprocessed = text processed = None while True: section = find_matching_section(unprocessed, match_func) # print '\t', 'sort_matching_blocks', section if not section: if processed: processed = '\n'.join((processed, unprocessed,)) else: processed = unprocessed break text0, text1, text2 = section if processed: processed = '\n'.join((processed, text0,)) else: processed = text0 # print 'before:' # temp_lines = text1.split('\n') # for index, line in enumerate(temp_lines): # if index < 3 or index + 3 >= len(temp_lines): # print '\t', index, line # # print text1 # print text1 = sort_func(text1, filepath, filename, file_extension) # print 'after:' # # print text1 # temp_lines = text1.split('\n') # for index, line in enumerate(temp_lines): # if index < 3 or index + 3 >= len(temp_lines): # print '\t', index, line # print processed = '\n'.join((processed, text1,)) if text2: unprocessed = text2 else: break if text != processed: print sort_name, filepath return processed def find_class_statement_section(text): def is_class_statement(line): return line.strip().startswith('@class ') return find_matching_section(text, is_class_statement) def find_include_section(text): def is_include_line(line): return is_include_or_import(line) # return is_include_or_import_or_empty(line) return find_matching_section(text, is_include_line) def sort_includes(filepath, filename, file_extension, text): # print 'sort_includes', filepath if file_extension not in ('.h', '.m', '.mm'): return text return sort_matching_blocks('sort_includes', filepath, filename, file_extension, text, find_include_section, sort_include_block) def sort_class_statements(filepath, filename, file_extension, text): # print 'sort_class_statements', filepath if file_extension not in ('.h', '.m', '.mm'): return text return sort_matching_blocks('sort_class_statements', filepath, filename, file_extension, text, find_class_statement_section, sort_class_statement_block) def splitall(path): allparts = [] while 1: parts = os.path.split(path) if parts[0] == path: # sentinel for absolute paths allparts.insert(0, parts[0]) break elif parts[1] == path: # sentinel for relative paths allparts.insert(0, parts[1]) break else: path = parts[0] allparts.insert(0, parts[1]) return allparts def process(filepath): short_filepath = filepath[len(git_repo_path):] if short_filepath.startswith(os.sep): short_filepath = short_filepath[len(os.sep):] filename = os.path.basename(filepath) if filename.startswith('.'): raise "shouldn't call process with dotfile" file_ext = os.path.splitext(filename)[1] if file_ext in ('.swift'): env_copy = os.environ.copy() env_copy["SCRIPT_INPUT_FILE_COUNT"] = "1" env_copy["SCRIPT_INPUT_FILE_0"] = '%s' % ( short_filepath, ) lint_output = subprocess.check_output(['swiftlint', 'autocorrect', '--use-script-input-files'], env=env_copy) print lint_output try: lint_output = subprocess.check_output(['swiftlint', 'lint', '--use-script-input-files'], env=env_copy) except subprocess.CalledProcessError, e: lint_output = e.output print lint_output with open(filepath, 'rt') as f: text = f.read() original_text = text text = sort_includes(filepath, filename, file_ext, text) text = sort_class_statements(filepath, filename, file_ext, text) lines = text.split('\n') while lines and lines[0].startswith('//'): lines = lines[1:] text = '\n'.join(lines) text = text.strip() header = '''// // Copyright (c) %s Open Whisper Systems. All rights reserved. // ''' % ( datetime.datetime.now().year, ) text = header + text + '\n' if original_text == text: return print 'Updating:', short_filepath with open(filepath, 'wt') as f: f.write(text) def should_ignore_path(path): ignore_paths = [ os.path.join(git_repo_path, '.git') ] for ignore_path in ignore_paths: if path.startswith(ignore_path): return True for component in splitall(path): if component.startswith('.'): return True if component.endswith('.framework'): return True if component in ('Pods', 'ThirdParty', 'Carthage',): return True return False def process_if_appropriate(filepath): filename = os.path.basename(filepath) if filename.startswith('.'): return file_ext = os.path.splitext(filename)[1] if file_ext not in ('.h', '.hpp', '.cpp', '.m', '.mm', '.pch', '.swift'): return if should_ignore_path(filepath): return process(filepath) def check_diff_for_keywords(): keywords = ["OWSAssert\(", "OWSFail\(", "ows_add_overflow\(", "ows_sub_overflow\("] matching_expression = "|".join(keywords) command_line = 'git diff --staged | grep --color=always -C 3 -E "%s"' % matching_expression print(command_line) try: output = subprocess.check_output(command_line, shell=True) except subprocess.CalledProcessError, e: # > man grep # EXIT STATUS # The grep utility exits with one of the following values: # 0 One or more lines were selected. # 1 No lines were selected. # >1 An error occurred. if e.returncode == 1: # no keywords in diff output return else: # some other error - bad grep expression? raise e if len(output) > 0: print("⚠️ keywords detected in diff:") print(output) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Precommit script.') parser.add_argument('--all', action='store_true', help='process all files in or below current dir') parser.add_argument('--path', help='used to specify a path to process.') args = parser.parse_args() if args.all: for rootdir, dirnames, filenames in os.walk(git_repo_path): for filename in filenames: file_path = os.path.abspath(os.path.join(rootdir, filename)) process_if_appropriate(file_path) elif args.path: for rootdir, dirnames, filenames in os.walk(args.path): for filename in filenames: file_path = os.path.abspath(os.path.join(rootdir, filename)) process_if_appropriate(file_path) else: filepaths = [] # Staging output = commands.getoutput('git diff --cached --name-only --diff-filter=ACMR') filepaths.extend([line.strip() for line in output.split('\n')]) # Working output = commands.getoutput('git diff --name-only --diff-filter=ACMR') filepaths.extend([line.strip() for line in output.split('\n')]) # Only process each path once. filepaths = sorted(set(filepaths)) for filepath in filepaths: filepath = os.path.abspath(os.path.join(git_repo_path, filepath)) process_if_appropriate(filepath) print 'git clang-format...' print commands.getoutput('git clang-format') check_diff_for_keywords()