NOIP学习小站
西安交通大学附属中学航天学校

模板函数

一、自定义交换函数

我们来看一个特殊的场景,要自定义一个交换函数用来交换两个变量的值(这里不使用C++内置的swap函数)。如果我们要交换两个int变量的值,函数参数使用引用传参的方式,很容易就能写出下面的函数:

void oldswap(int &a,int &b){
	int temp = a;
	a = b;
	b = temp;
}

当然,这个函数只能交换两个int类型的变量的值,如果要交换两个double类型的变量的值,直接调用这个函数会出现编译错误。如果程序里既要交换两个int类型变量的值,还要交换两个double类型变量的值,甚至还要交换两个结构体变量的值,该怎么办呢?现在能想到的最简单粗暴的方法,就是定义多个不同数据类型参数的函数:

struct Point{
	int x,y;
	Point(){};
	Point(int x,int y){
		this->x = x;
		this->y = y;
	}
	void print(){
		cout<<"("<<this->x<<","<<this->y<<")";
	}
};

/*************************************************
//普通交换函数,不同的参数类型都定义一个交换函数
//下面三个函数的函数名相同,但参数类型不同,是三个不同的函数(利用了C++函数重载的功能) 
**************************************************/
void oldswap(int &a,int &b){
	int temp = a;
	a = b;
	b = temp;
}
void oldswap(double &a,double &b){
	double temp = a;
	a = b;
	b = temp;
}
void oldswap(Point &a,Point &b){
	Point temp = a;
	a = b;
	b = temp;
}

现在调用oldswap函数就能实现上述效果(其实不同的参数类型调用的是不同的函数):

int x1=12,y1=23;
oldswap(x1,y1);
cout<<x1<<","<<y1<<endl;

double x2=12.23,y2=23.45;
oldswap(x2,y2);
cout<<x2<<","<<y2<<endl;

Point x3(1,2),y3(0,0);
oldswap(x3,y3);
x3.print();
cout<<"\t";
y3.print();
cout<<endl;

二、模板函数简介

很显然,这样实现的代码不够简洁!可以利用C++中的模板函数来简化代码,先来定义一个模板函数:

template<typename T>
void newswap(T &a,T &b){
	T temp = a;
	a = b;
	b = temp;
}

看函数申明语句,两个引用传参的参数a和b都是T类型,那么T类型究竟是什么数据类型呢?看函数定义前的模板前缀标记语句template<typename T>,这里出现了T,但是这里的T却不是一个自定义数据类型,而是一个通用数据类型。同样的函数体中的temp变量也是T类型,也是一个通用数据类型。可见,这里的参数a、b,函数体中的变量temp都是同一个通用数据类型T。这里的模板函数体现了“泛型编程”的概念:即不考虑具体数据类型的编程方式。

上面的newswap就是一个模板函数,函数模板不是实际的函数,而是编译器用于生成一个或多个函数的 "模具"。在编写模板函数时,不必为形参、返回值或局部变量指定实际数据类型,而是使用类型名称(上面的typename T)来指定通用数据类型。当编译器遇到对模板函数的调用时,它将检查其实际参数的数据类型,并对照模板函数的实现代码自动生成将与数据类型对应的实际函数代码。

例如下面的语句:

int x1=12,y1=23;
newswap(x1,y1);
cout<<x1<<","<<y1<<endl;

newswap调用的是编译器根据实际参数类型(int)对照模板函数自动生成的函数代码:

//编译时,模板函数的通用数据类型T根据实际参数类型被识别成int,自动生成对应的实际函数代码
void newswap(int &a,int &b){
	int temp = a;
	a = b;
	b = temp;
}

这里的函数调用语句newswap(x1,y1);有一个自动推导调用通用数据类型T为int的模板函数的过程。其实还可以显示地调用模板函数,指定通用数据类型的实际数据类型,例如:newswap<int>(x1,y1);

同样地,下面的语句:

double x2=12.23,y2=23.45;
newswap(x2,y2);
cout<<x2<<","<<y2<<endl;

newswap调用的是编译器根据实际参数类型(double)对照模板函数自动生成的函数代码:

//编译时,模板函数的通用数据类型T根据实际参数类型被识别成double,自动生成对应的实际函数代码
void newswap(double &a,double &b){
	double temp = a;
	a = b;
	b = temp;
}

类似地,下面的语句:

Point x3(1,2),y3(0,0);
newswap(x3,y3);
x3.print();
cout<<"\t";
y3.print();
cout<<endl;

newswap调用的是编译器根据实际参数类型(Point)对照模板函数自动生成的函数代码:

//编译时,模板函数的通用数据类型T根据实际参数类型被识别成Point,自动生成对应的实际函数代码
void newswap(Point &a,Point &b){
	Point temp = a;
	a = b;
	b = temp;
}

可见,上面定义的newswap模板函数实现了任意两个相同数据类型变量的值的交换,甚至是自定义的结构体变量或者类变量。其实,C++内置的swap函数,也是一个模板函数,所以使用swap函数可以交换任意相同数据类型的变量的值。

类似地,cmath中的abs也是一个模板函数,abs(-1)的返回值是int型整数1,abs(-1.0)的返回值是double型浮点数1.0。

大家可以思考并测试下面模板函数的用途:

template<typename T>
T square(T number){
    return number * number;
}

再试试下面这个有多个通用数据类型的模板函数:

#include<iostream>
using namespace std;
//用来实现数据类型强制转换的模板函数 
template<typename T1,typename T2>
void convert(T1 in,T2 &out){
	out = (T2)in;
}
int main()
{   
	int a = 123,b = 65;
	double c = 12.56,d = 3.14;
	char ch;
	
	convert(c,a);
	cout<<a<<endl;
	
	convert(b,d);
	cout<<d<<endl;
	
	convert(b,ch);
	cout<<ch<<endl;
	
    return 0; 
}

前面介绍选择排序的时候,展示了使用比较函数实现选择排序的方法帮助大家类比理解sort函数的比较函数。这里换用快速排序,将前面实现的快速排序的函数改写成模版函数,帮助大家类比理解sort函数为什么可以对不同类型的数组排序:

#include<iostream>
using namespace std;
const int N = 110;
int a[N]; 

//坐标
struct Point{
    int x,y;
} b[N];

//快速排序:对指针begin~end区域排序(升序排序,不包括end)
template<typename T> 
void QuickSort(T *begin,T *end){
    T *p = begin;
    T *q = end-1;
    T flag = *(begin+(end-begin)/2);  //不能写成 *((begin+end)/2),两个指针不能做加法
    
    do{
        while(*p<flag) p++;        //左边查找第一个不小于flag的元素 
        while(flag<*q) q--;        //右边查找第一个不大于flag的元素
        if(p<=q){
            T t = *p;*p = *q;*q = t;      //交换 
            p++;q--;                        //下一个位置继续查找 
        }
    }while(p<=q);
    
    if(begin<q) QuickSort(begin,q+1);
    if(p<end-1) QuickSort(p,end);
}

//快速排序:对指针begin~end区域排序(不包括end)
//第三个参数cmpfun是排序比较函数,注意这个参数的类型是函数
template<typename T>  
void QuickSort(T *begin ,T *end ,bool(*cmpfun)(T a ,T b) ){
    T *p = begin;
    T *q = end-1;
    T flag = *(begin+(end-begin)/2);  //不能写成 *((begin+end)/2),两个指针不能做加法
    
    do{
        while(cmpfun(*p,flag)) p++;        //左边查找第一个不小于flag的元素 
        while(cmpfun(flag,*q)) q--;        //右边查找第一个不大于flag的元素
        if(p<=q){
            T t = *p;*p = *q;*q = t;      //交换 
            p++;q--;                      //下一个位置继续查找 
        }
    }while(p<=q);
    
    if(begin<q) QuickSort(begin,q+1,cmpfun);
    if(p<end-1) QuickSort(p,end,cmpfun);
}

//比较函数:降序排序
bool cmp(int a,int b){
    return a>b;
}

//比较函数:坐标到原点距离升序排序
bool cmp2(Point a,Point b){
    return a.x*a.x+a.y*a.y<b.x*b.x+b.y*b.y;
}

int main(){ 
    int n;
    cin>>n;
    //输入n个整数 
    for(int i=0;i<n;i++){
        cin>>a[i];
    }
    
    //调用没有比较函数的QuickSort 
    QuickSort(a,a+n);
    for(int i=0;i<n;i++){
        cout<<a[i]<<" ";
    }
    cout<<endl;
    
    //调用有比较函数的QuickSort
    QuickSort(a,a+n,cmp);
    for(int i=0;i<n;i++){
        cout<<a[i]<<" ";
    }
    cout<<endl;
    
    //输入n个坐标值 
    for(int i=0;i<n;i++){
        cin>>b[i].x>>b[i].y;
    }
    
    //调用有比较函数的QuickSort    
    QuickSort(b,b+n,cmp2);
    for(int i=0;i<n;i++){
        cout<<"("<<b[i].x<<","<<b[i].y<<") ";
    }
    cout<<endl;
    return 0;
}